使用Keras实现简单线性回归模型操作


Posted in Python onJune 12, 2020

神经网络可以用来模拟回归问题 (regression),实质上是单输入单输出神经网络模型,例如给下面一组数据,用一条线来对数据进行拟合,并可以预测新输入 x 的输出值。

使用Keras实现简单线性回归模型操作

一、详细解读

我们通过这个简单的例子来熟悉Keras构建神经网络的步骤:

1.导入模块并生成数据

首先导入本例子需要的模块,numpy、Matplotlib、和keras.models、keras.layers模块。Sequential是多个网络层的线性堆叠,可以通过向Sequential模型传递一个layer的list来构造该模型,也可以通过.add()方法一个个的将layer加入模型中。layers.Dense 意思是这个神经层是全连接层。

2.建立模型

然后用 Sequential 建立 model,再用 model.add 添加神经层,添加的是 Dense 全连接神经层。参数有两个,(注意此处Keras 2.0.2版本中有变更)一个是输入数据的维度,另一个units代表神经元数,即输出单元数。如果需要添加下一个神经层的时候,不用再定义输入的纬度,因为它默认就把前一层的输出作为当前层的输入。在这个简单的例子里,只需要一层就够了。

3.激活模型

model.compile来激活模型,参数中,误差函数用的是 mse均方误差;优化器用的是 sgd 随机梯度下降法。

4.训练模型

训练的时候用 model.train_on_batch 一批一批的训练 X_train, Y_train。默认的返回值是 cost,每100步输出一下结果。

5.验证模型

用到的函数是 model.evaluate,输入测试集的x和y,输出 cost,weights 和 biases。其中 weights 和 biases 是取在模型的第一层 model.layers[0] 学习到的参数。从学习到的结果你可以看到, weights 比较接近0.5,bias 接近 2。

Weights= [[ 0.49136472]]

biases= [ 2.00405312]

6.可视化学习结果

最后可以画出预测结果,与测试集的值进行对比。

使用Keras实现简单线性回归模型操作

二、完整代码

import numpy as np
np.random.seed(1337) 
from keras.models import Sequential
from keras.layers import Dense
import matplotlib.pyplot as plt
 
# 生成数据
X = np.linspace(-1, 1, 200) #在返回(-1, 1)范围内的等差序列
np.random.shuffle(X) # 打乱顺序
Y = 0.5 * X + 2 + np.random.normal(0, 0.05, (200, )) #生成Y并添加噪声
# plot
plt.scatter(X, Y)
plt.show()
 
X_train, Y_train = X[:160], Y[:160]  # 前160组数据为训练数据集
X_test, Y_test = X[160:], Y[160:]  #后40组数据为测试数据集
 
# 构建神经网络模型
model = Sequential()
model.add(Dense(input_dim=1, units=1))
 
# 选定loss函数和优化器
model.compile(loss='mse', optimizer='sgd')
 
# 训练过程
print('Training -----------')
for step in range(501):
 cost = model.train_on_batch(X_train, Y_train)
 if step % 50 == 0:
  print("After %d trainings, the cost: %f" % (step, cost))
 
# 测试过程
print('\nTesting ------------')
cost = model.evaluate(X_test, Y_test, batch_size=40)
print('test cost:', cost)
W, b = model.layers[0].get_weights()
print('Weights=', W, '\nbiases=', b)
 
# 将训练结果绘出
Y_pred = model.predict(X_test)
plt.scatter(X_test, Y_test)
plt.plot(X_test, Y_pred)
plt.show()

三、其他补充

1. numpy.linspace

numpy.linspace(start, stop, num=50, endpoint=True,retstep=False,dtype=None)

返回等差序列,序列范围在(start,end),生成num个元素的np数组,如果endpoint为False,则生成num+1个但是返回num个,retstep=True则在其后返回步长.

>>> np.linspace(2.0, 3.0, num=5)
array([ 2. , 2.25, 2.5 , 2.75, 3. ])
>>> np.linspace(2.0, 3.0, num=5, endpoint=False)
array([ 2. , 2.2, 2.4, 2.6, 2.8])
>>> np.linspace(2.0, 3.0, num=5, retstep=True)
(array([ 2. , 2.25, 2.5 , 2.75, 3. ]), 0.25)

以上这篇使用Keras实现简单线性回归模型操作就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python二分法实现实例
Nov 21 Python
Django1.3添加app提示模块不存在的解决方法
Aug 26 Python
Python编程中的文件读写及相关的文件对象方法讲解
Jan 19 Python
python实现发送邮件功能
Jul 22 Python
python 获取当天凌晨零点的时间戳方法
May 22 Python
python实现字符串加密 生成唯一固定长度字符串
Mar 22 Python
Python用Try语句捕获异常的实例方法
Jun 26 Python
python实现DEM数据的阴影生成的方法
Jul 23 Python
Python整数与Numpy数据溢出问题解决
Sep 11 Python
python字符串替换re.sub()实例解析
Feb 09 Python
python 函数嵌套及多函数共同运行知识点讲解
Mar 03 Python
Python通过队列来实现进程间通信的示例
Oct 14 Python
Python实现Keras搭建神经网络训练分类模型教程
Jun 12 #Python
简单了解Python变量作用域正确使用方法
Jun 12 #Python
keras 读取多标签图像数据方式
Jun 12 #Python
Python数据可视化图实现过程详解
Jun 12 #Python
浅谈cv2.imread()和keras.preprocessing中的image.load_img()区别
Jun 12 #Python
升级keras解决load_weights()中的未定义skip_mismatch关键字问题
Jun 12 #Python
解决Tensorflow2.0 tf.keras.Model.load_weights() 报错处理问题
Jun 12 #Python
You might like
php chr() ord()中文截取乱码问题解决方法
2008/09/08 PHP
PHP 显示客户端IP与服务器IP的代码
2010/10/12 PHP
PHP读取txt文件的内容并赋值给数组的代码
2011/11/03 PHP
mysql,mysqli,PDO的各自不同介绍
2012/09/19 PHP
php实现xml转换数组的方法示例
2017/02/03 PHP
基于PHP的加载类操作以及其他两种魔术方法的应用实例
2017/08/28 PHP
thinkPHP框架中layer.js的封装与使用方法示例
2019/01/18 PHP
ExtJs的Date格式字符代码
2010/12/30 Javascript
Jquery 实现grid绑定模板
2015/01/28 Javascript
Jquery操作cookie记住用户名
2016/03/29 Javascript
浅析创建javascript对象的方法
2016/05/13 Javascript
Vue数据驱动模拟实现5
2017/01/13 Javascript
JS表单数据验证的正则表达式(常用)
2017/02/18 Javascript
详解JS中遍历语法的比较
2017/04/07 Javascript
vue数据传递--我有特殊的实现技巧
2018/03/20 Javascript
vue渲染时闪烁{{}}的问题及解决方法
2018/03/28 Javascript
js prototype深入理解及应用实例分析
2019/11/25 Javascript
javascript设计模式 ? 迭代器模式原理与用法实例分析
2020/04/17 Javascript
vue3.0中setup使用(两种用法)
2020/12/02 Vue.js
vue打开新窗口并实现传参的图文实例
2021/03/04 Vue.js
[03:09]DOTA2亚洲邀请赛 LGD战队出场宣传片
2015/02/07 DOTA
利用python实现简易版的贪吃蛇游戏(面向python小白)
2018/12/30 Python
详解python 破解网站反爬虫的两种简单方法
2020/02/09 Python
Python 自由定制表格的实现示例
2020/03/20 Python
Python基于jieba, wordcloud库生成中文词云
2020/05/13 Python
Node.js 和 Python之间该选择哪个?
2020/08/05 Python
不可轻视HTML5!App三年内将被html5顶替彻底消失
2015/11/18 HTML / CSS
HTML5实现预览本地图片
2016/02/17 HTML / CSS
Sephora丝芙兰菲律宾官方网站:购买化妆品和护肤品
2017/04/05 全球购物
Yves Rocher伊夫·黎雪美国官网:法国始创植物美肌1959
2019/01/09 全球购物
马来西亚在线购物:POPLOOK.com
2019/12/09 全球购物
SAZAC的动物连体衣和动物睡衣:Kigurumi Shop
2020/03/14 全球购物
商场总经理岗位职责
2014/02/03 职场文书
领导欢迎词范文
2015/01/26 职场文书
活动费用申请报告
2015/05/15 职场文书
《女娲补天》读后感5篇
2019/12/31 职场文书