使用Keras实现简单线性回归模型操作


Posted in Python onJune 12, 2020

神经网络可以用来模拟回归问题 (regression),实质上是单输入单输出神经网络模型,例如给下面一组数据,用一条线来对数据进行拟合,并可以预测新输入 x 的输出值。

使用Keras实现简单线性回归模型操作

一、详细解读

我们通过这个简单的例子来熟悉Keras构建神经网络的步骤:

1.导入模块并生成数据

首先导入本例子需要的模块,numpy、Matplotlib、和keras.models、keras.layers模块。Sequential是多个网络层的线性堆叠,可以通过向Sequential模型传递一个layer的list来构造该模型,也可以通过.add()方法一个个的将layer加入模型中。layers.Dense 意思是这个神经层是全连接层。

2.建立模型

然后用 Sequential 建立 model,再用 model.add 添加神经层,添加的是 Dense 全连接神经层。参数有两个,(注意此处Keras 2.0.2版本中有变更)一个是输入数据的维度,另一个units代表神经元数,即输出单元数。如果需要添加下一个神经层的时候,不用再定义输入的纬度,因为它默认就把前一层的输出作为当前层的输入。在这个简单的例子里,只需要一层就够了。

3.激活模型

model.compile来激活模型,参数中,误差函数用的是 mse均方误差;优化器用的是 sgd 随机梯度下降法。

4.训练模型

训练的时候用 model.train_on_batch 一批一批的训练 X_train, Y_train。默认的返回值是 cost,每100步输出一下结果。

5.验证模型

用到的函数是 model.evaluate,输入测试集的x和y,输出 cost,weights 和 biases。其中 weights 和 biases 是取在模型的第一层 model.layers[0] 学习到的参数。从学习到的结果你可以看到, weights 比较接近0.5,bias 接近 2。

Weights= [[ 0.49136472]]

biases= [ 2.00405312]

6.可视化学习结果

最后可以画出预测结果,与测试集的值进行对比。

使用Keras实现简单线性回归模型操作

二、完整代码

import numpy as np
np.random.seed(1337) 
from keras.models import Sequential
from keras.layers import Dense
import matplotlib.pyplot as plt
 
# 生成数据
X = np.linspace(-1, 1, 200) #在返回(-1, 1)范围内的等差序列
np.random.shuffle(X) # 打乱顺序
Y = 0.5 * X + 2 + np.random.normal(0, 0.05, (200, )) #生成Y并添加噪声
# plot
plt.scatter(X, Y)
plt.show()
 
X_train, Y_train = X[:160], Y[:160]  # 前160组数据为训练数据集
X_test, Y_test = X[160:], Y[160:]  #后40组数据为测试数据集
 
# 构建神经网络模型
model = Sequential()
model.add(Dense(input_dim=1, units=1))
 
# 选定loss函数和优化器
model.compile(loss='mse', optimizer='sgd')
 
# 训练过程
print('Training -----------')
for step in range(501):
 cost = model.train_on_batch(X_train, Y_train)
 if step % 50 == 0:
  print("After %d trainings, the cost: %f" % (step, cost))
 
# 测试过程
print('\nTesting ------------')
cost = model.evaluate(X_test, Y_test, batch_size=40)
print('test cost:', cost)
W, b = model.layers[0].get_weights()
print('Weights=', W, '\nbiases=', b)
 
# 将训练结果绘出
Y_pred = model.predict(X_test)
plt.scatter(X_test, Y_test)
plt.plot(X_test, Y_pred)
plt.show()

三、其他补充

1. numpy.linspace

numpy.linspace(start, stop, num=50, endpoint=True,retstep=False,dtype=None)

返回等差序列,序列范围在(start,end),生成num个元素的np数组,如果endpoint为False,则生成num+1个但是返回num个,retstep=True则在其后返回步长.

>>> np.linspace(2.0, 3.0, num=5)
array([ 2. , 2.25, 2.5 , 2.75, 3. ])
>>> np.linspace(2.0, 3.0, num=5, endpoint=False)
array([ 2. , 2.2, 2.4, 2.6, 2.8])
>>> np.linspace(2.0, 3.0, num=5, retstep=True)
(array([ 2. , 2.25, 2.5 , 2.75, 3. ]), 0.25)

以上这篇使用Keras实现简单线性回归模型操作就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python实现Mysql数据库连接池实例详解
Apr 11 Python
Python常见加密模块用法分析【MD5,sha,crypt模块】
May 24 Python
聊聊Python中的pypy
Jan 12 Python
浅谈Python中的私有变量
Feb 28 Python
python 高效去重复 支持GB级别大文件的示例代码
Nov 08 Python
Python中的引用知识点总结
May 20 Python
Python的numpy库下的几个小函数的用法(小结)
Jul 12 Python
python集合删除多种方法详解
Feb 10 Python
python程序输出无内容的解决方式
Apr 09 Python
pandas将list数据拆分成行或列的实现
Dec 13 Python
Python用SSH连接到网络设备
Feb 18 Python
python利用opencv实现颜色检测
Feb 23 Python
Python实现Keras搭建神经网络训练分类模型教程
Jun 12 #Python
简单了解Python变量作用域正确使用方法
Jun 12 #Python
keras 读取多标签图像数据方式
Jun 12 #Python
Python数据可视化图实现过程详解
Jun 12 #Python
浅谈cv2.imread()和keras.preprocessing中的image.load_img()区别
Jun 12 #Python
升级keras解决load_weights()中的未定义skip_mismatch关键字问题
Jun 12 #Python
解决Tensorflow2.0 tf.keras.Model.load_weights() 报错处理问题
Jun 12 #Python
You might like
用mysql_fetch_array()获取当前行数据的方法详解
2013/06/05 PHP
浅析php面向对象public private protected 访问修饰符
2013/06/30 PHP
php中file_get_contents与curl性能比较分析
2014/11/08 PHP
PHP中关于php.ini参数优化详解
2020/02/28 PHP
javascript四舍五入函数代码分享(保留后几位)
2013/12/10 Javascript
js关于字符长度限制的问题示例探讨
2014/01/24 Javascript
JS实现判断滚动条滚到页面底部并执行事件的方法
2014/12/18 Javascript
js实现点击图片改变页面背景图的方法
2015/02/28 Javascript
AspNet中使用JQuery上传插件Uploadify详解
2015/05/20 Javascript
JQuery Ajax WebService传递参数的简单实例
2016/11/02 Javascript
基于jQuery插件jqzoom实现的图片放大镜效果示例
2017/01/23 Javascript
bootstrap vue.js实现tab效果
2017/02/07 Javascript
js 实现省市区三级联动菜单效果
2017/02/20 Javascript
微信小程序动态显示项目倒计时效果
2017/06/13 Javascript
JavaScript数组push方法使用注意事项
2017/10/30 Javascript
JavaScript实现的级联算法示例【省市二级联动功能】
2018/12/25 Javascript
vue+elementUI 复杂表单的验证、数据提交方案问题
2019/06/24 Javascript
小程序中this.setData的使用和注意事项
2019/08/28 Javascript
使用layui 的layedit定义自己的toolbar方法
2019/09/18 Javascript
Python和php通信乱码问题解决方法
2014/04/15 Python
跟老齐学Python之有容乃大的list(4)
2014/09/28 Python
python使用三角迭代计算圆周率PI的方法
2015/03/20 Python
Python数据结构与算法之常见的分配排序法示例【桶排序与基数排序】
2017/12/15 Python
Python实现进程同步和通信的方法
2018/01/02 Python
详谈python在windows中的文件路径问题
2018/04/28 Python
pymongo中group by的操作方法教程
2019/03/22 Python
Python 类方法和实例方法(@classmethod),静态方法(@staticmethod)原理与用法分析
2019/09/20 Python
python爬虫 正则表达式解析
2019/09/28 Python
python读取raw binary图片并提取统计信息的实例
2020/01/09 Python
如何在django中实现分页功能
2020/04/22 Python
北美主要的汽车零部件零售商:AutoShack.com
2019/02/23 全球购物
晚宴邀请函范文
2014/01/15 职场文书
企业内控岗位的职责
2014/02/07 职场文书
在宿舍喝酒的检讨书
2014/09/28 职场文书
大学生毕业评语
2014/12/31 职场文书
建筑安全员岗位职责
2015/02/15 职场文书