使用Keras实现简单线性回归模型操作


Posted in Python onJune 12, 2020

神经网络可以用来模拟回归问题 (regression),实质上是单输入单输出神经网络模型,例如给下面一组数据,用一条线来对数据进行拟合,并可以预测新输入 x 的输出值。

使用Keras实现简单线性回归模型操作

一、详细解读

我们通过这个简单的例子来熟悉Keras构建神经网络的步骤:

1.导入模块并生成数据

首先导入本例子需要的模块,numpy、Matplotlib、和keras.models、keras.layers模块。Sequential是多个网络层的线性堆叠,可以通过向Sequential模型传递一个layer的list来构造该模型,也可以通过.add()方法一个个的将layer加入模型中。layers.Dense 意思是这个神经层是全连接层。

2.建立模型

然后用 Sequential 建立 model,再用 model.add 添加神经层,添加的是 Dense 全连接神经层。参数有两个,(注意此处Keras 2.0.2版本中有变更)一个是输入数据的维度,另一个units代表神经元数,即输出单元数。如果需要添加下一个神经层的时候,不用再定义输入的纬度,因为它默认就把前一层的输出作为当前层的输入。在这个简单的例子里,只需要一层就够了。

3.激活模型

model.compile来激活模型,参数中,误差函数用的是 mse均方误差;优化器用的是 sgd 随机梯度下降法。

4.训练模型

训练的时候用 model.train_on_batch 一批一批的训练 X_train, Y_train。默认的返回值是 cost,每100步输出一下结果。

5.验证模型

用到的函数是 model.evaluate,输入测试集的x和y,输出 cost,weights 和 biases。其中 weights 和 biases 是取在模型的第一层 model.layers[0] 学习到的参数。从学习到的结果你可以看到, weights 比较接近0.5,bias 接近 2。

Weights= [[ 0.49136472]]

biases= [ 2.00405312]

6.可视化学习结果

最后可以画出预测结果,与测试集的值进行对比。

使用Keras实现简单线性回归模型操作

二、完整代码

import numpy as np
np.random.seed(1337) 
from keras.models import Sequential
from keras.layers import Dense
import matplotlib.pyplot as plt
 
# 生成数据
X = np.linspace(-1, 1, 200) #在返回(-1, 1)范围内的等差序列
np.random.shuffle(X) # 打乱顺序
Y = 0.5 * X + 2 + np.random.normal(0, 0.05, (200, )) #生成Y并添加噪声
# plot
plt.scatter(X, Y)
plt.show()
 
X_train, Y_train = X[:160], Y[:160]  # 前160组数据为训练数据集
X_test, Y_test = X[160:], Y[160:]  #后40组数据为测试数据集
 
# 构建神经网络模型
model = Sequential()
model.add(Dense(input_dim=1, units=1))
 
# 选定loss函数和优化器
model.compile(loss='mse', optimizer='sgd')
 
# 训练过程
print('Training -----------')
for step in range(501):
 cost = model.train_on_batch(X_train, Y_train)
 if step % 50 == 0:
  print("After %d trainings, the cost: %f" % (step, cost))
 
# 测试过程
print('\nTesting ------------')
cost = model.evaluate(X_test, Y_test, batch_size=40)
print('test cost:', cost)
W, b = model.layers[0].get_weights()
print('Weights=', W, '\nbiases=', b)
 
# 将训练结果绘出
Y_pred = model.predict(X_test)
plt.scatter(X_test, Y_test)
plt.plot(X_test, Y_pred)
plt.show()

三、其他补充

1. numpy.linspace

numpy.linspace(start, stop, num=50, endpoint=True,retstep=False,dtype=None)

返回等差序列,序列范围在(start,end),生成num个元素的np数组,如果endpoint为False,则生成num+1个但是返回num个,retstep=True则在其后返回步长.

>>> np.linspace(2.0, 3.0, num=5)
array([ 2. , 2.25, 2.5 , 2.75, 3. ])
>>> np.linspace(2.0, 3.0, num=5, endpoint=False)
array([ 2. , 2.2, 2.4, 2.6, 2.8])
>>> np.linspace(2.0, 3.0, num=5, retstep=True)
(array([ 2. , 2.25, 2.5 , 2.75, 3. ]), 0.25)

以上这篇使用Keras实现简单线性回归模型操作就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
在Docker上部署Python的Flask框架的教程
Apr 08 Python
python返回昨天日期的方法
May 13 Python
python使用matplotlib绘制柱状图教程
Feb 08 Python
Python探索之pLSA实现代码
Oct 25 Python
python SMTP实现发送带附件电子邮件
May 22 Python
Python实现字典(dict)的迭代操作示例
Jun 05 Python
win10 64bit下python NLTK安装教程
Sep 19 Python
win10下tensorflow和matplotlib安装教程
Sep 19 Python
Python在终端通过pip安装好包以后在Pycharm中依然无法使用的问题(三种解决方案)
Mar 10 Python
python print 格式化输出,动态指定长度的实现
Apr 12 Python
Python dict的常用方法示例代码
Jun 23 Python
Python pathlib模块使用方法及实例解析
Oct 05 Python
Python实现Keras搭建神经网络训练分类模型教程
Jun 12 #Python
简单了解Python变量作用域正确使用方法
Jun 12 #Python
keras 读取多标签图像数据方式
Jun 12 #Python
Python数据可视化图实现过程详解
Jun 12 #Python
浅谈cv2.imread()和keras.preprocessing中的image.load_img()区别
Jun 12 #Python
升级keras解决load_weights()中的未定义skip_mismatch关键字问题
Jun 12 #Python
解决Tensorflow2.0 tf.keras.Model.load_weights() 报错处理问题
Jun 12 #Python
You might like
php 生成签名及验证签名详解
2016/10/26 PHP
php 反斜杠处理函数addslashes()和stripslashes()实例详解
2016/12/25 PHP
ThinkPHP框架实现的邮箱激活功能示例
2018/06/15 PHP
php写入mysql中文乱码的实例解决方法
2019/09/17 PHP
PHP全局使用Laravel辅助函数dd
2019/12/26 PHP
HTML node相关的一些资料整理
2010/01/01 Javascript
Javascript创建Silverlight Plugin以及自定义nonSilverlight和lowSilverlight样式
2010/06/28 Javascript
javaScript数组迭代方法详解
2016/04/14 Javascript
即将发布的jQuery 3 有哪些新特性
2016/04/14 Javascript
如何解决IONIC页面底部被遮住无法向上滚动问题
2016/09/06 Javascript
JS动态的把左边列表添加到右边的实现代码(可上下移动)
2016/11/17 Javascript
vue组件如何被其他项目引用
2017/04/13 Javascript
Vue cli+mui 区域滚动的实例代码
2018/01/25 Javascript
Vue.js图片预览插件使用详解
2018/08/27 Javascript
微信小程序webview组件交互,内联h5页面并网页实现微信支付实现解析
2019/08/16 Javascript
antd vue 刷新保留当前页面路由,保留选中菜单,保留menu选中操作
2020/08/06 Javascript
如何在selenium中使用js实现定位
2020/08/18 Javascript
关于vue-cli3打包代码后白屏的解决方案
2020/09/02 Javascript
在Ubuntu系统下安装使用Python的GUI工具wxPython
2016/02/18 Python
Python实现快速排序算法及去重的快速排序的简单示例
2016/06/26 Python
Flask框架的学习指南之制作简单blog系统
2016/11/20 Python
Python AES加密模块用法分析
2017/05/22 Python
Python使用 Beanstalkd 做异步任务处理的方法
2018/04/24 Python
浅谈pycharm下找不到sqlalchemy的问题
2018/12/03 Python
pytorch获取vgg16-feature层输出的例子
2019/08/20 Python
详解python内置常用高阶函数(列出了5个常用的)
2020/02/21 Python
python shell命令行中import多层目录下的模块操作
2020/03/09 Python
英国和世界各地鲜花速递专家:Arena Flowers
2018/02/10 全球购物
ManoMano英国:欧洲第一家专注于DIY和园艺市场的电商平台
2020/03/12 全球购物
c/c++某大公司的两道笔试题
2014/02/02 面试题
杭州信雅达系统.NET工程师面试试题
2015/02/08 面试题
服务员岗位责任制
2014/02/11 职场文书
竞选班长的演讲稿
2014/04/24 职场文书
商业融资计划书
2014/04/29 职场文书
导游词之峨眉乐山/兵马俑/北京故宫御花园
2019/09/03 职场文书
Qt数据库应用之实现图片转pdf
2022/06/01 Java/Android