使用Keras实现简单线性回归模型操作


Posted in Python onJune 12, 2020

神经网络可以用来模拟回归问题 (regression),实质上是单输入单输出神经网络模型,例如给下面一组数据,用一条线来对数据进行拟合,并可以预测新输入 x 的输出值。

使用Keras实现简单线性回归模型操作

一、详细解读

我们通过这个简单的例子来熟悉Keras构建神经网络的步骤:

1.导入模块并生成数据

首先导入本例子需要的模块,numpy、Matplotlib、和keras.models、keras.layers模块。Sequential是多个网络层的线性堆叠,可以通过向Sequential模型传递一个layer的list来构造该模型,也可以通过.add()方法一个个的将layer加入模型中。layers.Dense 意思是这个神经层是全连接层。

2.建立模型

然后用 Sequential 建立 model,再用 model.add 添加神经层,添加的是 Dense 全连接神经层。参数有两个,(注意此处Keras 2.0.2版本中有变更)一个是输入数据的维度,另一个units代表神经元数,即输出单元数。如果需要添加下一个神经层的时候,不用再定义输入的纬度,因为它默认就把前一层的输出作为当前层的输入。在这个简单的例子里,只需要一层就够了。

3.激活模型

model.compile来激活模型,参数中,误差函数用的是 mse均方误差;优化器用的是 sgd 随机梯度下降法。

4.训练模型

训练的时候用 model.train_on_batch 一批一批的训练 X_train, Y_train。默认的返回值是 cost,每100步输出一下结果。

5.验证模型

用到的函数是 model.evaluate,输入测试集的x和y,输出 cost,weights 和 biases。其中 weights 和 biases 是取在模型的第一层 model.layers[0] 学习到的参数。从学习到的结果你可以看到, weights 比较接近0.5,bias 接近 2。

Weights= [[ 0.49136472]]

biases= [ 2.00405312]

6.可视化学习结果

最后可以画出预测结果,与测试集的值进行对比。

使用Keras实现简单线性回归模型操作

二、完整代码

import numpy as np
np.random.seed(1337) 
from keras.models import Sequential
from keras.layers import Dense
import matplotlib.pyplot as plt
 
# 生成数据
X = np.linspace(-1, 1, 200) #在返回(-1, 1)范围内的等差序列
np.random.shuffle(X) # 打乱顺序
Y = 0.5 * X + 2 + np.random.normal(0, 0.05, (200, )) #生成Y并添加噪声
# plot
plt.scatter(X, Y)
plt.show()
 
X_train, Y_train = X[:160], Y[:160]  # 前160组数据为训练数据集
X_test, Y_test = X[160:], Y[160:]  #后40组数据为测试数据集
 
# 构建神经网络模型
model = Sequential()
model.add(Dense(input_dim=1, units=1))
 
# 选定loss函数和优化器
model.compile(loss='mse', optimizer='sgd')
 
# 训练过程
print('Training -----------')
for step in range(501):
 cost = model.train_on_batch(X_train, Y_train)
 if step % 50 == 0:
  print("After %d trainings, the cost: %f" % (step, cost))
 
# 测试过程
print('\nTesting ------------')
cost = model.evaluate(X_test, Y_test, batch_size=40)
print('test cost:', cost)
W, b = model.layers[0].get_weights()
print('Weights=', W, '\nbiases=', b)
 
# 将训练结果绘出
Y_pred = model.predict(X_test)
plt.scatter(X_test, Y_test)
plt.plot(X_test, Y_pred)
plt.show()

三、其他补充

1. numpy.linspace

numpy.linspace(start, stop, num=50, endpoint=True,retstep=False,dtype=None)

返回等差序列,序列范围在(start,end),生成num个元素的np数组,如果endpoint为False,则生成num+1个但是返回num个,retstep=True则在其后返回步长.

>>> np.linspace(2.0, 3.0, num=5)
array([ 2. , 2.25, 2.5 , 2.75, 3. ])
>>> np.linspace(2.0, 3.0, num=5, endpoint=False)
array([ 2. , 2.2, 2.4, 2.6, 2.8])
>>> np.linspace(2.0, 3.0, num=5, retstep=True)
(array([ 2. , 2.25, 2.5 , 2.75, 3. ]), 0.25)

以上这篇使用Keras实现简单线性回归模型操作就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python实现合并字典的方法
Jul 07 Python
利用Python如何生成随机密码
Apr 20 Python
Python自动化运维之IP地址处理模块详解
Dec 10 Python
Python模块文件结构代码详解
Feb 03 Python
python中sys.argv函数精简概括
Jul 08 Python
Django 对象关系映射(ORM)源码详解
Aug 06 Python
python SocketServer源码深入解读
Sep 17 Python
详解python中各种文件打开模式
Jan 19 Python
Python random库使用方法及异常处理方案
Mar 02 Python
Python 通过爬虫实现GitHub网页的模拟登录的示例代码
Aug 17 Python
Python使用lambda抛出异常实现方法解析
Aug 20 Python
Python 实现PS滤镜的旋涡特效
Dec 03 Python
Python实现Keras搭建神经网络训练分类模型教程
Jun 12 #Python
简单了解Python变量作用域正确使用方法
Jun 12 #Python
keras 读取多标签图像数据方式
Jun 12 #Python
Python数据可视化图实现过程详解
Jun 12 #Python
浅谈cv2.imread()和keras.preprocessing中的image.load_img()区别
Jun 12 #Python
升级keras解决load_weights()中的未定义skip_mismatch关键字问题
Jun 12 #Python
解决Tensorflow2.0 tf.keras.Model.load_weights() 报错处理问题
Jun 12 #Python
You might like
snoopy 强大的PHP采集类使用实例代码
2010/12/09 PHP
下拉列表多级联动dropDownList示例代码
2013/06/27 PHP
Zend Framework分页类用法详解
2016/03/22 PHP
jquery tab插件精简版分享
2011/09/10 Javascript
如何用JavaScript动态呼叫函数(两种方式)
2013/05/03 Javascript
原生javascript实现图片弹窗交互效果
2015/01/12 Javascript
Jquery Ajax xmlhttp请求成功问题
2015/02/04 Javascript
Bootstrap框架结合jQuery仿百度换肤功能实例解析
2016/09/17 Javascript
浅谈js对象的创建和对6种继承模式的理解和遐想
2016/10/16 Javascript
Bootstrap实现导航栏的2种方式
2016/11/28 Javascript
激动人心的 Angular HttpClient的源码解析
2017/07/10 Javascript
vue: WebStorm设置快速编译运行的方法
2018/10/18 Javascript
jQuery简单实现根据日期计算星期几的方法
2019/01/09 jQuery
vue element upload组件 file-list的动态绑定实现
2019/10/11 Javascript
JS实现的雪花飘落特效示例
2019/12/03 Javascript
使用Python的判断语句模拟三目运算
2015/04/24 Python
python PIL模块与随机生成中文验证码
2016/02/27 Python
PyQt实现界面翻转切换效果
2018/04/20 Python
python组合无重复三位数的实例
2018/11/13 Python
Python中请不要再用re.compile了
2019/06/30 Python
使用Python实现 学生学籍管理系统
2019/11/26 Python
python字典与json转换的方法总结
2020/12/28 Python
python基于爬虫+django,打造个性化API接口
2021/01/21 Python
对CSS3选择器的研究(详解)
2016/09/16 HTML / CSS
Elemental Herbology官网:英国美容品牌
2019/04/27 全球购物
伦敦最受欢迎的蛋糕店:Konditor & Cook
2019/11/01 全球购物
美国专业消费电子及摄影器材网站:B&H Photo Video
2019/12/18 全球购物
什么情况下你必须要把一个类定义为abstract的
2013/01/06 面试题
初二物理教学反思
2014/01/29 职场文书
家长对孩子评语
2014/01/30 职场文书
学校十一活动方案
2014/02/01 职场文书
《水乡歌》教学反思
2014/04/24 职场文书
2014年语文教研组工作总结
2014/12/06 职场文书
年度考核表个人总结
2015/03/06 职场文书
环保主题班会教案
2015/08/13 职场文书
比较几种Redis集群方案
2021/06/21 Redis