使用Keras实现简单线性回归模型操作


Posted in Python onJune 12, 2020

神经网络可以用来模拟回归问题 (regression),实质上是单输入单输出神经网络模型,例如给下面一组数据,用一条线来对数据进行拟合,并可以预测新输入 x 的输出值。

使用Keras实现简单线性回归模型操作

一、详细解读

我们通过这个简单的例子来熟悉Keras构建神经网络的步骤:

1.导入模块并生成数据

首先导入本例子需要的模块,numpy、Matplotlib、和keras.models、keras.layers模块。Sequential是多个网络层的线性堆叠,可以通过向Sequential模型传递一个layer的list来构造该模型,也可以通过.add()方法一个个的将layer加入模型中。layers.Dense 意思是这个神经层是全连接层。

2.建立模型

然后用 Sequential 建立 model,再用 model.add 添加神经层,添加的是 Dense 全连接神经层。参数有两个,(注意此处Keras 2.0.2版本中有变更)一个是输入数据的维度,另一个units代表神经元数,即输出单元数。如果需要添加下一个神经层的时候,不用再定义输入的纬度,因为它默认就把前一层的输出作为当前层的输入。在这个简单的例子里,只需要一层就够了。

3.激活模型

model.compile来激活模型,参数中,误差函数用的是 mse均方误差;优化器用的是 sgd 随机梯度下降法。

4.训练模型

训练的时候用 model.train_on_batch 一批一批的训练 X_train, Y_train。默认的返回值是 cost,每100步输出一下结果。

5.验证模型

用到的函数是 model.evaluate,输入测试集的x和y,输出 cost,weights 和 biases。其中 weights 和 biases 是取在模型的第一层 model.layers[0] 学习到的参数。从学习到的结果你可以看到, weights 比较接近0.5,bias 接近 2。

Weights= [[ 0.49136472]]

biases= [ 2.00405312]

6.可视化学习结果

最后可以画出预测结果,与测试集的值进行对比。

使用Keras实现简单线性回归模型操作

二、完整代码

import numpy as np
np.random.seed(1337) 
from keras.models import Sequential
from keras.layers import Dense
import matplotlib.pyplot as plt
 
# 生成数据
X = np.linspace(-1, 1, 200) #在返回(-1, 1)范围内的等差序列
np.random.shuffle(X) # 打乱顺序
Y = 0.5 * X + 2 + np.random.normal(0, 0.05, (200, )) #生成Y并添加噪声
# plot
plt.scatter(X, Y)
plt.show()
 
X_train, Y_train = X[:160], Y[:160]  # 前160组数据为训练数据集
X_test, Y_test = X[160:], Y[160:]  #后40组数据为测试数据集
 
# 构建神经网络模型
model = Sequential()
model.add(Dense(input_dim=1, units=1))
 
# 选定loss函数和优化器
model.compile(loss='mse', optimizer='sgd')
 
# 训练过程
print('Training -----------')
for step in range(501):
 cost = model.train_on_batch(X_train, Y_train)
 if step % 50 == 0:
  print("After %d trainings, the cost: %f" % (step, cost))
 
# 测试过程
print('\nTesting ------------')
cost = model.evaluate(X_test, Y_test, batch_size=40)
print('test cost:', cost)
W, b = model.layers[0].get_weights()
print('Weights=', W, '\nbiases=', b)
 
# 将训练结果绘出
Y_pred = model.predict(X_test)
plt.scatter(X_test, Y_test)
plt.plot(X_test, Y_pred)
plt.show()

三、其他补充

1. numpy.linspace

numpy.linspace(start, stop, num=50, endpoint=True,retstep=False,dtype=None)

返回等差序列,序列范围在(start,end),生成num个元素的np数组,如果endpoint为False,则生成num+1个但是返回num个,retstep=True则在其后返回步长.

>>> np.linspace(2.0, 3.0, num=5)
array([ 2. , 2.25, 2.5 , 2.75, 3. ])
>>> np.linspace(2.0, 3.0, num=5, endpoint=False)
array([ 2. , 2.2, 2.4, 2.6, 2.8])
>>> np.linspace(2.0, 3.0, num=5, retstep=True)
(array([ 2. , 2.25, 2.5 , 2.75, 3. ]), 0.25)

以上这篇使用Keras实现简单线性回归模型操作就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
详解Python中time()方法的使用的教程
May 22 Python
python操作ssh实现服务器日志下载的方法
Jun 03 Python
Pycharm编辑器技巧之自动导入模块详解
Jul 18 Python
Python 多线程的实例详解
Sep 07 Python
python 筛选数据集中列中value长度大于20的数据集方法
Jun 14 Python
python中pika模块问题的深入探究
Oct 13 Python
Django csrf 两种方法设置form的实例
Feb 03 Python
python数据分析:关键字提取方式
Feb 24 Python
python 函数嵌套及多函数共同运行知识点讲解
Mar 03 Python
python数据库开发之MongoDB安装及Python3操作MongoDB数据库详细方法与实例
Mar 18 Python
iPython pylab模式启动方式
Apr 24 Python
Linux安装Python3如何和系统自带的Python2并存
Jul 23 Python
Python实现Keras搭建神经网络训练分类模型教程
Jun 12 #Python
简单了解Python变量作用域正确使用方法
Jun 12 #Python
keras 读取多标签图像数据方式
Jun 12 #Python
Python数据可视化图实现过程详解
Jun 12 #Python
浅谈cv2.imread()和keras.preprocessing中的image.load_img()区别
Jun 12 #Python
升级keras解决load_weights()中的未定义skip_mismatch关键字问题
Jun 12 #Python
解决Tensorflow2.0 tf.keras.Model.load_weights() 报错处理问题
Jun 12 #Python
You might like
PHP实现将textarea的值根据回车换行拆分至数组
2015/06/10 PHP
thinkphp命名空间用法实例详解
2015/12/30 PHP
php解决DOM乱码的方法示例代码
2016/11/20 PHP
php常用字符串String函数实例总结【转换,替换,计算,截取,加密】
2016/12/07 PHP
setTimeout 不断吐食CPU的问题分析
2009/04/01 Javascript
jquery 查找iframe父级页面元素的实现代码
2011/08/28 Javascript
创建公共调用 jQuery Ajax 带返回值
2012/08/01 Javascript
jquery鼠标滑过提示title具体实现代码
2013/08/06 Javascript
多选列表框动态添加,移动,删除,全选等操作的简单实例
2014/01/13 Javascript
浅谈jQuery中的事件
2015/03/23 Javascript
js日期范围初始化得到前一个月日期的方法
2015/05/05 Javascript
jQuery中 prop() attr()使用详解
2015/05/19 Javascript
详解JavaScript中循环控制语句的用法
2015/06/03 Javascript
深入浅析JavaScript中prototype和proto的关系
2015/11/15 Javascript
javascript瀑布流式图片懒加载实例解析与优化
2016/02/23 Javascript
javascirpt实现2个iframe之间传值的方法
2016/06/30 Javascript
php输出全部gb2312编码内的汉字方法
2017/03/04 Javascript
javascript中json对象json数组json字符串互转及取值方法
2017/04/19 Javascript
Node.js pipe实现源码解析
2017/08/12 Javascript
深入理解ES7的async/await的用法
2017/09/09 Javascript
Vue Cli3 创建项目的方法步骤
2018/10/15 Javascript
vue2路由基本用法实例分析
2020/03/06 Javascript
[01:28:56]2014 DOTA2华西杯精英邀请赛 5 24 CIS VS DK
2014/05/26 DOTA
[01:12:27]EG vs Secret 2018国际邀请赛淘汰赛BO3 第二场 8.22
2018/08/23 DOTA
Python多进程multiprocessing.Pool类详解
2018/04/27 Python
VSCode中自动为Python文件添加头部注释
2019/11/14 Python
css3编写浏览器背景渐变背景色的方法
2018/03/05 HTML / CSS
如何在Canvas上的图形/图像绑定事件监听的实现
2020/09/16 HTML / CSS
Prototype如何更新局部页面
2013/03/03 面试题
《尊严》教学反思
2014/02/11 职场文书
2014年语文教师工作总结
2014/12/18 职场文书
培训通知书模板
2015/04/17 职场文书
债务纠纷代理词
2015/05/25 职场文书
红色电影观后感
2015/06/18 职场文书
2016年学校安全教育月活动总结
2016/04/06 职场文书
OpenFeign实现远程调用
2022/08/14 Java/Android