使用Keras实现简单线性回归模型操作


Posted in Python onJune 12, 2020

神经网络可以用来模拟回归问题 (regression),实质上是单输入单输出神经网络模型,例如给下面一组数据,用一条线来对数据进行拟合,并可以预测新输入 x 的输出值。

使用Keras实现简单线性回归模型操作

一、详细解读

我们通过这个简单的例子来熟悉Keras构建神经网络的步骤:

1.导入模块并生成数据

首先导入本例子需要的模块,numpy、Matplotlib、和keras.models、keras.layers模块。Sequential是多个网络层的线性堆叠,可以通过向Sequential模型传递一个layer的list来构造该模型,也可以通过.add()方法一个个的将layer加入模型中。layers.Dense 意思是这个神经层是全连接层。

2.建立模型

然后用 Sequential 建立 model,再用 model.add 添加神经层,添加的是 Dense 全连接神经层。参数有两个,(注意此处Keras 2.0.2版本中有变更)一个是输入数据的维度,另一个units代表神经元数,即输出单元数。如果需要添加下一个神经层的时候,不用再定义输入的纬度,因为它默认就把前一层的输出作为当前层的输入。在这个简单的例子里,只需要一层就够了。

3.激活模型

model.compile来激活模型,参数中,误差函数用的是 mse均方误差;优化器用的是 sgd 随机梯度下降法。

4.训练模型

训练的时候用 model.train_on_batch 一批一批的训练 X_train, Y_train。默认的返回值是 cost,每100步输出一下结果。

5.验证模型

用到的函数是 model.evaluate,输入测试集的x和y,输出 cost,weights 和 biases。其中 weights 和 biases 是取在模型的第一层 model.layers[0] 学习到的参数。从学习到的结果你可以看到, weights 比较接近0.5,bias 接近 2。

Weights= [[ 0.49136472]]

biases= [ 2.00405312]

6.可视化学习结果

最后可以画出预测结果,与测试集的值进行对比。

使用Keras实现简单线性回归模型操作

二、完整代码

import numpy as np
np.random.seed(1337) 
from keras.models import Sequential
from keras.layers import Dense
import matplotlib.pyplot as plt
 
# 生成数据
X = np.linspace(-1, 1, 200) #在返回(-1, 1)范围内的等差序列
np.random.shuffle(X) # 打乱顺序
Y = 0.5 * X + 2 + np.random.normal(0, 0.05, (200, )) #生成Y并添加噪声
# plot
plt.scatter(X, Y)
plt.show()
 
X_train, Y_train = X[:160], Y[:160]  # 前160组数据为训练数据集
X_test, Y_test = X[160:], Y[160:]  #后40组数据为测试数据集
 
# 构建神经网络模型
model = Sequential()
model.add(Dense(input_dim=1, units=1))
 
# 选定loss函数和优化器
model.compile(loss='mse', optimizer='sgd')
 
# 训练过程
print('Training -----------')
for step in range(501):
 cost = model.train_on_batch(X_train, Y_train)
 if step % 50 == 0:
  print("After %d trainings, the cost: %f" % (step, cost))
 
# 测试过程
print('\nTesting ------------')
cost = model.evaluate(X_test, Y_test, batch_size=40)
print('test cost:', cost)
W, b = model.layers[0].get_weights()
print('Weights=', W, '\nbiases=', b)
 
# 将训练结果绘出
Y_pred = model.predict(X_test)
plt.scatter(X_test, Y_test)
plt.plot(X_test, Y_pred)
plt.show()

三、其他补充

1. numpy.linspace

numpy.linspace(start, stop, num=50, endpoint=True,retstep=False,dtype=None)

返回等差序列,序列范围在(start,end),生成num个元素的np数组,如果endpoint为False,则生成num+1个但是返回num个,retstep=True则在其后返回步长.

>>> np.linspace(2.0, 3.0, num=5)
array([ 2. , 2.25, 2.5 , 2.75, 3. ])
>>> np.linspace(2.0, 3.0, num=5, endpoint=False)
array([ 2. , 2.2, 2.4, 2.6, 2.8])
>>> np.linspace(2.0, 3.0, num=5, retstep=True)
(array([ 2. , 2.25, 2.5 , 2.75, 3. ]), 0.25)

以上这篇使用Keras实现简单线性回归模型操作就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
分析Python的Django框架的运行方式及处理流程
Apr 08 Python
Python中unittest模块做UT(单元测试)使用实例
Jun 12 Python
python爬取NUS-WIDE数据库图片
Oct 05 Python
浅谈python jieba分词模块的基本用法
Nov 09 Python
python3实现SMTP发送邮件详细教程
Jun 19 Python
numpy向空的二维数组中添加元素的方法
Nov 01 Python
flask应用部署到服务器的方法
Jul 12 Python
python tkinter实现屏保程序
Jul 30 Python
Python超越函数积分运算以及绘图实现代码
Nov 20 Python
django xadmin 管理器常用显示设置方式
Mar 11 Python
Python线程协作threading.Condition实现过程解析
Mar 12 Python
python实现图像全景拼接
Mar 27 Python
Python实现Keras搭建神经网络训练分类模型教程
Jun 12 #Python
简单了解Python变量作用域正确使用方法
Jun 12 #Python
keras 读取多标签图像数据方式
Jun 12 #Python
Python数据可视化图实现过程详解
Jun 12 #Python
浅谈cv2.imread()和keras.preprocessing中的image.load_img()区别
Jun 12 #Python
升级keras解决load_weights()中的未定义skip_mismatch关键字问题
Jun 12 #Python
解决Tensorflow2.0 tf.keras.Model.load_weights() 报错处理问题
Jun 12 #Python
You might like
php实现文件上传及头像预览功能
2017/01/15 PHP
jQuery 事件队列调整方法
2009/09/18 Javascript
jquery实现metro效果示例代码
2013/09/06 Javascript
Mac OS X 系统下安装和部署Egret引擎开发环境
2014/09/03 Javascript
JavaScript中的全局对象介绍
2015/01/01 Javascript
jQuery 3.0 的变化及使用方法
2016/02/01 Javascript
jQuery hover事件简单实现同时绑定2个方法
2016/06/07 Javascript
JavaScript必知必会(七)js对象继承
2016/06/08 Javascript
AngularJs定制样式插入到ueditor中的问题小结
2016/08/01 Javascript
JS简单实现点击复制链接的方法
2016/08/03 Javascript
jQuery事件用法详解
2016/10/06 Javascript
Vue.js实现的计算器功能完整示例
2018/07/11 Javascript
基于vue+axios+lrz.js微信端图片压缩上传方法
2019/06/25 Javascript
js实现网页随机验证码
2020/10/19 Javascript
修改NPM全局模式的默认安装路径的方法
2020/12/15 Javascript
Python编程语言的35个与众不同之处(语言特征和使用技巧)
2014/07/07 Python
Python实现数据库编程方法详解
2015/06/09 Python
使用Turtle画正螺旋线的方法
2017/09/22 Python
python利用正则表达式搜索单词示例代码
2017/09/24 Python
用不到50行的Python代码构建最小的区块链
2017/11/16 Python
python使用tkinter实现简单计算器
2018/01/30 Python
Ubuntu18.04下python版本完美切换的解决方法
2019/06/14 Python
在Python中append以及extend返回None的例子
2019/07/20 Python
python匿名函数lambda原理及实例解析
2020/02/07 Python
Python numpy多维数组实现原理详解
2020/03/10 Python
python如何支持并发方法详解
2020/07/25 Python
Python实现随机爬山算法
2021/01/29 Python
英国品牌男装折扣网站:Brown Bag
2018/03/08 全球购物
Zooplus葡萄牙:欧洲领先的网上宠物商店
2018/07/01 全球购物
动物科学专业毕业生的自我评价
2013/11/29 职场文书
毕业生大学生活自我总结
2014/01/31 职场文书
奉献爱心演讲稿
2014/09/04 职场文书
公民代理授权委托书
2014/09/24 职场文书
2015年学校办公室工作总结
2015/05/26 职场文书
课程设计感想范文
2015/08/11 职场文书
2016年学生会感恩节活动总结
2016/04/01 职场文书