python机器学习之神经网络实现


Posted in Python onOctober 13, 2018

神经网络在机器学习中有很大的应用,甚至涉及到方方面面。本文主要是简单介绍一下神经网络的基本理论概念和推算。同时也会介绍一下神经网络在数据分类方面的应用。

首先,当我们建立一个回归和分类模型的时候,无论是用最小二乘法(OLS)还是最大似然值(MLE)都用来使得残差达到最小。因此我们在建立模型的时候,都会有一个loss function。

而在神经网络里也不例外,也有个类似的loss function。

对回归而言:

python机器学习之神经网络实现

对分类而言:

python机器学习之神经网络实现

然后同样方法,对于W开始求导,求导为零就可以求出极值来。

关于式子中的W。我们在这里以三层的神经网络为例。先介绍一下神经网络的相关参数。

python机器学习之神经网络实现

第一层是输入层,第二层是隐藏层,第三层是输出层。

在X1,X2经过W1的加权后,达到隐藏层,然后经过W2的加权,到达输出层

其中,

python机器学习之神经网络实现

我们有:

python机器学习之神经网络实现

至此,我们建立了一个初级的三层神经网络。

当我们要求其的loss function最小时,我们需要逆向来求,也就是所谓的backpropagation。

我们要分别对W1和W2进行求导,然后求出其极值。

从右手边开始逆推,首先对W2进行求导。

代入损失函数公式:

python机器学习之神经网络实现

python机器学习之神经网络实现

然后,我们进行化简:

python机器学习之神经网络实现

化简到这里,我们同理再对W1进行求导。

python机器学习之神经网络实现

我们可以发现当我们在做bp网络时候,有一个逆推回去的误差项,其决定了loss function 的最终大小。

在实际的运算当中,我们会用到梯度求解,来求出极值点。

python机器学习之神经网络实现

总结一下来说,我们使用向前推进来理顺神经网络做到回归分类等模型。而向后推进来计算他的损失函数,使得参数W有一个最优解。

当然,和线性回归等模型相类似的是,我们也可以加上正则化的项来对W参数进行约束,以免使得模型的偏差太小,而导致在测试集的表现不佳。

python机器学习之神经网络实现

python机器学习之神经网络实现

Python 的实现:

使用了KERAS的库

解决线性回归: 

model.add(Dense(1, input_dim=n_features, activation='linear', use_bias=True))

# Use mean squared error for the loss metric and use the ADAM backprop algorithm
model.compile(loss='mean_squared_error', optimizer='adam')

# Train the network (learn the weights)
# We need to convert from DataFrame to NumpyArray
history = model.fit(X_train.values, y_train.values, epochs=100, 
     batch_size=1, verbose=2, validation_split=0)

解决多重分类问题: 

# create model
model = Sequential()
model.add(Dense(64, activation='relu', input_dim=n_features))
model.add(Dropout(0.5))
model.add(Dense(64, activation='relu'))
model.add(Dropout(0.5))
# Softmax output layer
model.add(Dense(7, activation='softmax'))

model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])

model.fit(X_train.values, y_train.values, epochs=20, batch_size=16)

y_pred = model.predict(X_test.values)

y_te = np.argmax(y_test.values, axis = 1)
y_pr = np.argmax(y_pred, axis = 1)

print(np.unique(y_pr))

print(classification_report(y_te, y_pr))

print(confusion_matrix(y_te, y_pr))

当我们选取最优参数时候,有很多种解决的途径。这里就介绍一种是gridsearchcv的方法,这是一种暴力检索的方法,遍历所有的设定参数来求得最优参数。

from sklearn.model_selection import GridSearchCV

def create_model(optimizer='rmsprop'):
 model = Sequential()
 model.add(Dense(64, activation='relu', input_dim=n_features))
 model.add(Dropout(0.5))
 model.add(Dense(64, activation='relu'))
 model.add(Dropout(0.5))
 model.add(Dense(7, activation='softmax'))
 model.compile(loss='categorical_crossentropy', optimizer=optimizer, metrics=['accuracy'])
 
 return model

model = KerasClassifier(build_fn=create_model, verbose=0)

optimizers = ['rmsprop']
epochs = [5, 10, 15]
batches = [128]


param_grid = dict(optimizer=optimizers, epochs=epochs, batch_size=batches, verbose=['2'])
grid = GridSearchCV(estimator=model, param_grid=param_grid)

grid.fit(X_train.values, y_train.values)

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python生成器的使用方法
Nov 21 Python
python基础教程之类class定义使用方法
Feb 20 Python
Python中利用sqrt()方法进行平方根计算的教程
May 15 Python
Python中字符串的格式化方法小结
May 03 Python
常见的python正则用法实例讲解
Jun 21 Python
利用python将xml文件解析成html文件的实现方法
Dec 22 Python
Python实现简单生成验证码功能【基于random模块】
Feb 10 Python
Python批量提取PDF文件中文本的脚本
Mar 14 Python
python移位运算的实现
Jul 15 Python
Python实现滑动平均(Moving Average)的例子
Aug 24 Python
Python树莓派学习笔记之UDP传输视频帧操作详解
Nov 15 Python
Python如何使用paramiko模块连接linux
Mar 18 Python
Python pyinotify模块实现对文档的实时监控功能方法
Oct 13 #Python
基于pycharm导入模块显示不存在的解决方法
Oct 13 #Python
解决PyCharm import torch包失败的问题
Oct 13 #Python
python3+requests接口自动化session操作方法
Oct 13 #Python
解决pycharm无法识别本地site-packages的问题
Oct 13 #Python
解决PyCharm同目录下导入模块会报错的问题
Oct 13 #Python
python中单例常用的几种实现方法总结
Oct 13 #Python
You might like
如何在PHP程序中防止盗链
2008/04/09 PHP
zend framework多模块多布局配置
2011/02/26 PHP
php中利用explode函数分割字符串到数组
2014/02/08 PHP
PHP基于SPL实现的迭代器模式示例
2018/04/22 PHP
jQuery 处理表单元素的代码
2010/02/15 Javascript
jquery实现输入框动态增减的实例代码
2013/07/14 Javascript
使用js的replace()方法查找字符示例代码
2013/10/28 Javascript
jquery设置text的值示例(设置文本框 DIV 表单值)
2014/01/06 Javascript
javascript实现微信分享
2014/12/23 Javascript
js的flv视频播放器插件使用方法
2015/06/23 Javascript
js实现的tab标签切换效果代码分享
2015/08/25 Javascript
浅析AngularJS中的指令
2016/03/20 Javascript
深入剖析JavaScript:Object类型
2016/05/10 Javascript
Bootstrap Table服务器分页与在线编辑应用总结
2016/08/08 Javascript
vuex操作state对象的实例代码
2018/04/25 Javascript
详解Vue中使用Echarts的两种方式
2018/07/03 Javascript
详解webpack loader和plugin编写
2018/10/12 Javascript
详解微信小程序开发之formId使用(模板消息)
2019/08/27 Javascript
vue实现员工信息录入功能
2020/06/11 Javascript
python中使用urllib2伪造HTTP报头的2个方法
2014/07/07 Python
详解Python3中的Sequence type的使用
2015/08/01 Python
python读取excel指定列数据并写入到新的excel方法
2018/07/10 Python
Python3 批量扫描端口的例子
2019/07/25 Python
python写程序统计词频的方法
2019/07/29 Python
Python高级特性 切片 迭代解析
2019/08/23 Python
使用Python来做一个屏幕录制工具的操作代码
2020/01/18 Python
HTML5拖拽文件到浏览器并实现文件上传下载功能代码
2013/06/06 HTML / CSS
Hush Puppies澳大利亚官网:舒适的男女休闲和正装鞋
2019/08/24 全球购物
新闻学专业应届生求职信
2013/11/08 职场文书
财务与信息服务专业推荐信
2013/11/28 职场文书
幼儿园中班上学期评语
2014/04/18 职场文书
2014年环境整治工作总结
2014/12/10 职场文书
结婚堵门保证书
2015/05/08 职场文书
教师外出学习心得体会
2016/01/18 职场文书
MySQL基础(一)
2021/04/05 MySQL
mysql查询的控制语句图文详解
2021/04/11 MySQL