使用Keras实现简单线性回归模型操作


Posted in Python onJune 12, 2020

神经网络可以用来模拟回归问题 (regression),实质上是单输入单输出神经网络模型,例如给下面一组数据,用一条线来对数据进行拟合,并可以预测新输入 x 的输出值。

使用Keras实现简单线性回归模型操作

一、详细解读

我们通过这个简单的例子来熟悉Keras构建神经网络的步骤:

1.导入模块并生成数据

首先导入本例子需要的模块,numpy、Matplotlib、和keras.models、keras.layers模块。Sequential是多个网络层的线性堆叠,可以通过向Sequential模型传递一个layer的list来构造该模型,也可以通过.add()方法一个个的将layer加入模型中。layers.Dense 意思是这个神经层是全连接层。

2.建立模型

然后用 Sequential 建立 model,再用 model.add 添加神经层,添加的是 Dense 全连接神经层。参数有两个,(注意此处Keras 2.0.2版本中有变更)一个是输入数据的维度,另一个units代表神经元数,即输出单元数。如果需要添加下一个神经层的时候,不用再定义输入的纬度,因为它默认就把前一层的输出作为当前层的输入。在这个简单的例子里,只需要一层就够了。

3.激活模型

model.compile来激活模型,参数中,误差函数用的是 mse均方误差;优化器用的是 sgd 随机梯度下降法。

4.训练模型

训练的时候用 model.train_on_batch 一批一批的训练 X_train, Y_train。默认的返回值是 cost,每100步输出一下结果。

5.验证模型

用到的函数是 model.evaluate,输入测试集的x和y,输出 cost,weights 和 biases。其中 weights 和 biases 是取在模型的第一层 model.layers[0] 学习到的参数。从学习到的结果你可以看到, weights 比较接近0.5,bias 接近 2。

Weights= [[ 0.49136472]]

biases= [ 2.00405312]

6.可视化学习结果

最后可以画出预测结果,与测试集的值进行对比。

使用Keras实现简单线性回归模型操作

二、完整代码

import numpy as np
np.random.seed(1337) 
from keras.models import Sequential
from keras.layers import Dense
import matplotlib.pyplot as plt
 
# 生成数据
X = np.linspace(-1, 1, 200) #在返回(-1, 1)范围内的等差序列
np.random.shuffle(X) # 打乱顺序
Y = 0.5 * X + 2 + np.random.normal(0, 0.05, (200, )) #生成Y并添加噪声
# plot
plt.scatter(X, Y)
plt.show()
 
X_train, Y_train = X[:160], Y[:160]  # 前160组数据为训练数据集
X_test, Y_test = X[160:], Y[160:]  #后40组数据为测试数据集
 
# 构建神经网络模型
model = Sequential()
model.add(Dense(input_dim=1, units=1))
 
# 选定loss函数和优化器
model.compile(loss='mse', optimizer='sgd')
 
# 训练过程
print('Training -----------')
for step in range(501):
 cost = model.train_on_batch(X_train, Y_train)
 if step % 50 == 0:
  print("After %d trainings, the cost: %f" % (step, cost))
 
# 测试过程
print('\nTesting ------------')
cost = model.evaluate(X_test, Y_test, batch_size=40)
print('test cost:', cost)
W, b = model.layers[0].get_weights()
print('Weights=', W, '\nbiases=', b)
 
# 将训练结果绘出
Y_pred = model.predict(X_test)
plt.scatter(X_test, Y_test)
plt.plot(X_test, Y_pred)
plt.show()

三、其他补充

1. numpy.linspace

numpy.linspace(start, stop, num=50, endpoint=True,retstep=False,dtype=None)

返回等差序列,序列范围在(start,end),生成num个元素的np数组,如果endpoint为False,则生成num+1个但是返回num个,retstep=True则在其后返回步长.

>>> np.linspace(2.0, 3.0, num=5)
array([ 2. , 2.25, 2.5 , 2.75, 3. ])
>>> np.linspace(2.0, 3.0, num=5, endpoint=False)
array([ 2. , 2.2, 2.4, 2.6, 2.8])
>>> np.linspace(2.0, 3.0, num=5, retstep=True)
(array([ 2. , 2.25, 2.5 , 2.75, 3. ]), 0.25)

以上这篇使用Keras实现简单线性回归模型操作就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python实现遍历目录的方法【测试可用】
Mar 22 Python
Python中sort和sorted函数代码解析
Jan 25 Python
python使用Matplotlib绘制分段函数
Sep 25 Python
对python中的高效迭代器函数详解
Oct 18 Python
PyQt5实现类似别踩白块游戏
Jan 24 Python
Python使用Paramiko控制liunx第三方库
May 20 Python
解决pycharm中的run和debug失效无法点击运行
Jun 09 Python
python BeautifulSoup库的安装与使用
Dec 17 Python
Python使用cn2an实现中文数字与阿拉伯数字的相互转换
Mar 02 Python
学会Python数据可视化必须尝试这7个库
Jun 16 Python
利用Python读取微信朋友圈的多种方法总结
Aug 23 Python
详解Python+OpenCV进行基础的图像操作
Feb 15 Python
Python实现Keras搭建神经网络训练分类模型教程
Jun 12 #Python
简单了解Python变量作用域正确使用方法
Jun 12 #Python
keras 读取多标签图像数据方式
Jun 12 #Python
Python数据可视化图实现过程详解
Jun 12 #Python
浅谈cv2.imread()和keras.preprocessing中的image.load_img()区别
Jun 12 #Python
升级keras解决load_weights()中的未定义skip_mismatch关键字问题
Jun 12 #Python
解决Tensorflow2.0 tf.keras.Model.load_weights() 报错处理问题
Jun 12 #Python
You might like
php执行sql语句的写法
2009/03/10 PHP
php变量范围介绍
2012/10/15 PHP
Mysql中分页查询的两个解决方法比较
2013/05/02 PHP
Symfony生成二维码的方法
2016/02/04 PHP
用JavaScript事件串连执行多个处理过程的方法
2007/03/09 Javascript
jQuery Flash/MP3/Video多媒体插件
2010/01/18 Javascript
一个XML格式数据转换为图表的例子
2010/02/09 Javascript
基于jquery创建的一个图片、视频缓冲的效果样式插件
2012/08/28 Javascript
jQuery实现表头固定效果的实例代码
2013/05/24 Javascript
picLazyLoad 实现图片延时加载(包含背景图片)
2016/07/21 Javascript
js与jquery分别实现tab标签页功能的方法
2016/11/18 Javascript
JS匿名函数类生成方式实例分析
2016/11/26 Javascript
JS实现的简单拖拽功能示例
2017/03/13 Javascript
详解axios中封装使用、拦截特定请求、判断所有请求加载完毕)
2019/04/09 Javascript
js设计模式之代理模式及订阅发布模式实例详解
2019/08/15 Javascript
express中static中间件的具体使用方法
2019/10/17 Javascript
Ajax获取node服务器数据的完整步骤
2020/09/20 Javascript
在Python中使用Neo4j数据库的教程
2015/04/16 Python
Python实现队列的方法
2015/05/26 Python
轻松掌握python设计模式之访问者模式
2016/11/18 Python
从头学Python之编写可执行的.py文件
2017/11/28 Python
python实现逆序输出一个数字的示例讲解
2018/06/25 Python
解决matplotlib.pyplot在Jupyter notebook中不显示图像问题
2020/04/22 Python
HTML5图片层叠的实现示例
2020/07/07 HTML / CSS
如何提高SQL Server的安全性
2016/07/25 面试题
生物技术毕业生自荐信
2013/10/23 职场文书
工商管理实习生自我鉴定范文
2013/12/18 职场文书
网上开商店的创业计划书
2014/01/19 职场文书
交通安全寄语大全
2014/04/08 职场文书
元旦寄语大全
2014/04/10 职场文书
青年志愿者先进事迹
2014/05/06 职场文书
校运会口号
2014/06/18 职场文书
公司活动总结范文
2014/07/01 职场文书
2016年世界艾滋病日宣传活动总结
2016/04/01 职场文书
导游词之徐州云龙湖
2019/11/19 职场文书
python基础之类方法和静态方法
2021/10/24 Python