python机器学习之神经网络实现


Posted in Python onOctober 13, 2018

神经网络在机器学习中有很大的应用,甚至涉及到方方面面。本文主要是简单介绍一下神经网络的基本理论概念和推算。同时也会介绍一下神经网络在数据分类方面的应用。

首先,当我们建立一个回归和分类模型的时候,无论是用最小二乘法(OLS)还是最大似然值(MLE)都用来使得残差达到最小。因此我们在建立模型的时候,都会有一个loss function。

而在神经网络里也不例外,也有个类似的loss function。

对回归而言:

python机器学习之神经网络实现

对分类而言:

python机器学习之神经网络实现

然后同样方法,对于W开始求导,求导为零就可以求出极值来。

关于式子中的W。我们在这里以三层的神经网络为例。先介绍一下神经网络的相关参数。

python机器学习之神经网络实现

第一层是输入层,第二层是隐藏层,第三层是输出层。

在X1,X2经过W1的加权后,达到隐藏层,然后经过W2的加权,到达输出层

其中,

python机器学习之神经网络实现

我们有:

python机器学习之神经网络实现

至此,我们建立了一个初级的三层神经网络。

当我们要求其的loss function最小时,我们需要逆向来求,也就是所谓的backpropagation。

我们要分别对W1和W2进行求导,然后求出其极值。

从右手边开始逆推,首先对W2进行求导。

代入损失函数公式:

python机器学习之神经网络实现

python机器学习之神经网络实现

然后,我们进行化简:

python机器学习之神经网络实现

化简到这里,我们同理再对W1进行求导。

python机器学习之神经网络实现

我们可以发现当我们在做bp网络时候,有一个逆推回去的误差项,其决定了loss function 的最终大小。

在实际的运算当中,我们会用到梯度求解,来求出极值点。

python机器学习之神经网络实现

总结一下来说,我们使用向前推进来理顺神经网络做到回归分类等模型。而向后推进来计算他的损失函数,使得参数W有一个最优解。

当然,和线性回归等模型相类似的是,我们也可以加上正则化的项来对W参数进行约束,以免使得模型的偏差太小,而导致在测试集的表现不佳。

python机器学习之神经网络实现

python机器学习之神经网络实现

Python 的实现:

使用了KERAS的库

解决线性回归: 

model.add(Dense(1, input_dim=n_features, activation='linear', use_bias=True))

# Use mean squared error for the loss metric and use the ADAM backprop algorithm
model.compile(loss='mean_squared_error', optimizer='adam')

# Train the network (learn the weights)
# We need to convert from DataFrame to NumpyArray
history = model.fit(X_train.values, y_train.values, epochs=100, 
     batch_size=1, verbose=2, validation_split=0)

解决多重分类问题: 

# create model
model = Sequential()
model.add(Dense(64, activation='relu', input_dim=n_features))
model.add(Dropout(0.5))
model.add(Dense(64, activation='relu'))
model.add(Dropout(0.5))
# Softmax output layer
model.add(Dense(7, activation='softmax'))

model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])

model.fit(X_train.values, y_train.values, epochs=20, batch_size=16)

y_pred = model.predict(X_test.values)

y_te = np.argmax(y_test.values, axis = 1)
y_pr = np.argmax(y_pred, axis = 1)

print(np.unique(y_pr))

print(classification_report(y_te, y_pr))

print(confusion_matrix(y_te, y_pr))

当我们选取最优参数时候,有很多种解决的途径。这里就介绍一种是gridsearchcv的方法,这是一种暴力检索的方法,遍历所有的设定参数来求得最优参数。

from sklearn.model_selection import GridSearchCV

def create_model(optimizer='rmsprop'):
 model = Sequential()
 model.add(Dense(64, activation='relu', input_dim=n_features))
 model.add(Dropout(0.5))
 model.add(Dense(64, activation='relu'))
 model.add(Dropout(0.5))
 model.add(Dense(7, activation='softmax'))
 model.compile(loss='categorical_crossentropy', optimizer=optimizer, metrics=['accuracy'])
 
 return model

model = KerasClassifier(build_fn=create_model, verbose=0)

optimizers = ['rmsprop']
epochs = [5, 10, 15]
batches = [128]


param_grid = dict(optimizer=optimizers, epochs=epochs, batch_size=batches, verbose=['2'])
grid = GridSearchCV(estimator=model, param_grid=param_grid)

grid.fit(X_train.values, y_train.values)

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
忘记ftp密码使用python ftplib库暴力破解密码的方法示例
Jan 22 Python
python smtplib模块发送SSL/TLS安全邮件实例
Apr 08 Python
Python写入CSV文件的方法
Jul 08 Python
详解Python各大聊天系统的屏蔽脏话功能原理
Dec 01 Python
python3爬取各类天气信息
Feb 24 Python
用python与文件进行交互的方法
Mar 01 Python
15行Python代码实现网易云热门歌单实例教程
Mar 10 Python
使用python socket分发大文件的实现方法
Jul 08 Python
CentOS 7如何实现定时执行python脚本
Jun 24 Python
Python数据分析库pandas高级接口dt的使用详解
Dec 11 Python
python 多线程爬取壁纸网站的示例
Feb 20 Python
python_tkinter弹出对话框创建
Mar 20 Python
Python pyinotify模块实现对文档的实时监控功能方法
Oct 13 #Python
基于pycharm导入模块显示不存在的解决方法
Oct 13 #Python
解决PyCharm import torch包失败的问题
Oct 13 #Python
python3+requests接口自动化session操作方法
Oct 13 #Python
解决pycharm无法识别本地site-packages的问题
Oct 13 #Python
解决PyCharm同目录下导入模块会报错的问题
Oct 13 #Python
python中单例常用的几种实现方法总结
Oct 13 #Python
You might like
PHP下使用CURL方式POST数据至API接口的代码
2013/02/14 PHP
关于PHP模板Smarty的初级使用方法以及心得分享
2013/06/21 PHP
php处理restful请求的路由类分享
2014/02/27 PHP
一些技巧性实用js代码小结
2009/10/14 Javascript
jQuery 获取URL参数的插件
2010/03/04 Javascript
如何确保JavaScript的执行顺序 之jQuery.html深度分析
2011/03/03 Javascript
javascript中节点的最近的相关节点访问方法
2013/03/20 Javascript
js验证身份证号有效性并提示对应信息
2015/10/19 Javascript
Javascript之Number对象介绍
2016/06/07 Javascript
jQuery原理系列-常用Dom操作详解
2016/06/07 Javascript
自动化测试读写64位操作系统的注册表
2016/08/15 Javascript
JavaScript实现图片拖曳效果
2017/09/08 Javascript
node版本管理工具n包使用教程详解
2018/11/09 Javascript
微信小程序MUI导航栏透明渐变功能示例(通过改变opacity实现)
2019/01/24 Javascript
JS实现前端路由功能示例【原生路由】
2020/05/29 Javascript
python 实现红包随机生成算法的简单实例
2017/01/04 Python
Pandas过滤dataframe中包含特定字符串的数据方法
2018/11/07 Python
Python 处理图片像素点的实例
2019/01/08 Python
详解python中递归函数
2019/04/16 Python
python文件转为exe文件的方法及用法详解
2019/07/08 Python
pandas DataFrame行或列的删除方法的实现示例
2019/08/02 Python
PyCharm MySQL可视化Database配置过程图解
2020/06/09 Python
python实现杨辉三角的几种方法代码实例
2021/03/02 Python
html5自带表单验证体验优化及提示气泡修改功能
2017/09/12 HTML / CSS
写一个函数,求一个字符串的长度。在main函数中输入字符串,并输出其长度
2015/11/18 面试题
别名指示符是什么
2012/10/08 面试题
自荐信格式范文
2013/10/07 职场文书
2015年度党风廉政建设工作情况汇报
2015/01/02 职场文书
单位实习鉴定评语
2015/01/04 职场文书
向雷锋同志学习倡议书
2015/04/27 职场文书
小学教师教学随笔
2015/08/14 职场文书
党务工作者主要事迹材料
2015/11/03 职场文书
SQLServer2008提示评估期已过解决方案
2021/04/12 SQL Server
python实现调用摄像头并拍照发邮箱
2021/04/27 Python
python通过函数名调用函数的几种方法总结
2021/06/07 Python
Python面向对象之内置函数相关知识总结
2021/06/24 Python