python机器学习之神经网络实现


Posted in Python onOctober 13, 2018

神经网络在机器学习中有很大的应用,甚至涉及到方方面面。本文主要是简单介绍一下神经网络的基本理论概念和推算。同时也会介绍一下神经网络在数据分类方面的应用。

首先,当我们建立一个回归和分类模型的时候,无论是用最小二乘法(OLS)还是最大似然值(MLE)都用来使得残差达到最小。因此我们在建立模型的时候,都会有一个loss function。

而在神经网络里也不例外,也有个类似的loss function。

对回归而言:

python机器学习之神经网络实现

对分类而言:

python机器学习之神经网络实现

然后同样方法,对于W开始求导,求导为零就可以求出极值来。

关于式子中的W。我们在这里以三层的神经网络为例。先介绍一下神经网络的相关参数。

python机器学习之神经网络实现

第一层是输入层,第二层是隐藏层,第三层是输出层。

在X1,X2经过W1的加权后,达到隐藏层,然后经过W2的加权,到达输出层

其中,

python机器学习之神经网络实现

我们有:

python机器学习之神经网络实现

至此,我们建立了一个初级的三层神经网络。

当我们要求其的loss function最小时,我们需要逆向来求,也就是所谓的backpropagation。

我们要分别对W1和W2进行求导,然后求出其极值。

从右手边开始逆推,首先对W2进行求导。

代入损失函数公式:

python机器学习之神经网络实现

python机器学习之神经网络实现

然后,我们进行化简:

python机器学习之神经网络实现

化简到这里,我们同理再对W1进行求导。

python机器学习之神经网络实现

我们可以发现当我们在做bp网络时候,有一个逆推回去的误差项,其决定了loss function 的最终大小。

在实际的运算当中,我们会用到梯度求解,来求出极值点。

python机器学习之神经网络实现

总结一下来说,我们使用向前推进来理顺神经网络做到回归分类等模型。而向后推进来计算他的损失函数,使得参数W有一个最优解。

当然,和线性回归等模型相类似的是,我们也可以加上正则化的项来对W参数进行约束,以免使得模型的偏差太小,而导致在测试集的表现不佳。

python机器学习之神经网络实现

python机器学习之神经网络实现

Python 的实现:

使用了KERAS的库

解决线性回归: 

model.add(Dense(1, input_dim=n_features, activation='linear', use_bias=True))

# Use mean squared error for the loss metric and use the ADAM backprop algorithm
model.compile(loss='mean_squared_error', optimizer='adam')

# Train the network (learn the weights)
# We need to convert from DataFrame to NumpyArray
history = model.fit(X_train.values, y_train.values, epochs=100, 
     batch_size=1, verbose=2, validation_split=0)

解决多重分类问题: 

# create model
model = Sequential()
model.add(Dense(64, activation='relu', input_dim=n_features))
model.add(Dropout(0.5))
model.add(Dense(64, activation='relu'))
model.add(Dropout(0.5))
# Softmax output layer
model.add(Dense(7, activation='softmax'))

model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])

model.fit(X_train.values, y_train.values, epochs=20, batch_size=16)

y_pred = model.predict(X_test.values)

y_te = np.argmax(y_test.values, axis = 1)
y_pr = np.argmax(y_pred, axis = 1)

print(np.unique(y_pr))

print(classification_report(y_te, y_pr))

print(confusion_matrix(y_te, y_pr))

当我们选取最优参数时候,有很多种解决的途径。这里就介绍一种是gridsearchcv的方法,这是一种暴力检索的方法,遍历所有的设定参数来求得最优参数。

from sklearn.model_selection import GridSearchCV

def create_model(optimizer='rmsprop'):
 model = Sequential()
 model.add(Dense(64, activation='relu', input_dim=n_features))
 model.add(Dropout(0.5))
 model.add(Dense(64, activation='relu'))
 model.add(Dropout(0.5))
 model.add(Dense(7, activation='softmax'))
 model.compile(loss='categorical_crossentropy', optimizer=optimizer, metrics=['accuracy'])
 
 return model

model = KerasClassifier(build_fn=create_model, verbose=0)

optimizers = ['rmsprop']
epochs = [5, 10, 15]
batches = [128]


param_grid = dict(optimizer=optimizers, epochs=epochs, batch_size=batches, verbose=['2'])
grid = GridSearchCV(estimator=model, param_grid=param_grid)

grid.fit(X_train.values, y_train.values)

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python采用requests库模拟登录和抓取数据的简单示例
Jul 05 Python
Python中set与frozenset方法和区别详解
May 23 Python
浅谈Python的文件类型
May 30 Python
Python实现简单求解给定整数的质因数算法示例
Mar 25 Python
浅谈django的render函数的参数问题
Oct 16 Python
python引入不同文件夹下的自定义模块方法
Oct 27 Python
Python合并同一个文件夹下所有PDF文件的方法
Mar 11 Python
浅析Python3中的对象垃圾收集机制
Jun 06 Python
django认证系统实现自定义权限管理的方法
Aug 28 Python
Python reques接口测试框架实现代码
Jul 28 Python
Anaconda的安装与虚拟环境建立
Nov 18 Python
python编写函数注意事项总结
Mar 29 Python
Python pyinotify模块实现对文档的实时监控功能方法
Oct 13 #Python
基于pycharm导入模块显示不存在的解决方法
Oct 13 #Python
解决PyCharm import torch包失败的问题
Oct 13 #Python
python3+requests接口自动化session操作方法
Oct 13 #Python
解决pycharm无法识别本地site-packages的问题
Oct 13 #Python
解决PyCharm同目录下导入模块会报错的问题
Oct 13 #Python
python中单例常用的几种实现方法总结
Oct 13 #Python
You might like
用Flash图形化数据(一)
2006/10/09 PHP
给初学者的30条PHP最佳实践(荒野无灯)
2011/08/02 PHP
php页面缓存ob系列函数介绍
2012/10/18 PHP
探讨PHP函数ip2long转换IP时数值太大产生负数的解决方法
2013/06/06 PHP
ThinkPHP3.1.3版本新特性概述
2014/06/19 PHP
php常用hash加密函数
2014/11/22 PHP
PHP curl CURLOPT_RETURNTRANSFER参数的作用使用实例
2015/02/07 PHP
thinkPHP5实现数据库添加内容的方法
2017/10/25 PHP
JavaScript高级程序设计
2006/12/29 Javascript
围观tangram js库
2010/12/28 Javascript
ASP.NET jQuery 实例16 通过控件CustomValidator验证RadioButtonList
2012/02/03 Javascript
不使用jquery实现js打字效果示例分享
2014/01/19 Javascript
JQuery仿小米手机抢购页面倒计时效果
2014/12/16 Javascript
Bootstrap3 Grid system原理及应用详解
2016/09/30 Javascript
JavaScript实现的选择排序算法实例分析
2017/04/14 Javascript
JS实现的文字间歇循环滚动效果完整示例
2018/02/13 Javascript
vue+axios+promise实际开发用法详解
2018/10/15 Javascript
[01:01]青春无憾,一战成名——DOTA2全国高校联赛开启
2018/02/25 DOTA
Django框架中render_to_response()函数的使用方法
2015/07/16 Python
python的re正则表达式实例代码
2018/01/24 Python
python检测主机的连通性并记录到文件的实例
2018/06/21 Python
python opencv实现运动检测
2018/07/10 Python
python数据持久存储 pickle模块的基本使用方法解析
2019/08/30 Python
利用OpenCV和Python实现查找图片差异
2019/12/19 Python
助理政工师申报材料
2014/06/03 职场文书
营销学习心得体会
2014/09/12 职场文书
代领学位证书毕业证书委托书
2014/09/30 职场文书
群众路线个人剖析材料
2014/10/07 职场文书
党组织领导班子整改方案
2014/10/25 职场文书
关于群众路线的心得体会
2014/11/05 职场文书
单位介绍信格式
2015/01/31 职场文书
工程服务质量承诺书
2015/04/29 职场文书
24句精辟的现实社会语录,句句扎心,道尽人性
2019/08/29 职场文书
CSS3 制作精美的定价表
2021/04/06 HTML / CSS
使用Redis实现秒杀功能的简单方法
2021/05/08 Redis
python保存图片的四个常用方法
2022/02/28 Python