python机器学习之神经网络实现


Posted in Python onOctober 13, 2018

神经网络在机器学习中有很大的应用,甚至涉及到方方面面。本文主要是简单介绍一下神经网络的基本理论概念和推算。同时也会介绍一下神经网络在数据分类方面的应用。

首先,当我们建立一个回归和分类模型的时候,无论是用最小二乘法(OLS)还是最大似然值(MLE)都用来使得残差达到最小。因此我们在建立模型的时候,都会有一个loss function。

而在神经网络里也不例外,也有个类似的loss function。

对回归而言:

python机器学习之神经网络实现

对分类而言:

python机器学习之神经网络实现

然后同样方法,对于W开始求导,求导为零就可以求出极值来。

关于式子中的W。我们在这里以三层的神经网络为例。先介绍一下神经网络的相关参数。

python机器学习之神经网络实现

第一层是输入层,第二层是隐藏层,第三层是输出层。

在X1,X2经过W1的加权后,达到隐藏层,然后经过W2的加权,到达输出层

其中,

python机器学习之神经网络实现

我们有:

python机器学习之神经网络实现

至此,我们建立了一个初级的三层神经网络。

当我们要求其的loss function最小时,我们需要逆向来求,也就是所谓的backpropagation。

我们要分别对W1和W2进行求导,然后求出其极值。

从右手边开始逆推,首先对W2进行求导。

代入损失函数公式:

python机器学习之神经网络实现

python机器学习之神经网络实现

然后,我们进行化简:

python机器学习之神经网络实现

化简到这里,我们同理再对W1进行求导。

python机器学习之神经网络实现

我们可以发现当我们在做bp网络时候,有一个逆推回去的误差项,其决定了loss function 的最终大小。

在实际的运算当中,我们会用到梯度求解,来求出极值点。

python机器学习之神经网络实现

总结一下来说,我们使用向前推进来理顺神经网络做到回归分类等模型。而向后推进来计算他的损失函数,使得参数W有一个最优解。

当然,和线性回归等模型相类似的是,我们也可以加上正则化的项来对W参数进行约束,以免使得模型的偏差太小,而导致在测试集的表现不佳。

python机器学习之神经网络实现

python机器学习之神经网络实现

Python 的实现:

使用了KERAS的库

解决线性回归: 

model.add(Dense(1, input_dim=n_features, activation='linear', use_bias=True))

# Use mean squared error for the loss metric and use the ADAM backprop algorithm
model.compile(loss='mean_squared_error', optimizer='adam')

# Train the network (learn the weights)
# We need to convert from DataFrame to NumpyArray
history = model.fit(X_train.values, y_train.values, epochs=100, 
     batch_size=1, verbose=2, validation_split=0)

解决多重分类问题: 

# create model
model = Sequential()
model.add(Dense(64, activation='relu', input_dim=n_features))
model.add(Dropout(0.5))
model.add(Dense(64, activation='relu'))
model.add(Dropout(0.5))
# Softmax output layer
model.add(Dense(7, activation='softmax'))

model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])

model.fit(X_train.values, y_train.values, epochs=20, batch_size=16)

y_pred = model.predict(X_test.values)

y_te = np.argmax(y_test.values, axis = 1)
y_pr = np.argmax(y_pred, axis = 1)

print(np.unique(y_pr))

print(classification_report(y_te, y_pr))

print(confusion_matrix(y_te, y_pr))

当我们选取最优参数时候,有很多种解决的途径。这里就介绍一种是gridsearchcv的方法,这是一种暴力检索的方法,遍历所有的设定参数来求得最优参数。

from sklearn.model_selection import GridSearchCV

def create_model(optimizer='rmsprop'):
 model = Sequential()
 model.add(Dense(64, activation='relu', input_dim=n_features))
 model.add(Dropout(0.5))
 model.add(Dense(64, activation='relu'))
 model.add(Dropout(0.5))
 model.add(Dense(7, activation='softmax'))
 model.compile(loss='categorical_crossentropy', optimizer=optimizer, metrics=['accuracy'])
 
 return model

model = KerasClassifier(build_fn=create_model, verbose=0)

optimizers = ['rmsprop']
epochs = [5, 10, 15]
batches = [128]


param_grid = dict(optimizer=optimizers, epochs=epochs, batch_size=batches, verbose=['2'])
grid = GridSearchCV(estimator=model, param_grid=param_grid)

grid.fit(X_train.values, y_train.values)

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python中的map、reduce和filter浅析
Apr 26 Python
Python实现的选择排序算法示例
Nov 29 Python
Django JWT Token RestfulAPI用户认证详解
Jan 23 Python
python 求一个列表中所有元素的乘积实例
Jun 11 Python
使用python实现滑动验证码功能
Aug 05 Python
python2使用bs4爬取腾讯社招过程解析
Aug 14 Python
使用selenium和pyquery爬取京东商品列表过程解析
Aug 15 Python
Python closure闭包解释及其注意点详解
Aug 28 Python
Python3直接爬取图片URL并保存示例
Dec 18 Python
解决Python logging模块无法正常输出日志的问题
Feb 21 Python
Django中的模型类设计及展示示例详解
May 29 Python
Python实现生活常识解答机器人
Jun 28 Python
Python pyinotify模块实现对文档的实时监控功能方法
Oct 13 #Python
基于pycharm导入模块显示不存在的解决方法
Oct 13 #Python
解决PyCharm import torch包失败的问题
Oct 13 #Python
python3+requests接口自动化session操作方法
Oct 13 #Python
解决pycharm无法识别本地site-packages的问题
Oct 13 #Python
解决PyCharm同目录下导入模块会报错的问题
Oct 13 #Python
python中单例常用的几种实现方法总结
Oct 13 #Python
You might like
PHP 检查扩展库或函数是否可用的代码
2010/04/06 PHP
基于Zend的Captcha机制的应用
2013/05/02 PHP
10个php函数实用却不常见
2015/10/13 PHP
适用于初学者的简易PHP文件上传类
2015/10/29 PHP
PHP文件缓存smarty模板应用实例分析
2016/02/26 PHP
tp5实现微信小程序多图片上传到服务器功能
2018/07/16 PHP
XML的代替者----JSON
2007/07/21 Javascript
functional继承模式 摘自javascript:the good parts
2011/06/20 Javascript
防止按钮在短时间内被多次点击的方法
2014/03/10 Javascript
让JavaScript和其它资源并发下载的方法
2014/10/16 Javascript
JavaScript中使用数组方法汇总
2016/02/16 Javascript
js停止冒泡和阻止浏览器默认行为的简单方法
2016/05/15 Javascript
最全的Javascript编码规范(推荐)
2016/06/22 Javascript
原生js实现旋转木马轮播图效果
2017/02/27 Javascript
详解如何使用Node.js编写命令工具——以vue-cli为例
2017/06/29 Javascript
JS设置随机出现2个数字的实例代码
2017/07/19 Javascript
ReactNative Image组件使用详解
2017/08/07 Javascript
vue中子组件调用兄弟组件方法
2018/07/06 Javascript
layer插件select选中默认值的方法
2018/08/14 Javascript
详解微信小程序之scroll-view的flex布局问题
2019/01/16 Javascript
JS通过ajax + 多列布局 + 自动加载实现瀑布流效果
2019/05/30 Javascript
Element-ui中元素滚动时el-option超出元素区域的问题
2019/05/30 Javascript
JS如何实现手机端输入验证码效果
2020/05/13 Javascript
[04:03]辉夜杯主赛事 12月25日RECAP精彩回顾
2015/12/26 DOTA
Python实现识别手写数字大纲
2018/01/29 Python
Python json模块dumps、loads操作示例
2018/09/06 Python
python+flask实现API的方法
2018/11/21 Python
python自定义函数def的应用详解
2020/06/03 Python
机电专业大学生职业规划书范文
2014/02/25 职场文书
海飞丝广告词
2014/03/20 职场文书
合伙经营协议书
2014/04/18 职场文书
求职信范文大全
2014/05/26 职场文书
国际经济贸易专业自荐信
2014/06/13 职场文书
婚礼伴郎致辞
2015/07/28 职场文书
Css预编语言及区别详解
2021/04/25 HTML / CSS
mysql连接查询中and与where的区别浅析
2021/07/01 MySQL