python机器学习之神经网络实现


Posted in Python onOctober 13, 2018

神经网络在机器学习中有很大的应用,甚至涉及到方方面面。本文主要是简单介绍一下神经网络的基本理论概念和推算。同时也会介绍一下神经网络在数据分类方面的应用。

首先,当我们建立一个回归和分类模型的时候,无论是用最小二乘法(OLS)还是最大似然值(MLE)都用来使得残差达到最小。因此我们在建立模型的时候,都会有一个loss function。

而在神经网络里也不例外,也有个类似的loss function。

对回归而言:

python机器学习之神经网络实现

对分类而言:

python机器学习之神经网络实现

然后同样方法,对于W开始求导,求导为零就可以求出极值来。

关于式子中的W。我们在这里以三层的神经网络为例。先介绍一下神经网络的相关参数。

python机器学习之神经网络实现

第一层是输入层,第二层是隐藏层,第三层是输出层。

在X1,X2经过W1的加权后,达到隐藏层,然后经过W2的加权,到达输出层

其中,

python机器学习之神经网络实现

我们有:

python机器学习之神经网络实现

至此,我们建立了一个初级的三层神经网络。

当我们要求其的loss function最小时,我们需要逆向来求,也就是所谓的backpropagation。

我们要分别对W1和W2进行求导,然后求出其极值。

从右手边开始逆推,首先对W2进行求导。

代入损失函数公式:

python机器学习之神经网络实现

python机器学习之神经网络实现

然后,我们进行化简:

python机器学习之神经网络实现

化简到这里,我们同理再对W1进行求导。

python机器学习之神经网络实现

我们可以发现当我们在做bp网络时候,有一个逆推回去的误差项,其决定了loss function 的最终大小。

在实际的运算当中,我们会用到梯度求解,来求出极值点。

python机器学习之神经网络实现

总结一下来说,我们使用向前推进来理顺神经网络做到回归分类等模型。而向后推进来计算他的损失函数,使得参数W有一个最优解。

当然,和线性回归等模型相类似的是,我们也可以加上正则化的项来对W参数进行约束,以免使得模型的偏差太小,而导致在测试集的表现不佳。

python机器学习之神经网络实现

python机器学习之神经网络实现

Python 的实现:

使用了KERAS的库

解决线性回归: 

model.add(Dense(1, input_dim=n_features, activation='linear', use_bias=True))

# Use mean squared error for the loss metric and use the ADAM backprop algorithm
model.compile(loss='mean_squared_error', optimizer='adam')

# Train the network (learn the weights)
# We need to convert from DataFrame to NumpyArray
history = model.fit(X_train.values, y_train.values, epochs=100, 
     batch_size=1, verbose=2, validation_split=0)

解决多重分类问题: 

# create model
model = Sequential()
model.add(Dense(64, activation='relu', input_dim=n_features))
model.add(Dropout(0.5))
model.add(Dense(64, activation='relu'))
model.add(Dropout(0.5))
# Softmax output layer
model.add(Dense(7, activation='softmax'))

model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])

model.fit(X_train.values, y_train.values, epochs=20, batch_size=16)

y_pred = model.predict(X_test.values)

y_te = np.argmax(y_test.values, axis = 1)
y_pr = np.argmax(y_pred, axis = 1)

print(np.unique(y_pr))

print(classification_report(y_te, y_pr))

print(confusion_matrix(y_te, y_pr))

当我们选取最优参数时候,有很多种解决的途径。这里就介绍一种是gridsearchcv的方法,这是一种暴力检索的方法,遍历所有的设定参数来求得最优参数。

from sklearn.model_selection import GridSearchCV

def create_model(optimizer='rmsprop'):
 model = Sequential()
 model.add(Dense(64, activation='relu', input_dim=n_features))
 model.add(Dropout(0.5))
 model.add(Dense(64, activation='relu'))
 model.add(Dropout(0.5))
 model.add(Dense(7, activation='softmax'))
 model.compile(loss='categorical_crossentropy', optimizer=optimizer, metrics=['accuracy'])
 
 return model

model = KerasClassifier(build_fn=create_model, verbose=0)

optimizers = ['rmsprop']
epochs = [5, 10, 15]
batches = [128]


param_grid = dict(optimizer=optimizers, epochs=epochs, batch_size=batches, verbose=['2'])
grid = GridSearchCV(estimator=model, param_grid=param_grid)

grid.fit(X_train.values, y_train.values)

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python 基础学习第二弹 类属性和实例属性
Aug 27 Python
pycharm 使用心得(三)Hello world!
Jun 05 Python
python实现井字棋游戏
Mar 30 Python
Python3.2模拟实现webqq登录
Feb 15 Python
详解Python import方法引入模块的实例
Aug 02 Python
Python根据已知邻接矩阵绘制无向图操作示例
Jun 23 Python
详解Python3中的迭代器和生成器及其区别
Oct 09 Python
python射线法判断检测点是否位于区域外接矩形内
Jun 28 Python
pandas数据选取:df[] df.loc[] df.iloc[] df.ix[] df.at[] df.iat[]
Apr 24 Python
让Django的BooleanField支持字符串形式的输入方式
May 20 Python
Python打包exe时各种异常处理方案总结
May 18 Python
Python四款GUI图形界面库介绍
Jun 05 Python
Python pyinotify模块实现对文档的实时监控功能方法
Oct 13 #Python
基于pycharm导入模块显示不存在的解决方法
Oct 13 #Python
解决PyCharm import torch包失败的问题
Oct 13 #Python
python3+requests接口自动化session操作方法
Oct 13 #Python
解决pycharm无法识别本地site-packages的问题
Oct 13 #Python
解决PyCharm同目录下导入模块会报错的问题
Oct 13 #Python
python中单例常用的几种实现方法总结
Oct 13 #Python
You might like
DISCUZ 分页代码
2007/01/02 PHP
ThinkPHP控制器详解
2015/07/27 PHP
利用PHP实现一个简单的用户登记表示例
2017/04/25 PHP
javascript setAttribute, getAttribute 在不同浏览器上的不同表现
2010/08/05 Javascript
谈谈JavaScript中的函数与闭包
2013/04/14 Javascript
javascript中对Attr(dom中属性)的操作示例讲解
2013/12/02 Javascript
js抽奖实现随机抽奖代码效果
2013/12/02 Javascript
使用jQuery简单实现模拟浏览器搜索功能
2014/12/21 Javascript
基于jQuery+PHP+Mysql实现在线拍照和在线浏览照片
2015/09/06 Javascript
javascript常见数字进制转换实例分析
2016/04/21 Javascript
浅谈Javascript数组(推荐)
2016/05/17 Javascript
利用Angularjs实现幻灯片效果
2016/09/07 Javascript
ionic2懒加载配置详解
2017/09/01 Javascript
JS获取当前地理位置的方法
2017/10/25 Javascript
axios对请求各种异常情况处理的封装方法
2018/09/25 Javascript
监听element-ui table滚动事件的方法
2019/03/26 Javascript
layer弹出层取消遮罩的方法
2019/09/25 Javascript
原生js实现贪食蛇小游戏的思路详解
2019/11/26 Javascript
Vue-cli 移动端布局和动画使用详解
2020/08/10 Javascript
[14:50]2018DOTA2亚洲邀请赛开幕式
2018/04/03 DOTA
在DigitalOcean的服务器上部署flaskblog应用
2015/12/19 Python
python实现实时监控文件的方法
2016/08/26 Python
python实现决策树、随机森林的简单原理
2018/03/26 Python
更换Django默认的模板引擎为jinja2的实现方法
2018/05/28 Python
一看就懂得Python的math模块
2018/10/21 Python
python3 cvs将数据读取为字典的方法
2018/12/22 Python
Python机器学习工具scikit-learn的使用笔记
2021/01/28 Python
Python 中的函数装饰器和闭包详解
2021/02/06 Python
h5调用摄像头的实现方法
2016/06/01 HTML / CSS
英国婚礼商城:Wedding Mall
2019/11/02 全球购物
学生打架检讨书大全
2014/01/23 职场文书
2014年公司迎新年活动方案
2014/02/24 职场文书
博士给导师的自荐信
2015/03/06 职场文书
2015年统计员个人工作总结
2015/07/23 职场文书
SQL SERVER触发器详解
2022/02/24 SQL Server
python中数组和列表的简单实例
2022/03/25 Python