使用Keras实现简单线性回归模型操作


Posted in Python onJune 12, 2020

神经网络可以用来模拟回归问题 (regression),实质上是单输入单输出神经网络模型,例如给下面一组数据,用一条线来对数据进行拟合,并可以预测新输入 x 的输出值。

使用Keras实现简单线性回归模型操作

一、详细解读

我们通过这个简单的例子来熟悉Keras构建神经网络的步骤:

1.导入模块并生成数据

首先导入本例子需要的模块,numpy、Matplotlib、和keras.models、keras.layers模块。Sequential是多个网络层的线性堆叠,可以通过向Sequential模型传递一个layer的list来构造该模型,也可以通过.add()方法一个个的将layer加入模型中。layers.Dense 意思是这个神经层是全连接层。

2.建立模型

然后用 Sequential 建立 model,再用 model.add 添加神经层,添加的是 Dense 全连接神经层。参数有两个,(注意此处Keras 2.0.2版本中有变更)一个是输入数据的维度,另一个units代表神经元数,即输出单元数。如果需要添加下一个神经层的时候,不用再定义输入的纬度,因为它默认就把前一层的输出作为当前层的输入。在这个简单的例子里,只需要一层就够了。

3.激活模型

model.compile来激活模型,参数中,误差函数用的是 mse均方误差;优化器用的是 sgd 随机梯度下降法。

4.训练模型

训练的时候用 model.train_on_batch 一批一批的训练 X_train, Y_train。默认的返回值是 cost,每100步输出一下结果。

5.验证模型

用到的函数是 model.evaluate,输入测试集的x和y,输出 cost,weights 和 biases。其中 weights 和 biases 是取在模型的第一层 model.layers[0] 学习到的参数。从学习到的结果你可以看到, weights 比较接近0.5,bias 接近 2。

Weights= [[ 0.49136472]]

biases= [ 2.00405312]

6.可视化学习结果

最后可以画出预测结果,与测试集的值进行对比。

使用Keras实现简单线性回归模型操作

二、完整代码

import numpy as np
np.random.seed(1337) 
from keras.models import Sequential
from keras.layers import Dense
import matplotlib.pyplot as plt
 
# 生成数据
X = np.linspace(-1, 1, 200) #在返回(-1, 1)范围内的等差序列
np.random.shuffle(X) # 打乱顺序
Y = 0.5 * X + 2 + np.random.normal(0, 0.05, (200, )) #生成Y并添加噪声
# plot
plt.scatter(X, Y)
plt.show()
 
X_train, Y_train = X[:160], Y[:160]  # 前160组数据为训练数据集
X_test, Y_test = X[160:], Y[160:]  #后40组数据为测试数据集
 
# 构建神经网络模型
model = Sequential()
model.add(Dense(input_dim=1, units=1))
 
# 选定loss函数和优化器
model.compile(loss='mse', optimizer='sgd')
 
# 训练过程
print('Training -----------')
for step in range(501):
 cost = model.train_on_batch(X_train, Y_train)
 if step % 50 == 0:
  print("After %d trainings, the cost: %f" % (step, cost))
 
# 测试过程
print('\nTesting ------------')
cost = model.evaluate(X_test, Y_test, batch_size=40)
print('test cost:', cost)
W, b = model.layers[0].get_weights()
print('Weights=', W, '\nbiases=', b)
 
# 将训练结果绘出
Y_pred = model.predict(X_test)
plt.scatter(X_test, Y_test)
plt.plot(X_test, Y_pred)
plt.show()

三、其他补充

1. numpy.linspace

numpy.linspace(start, stop, num=50, endpoint=True,retstep=False,dtype=None)

返回等差序列,序列范围在(start,end),生成num个元素的np数组,如果endpoint为False,则生成num+1个但是返回num个,retstep=True则在其后返回步长.

>>> np.linspace(2.0, 3.0, num=5)
array([ 2. , 2.25, 2.5 , 2.75, 3. ])
>>> np.linspace(2.0, 3.0, num=5, endpoint=False)
array([ 2. , 2.2, 2.4, 2.6, 2.8])
>>> np.linspace(2.0, 3.0, num=5, retstep=True)
(array([ 2. , 2.25, 2.5 , 2.75, 3. ]), 0.25)

以上这篇使用Keras实现简单线性回归模型操作就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
优化Python代码使其加快作用域内的查找
Mar 30 Python
Python 私有函数的实例详解
Sep 11 Python
Python3安装Scrapy的方法步骤
Nov 23 Python
Python实现生成随机数据插入mysql数据库的方法
Dec 25 Python
Django使用Celery异步任务队列的使用
Mar 13 Python
Python continue继续循环用法总结
Jun 10 Python
python3 cvs将数据读取为字典的方法
Dec 22 Python
Pycharm 设置默认头的图文教程
Jan 17 Python
Python3远程监控程序的实现方法
Jul 15 Python
python global和nonlocal用法解析
Feb 03 Python
Django关于admin的使用技巧和知识点
Feb 10 Python
经验丰富程序员才知道的8种高级Python技巧
Jul 27 Python
Python实现Keras搭建神经网络训练分类模型教程
Jun 12 #Python
简单了解Python变量作用域正确使用方法
Jun 12 #Python
keras 读取多标签图像数据方式
Jun 12 #Python
Python数据可视化图实现过程详解
Jun 12 #Python
浅谈cv2.imread()和keras.preprocessing中的image.load_img()区别
Jun 12 #Python
升级keras解决load_weights()中的未定义skip_mismatch关键字问题
Jun 12 #Python
解决Tensorflow2.0 tf.keras.Model.load_weights() 报错处理问题
Jun 12 #Python
You might like
PHP的SQL注入过程分析
2012/01/06 PHP
JavaScript判断一个URL链接是否有效的实现方法
2011/10/08 Javascript
js点击选择文本的方法
2015/02/09 Javascript
分享一些常用的jQuery动画事件和动画函数
2015/11/27 Javascript
Bootstrap输入框组件简单实现代码
2017/03/06 Javascript
详解Vue.js中.native修饰符
2018/04/24 Javascript
jQuery常见的遍历DOM操作详解
2018/09/05 jQuery
JavaScript学习笔记之数组基本操作示例
2019/01/09 Javascript
Angular中使用ng-zorro图标库部分图标不能正常显示问题
2019/04/22 Javascript
从零到一详聊创建Vue工程及遇到的常见问题
2019/04/25 Javascript
利用d3.js制作连线动画图与编辑器的方法实例
2019/09/05 Javascript
浅谈vue使用axios的回调函数中this不指向vue实例,为undefined
2020/09/21 Javascript
python实现批量下载新浪博客的方法
2015/06/15 Python
python使用MySQLdb访问mysql数据库的方法
2015/08/03 Python
Python中字符串的常见操作技巧总结
2016/07/28 Python
对numpy中轴与维度的理解
2018/04/18 Python
Python3多线程操作简单示例
2018/05/22 Python
Python Tkinter模块实现时钟功能应用示例
2018/07/23 Python
python将视频转换为全字符视频
2019/04/26 Python
Python学习笔记之集合的概念和简单使用示例
2019/08/22 Python
python RC4加密操作示例【测试可用】
2019/09/26 Python
Pytorch 之修改Tensor部分值方式
2019/12/27 Python
python使用gdal对shp读取,新建和更新的实例
2020/03/10 Python
Tensorflow中的图(tf.Graph)和会话(tf.Session)的实现
2020/04/22 Python
django queryset 去重 .distinct()说明
2020/05/19 Python
台湾演唱会订票网站:StubHub台湾
2019/06/11 全球购物
中国领先的汽车保养服务平台:途虎养车
2019/10/18 全球购物
如何获取某个日期是当月的最后一天
2013/12/05 面试题
教师党员承诺书
2014/03/25 职场文书
反邪教标语
2014/06/23 职场文书
小学兴趣小组活动总结
2014/07/07 职场文书
学习十八大标语
2014/10/09 职场文书
各类场合主持词开场白范文集锦
2019/08/16 职场文书
pytest进阶教程之fixture函数详解
2021/03/29 Python
使用HTML+Css+transform实现3D导航栏的示例代码
2021/03/31 HTML / CSS
阿里云服务器部署mongodb的详细过程
2021/09/04 MongoDB