python机器学习之神经网络实现


Posted in Python onOctober 13, 2018

神经网络在机器学习中有很大的应用,甚至涉及到方方面面。本文主要是简单介绍一下神经网络的基本理论概念和推算。同时也会介绍一下神经网络在数据分类方面的应用。

首先,当我们建立一个回归和分类模型的时候,无论是用最小二乘法(OLS)还是最大似然值(MLE)都用来使得残差达到最小。因此我们在建立模型的时候,都会有一个loss function。

而在神经网络里也不例外,也有个类似的loss function。

对回归而言:

python机器学习之神经网络实现

对分类而言:

python机器学习之神经网络实现

然后同样方法,对于W开始求导,求导为零就可以求出极值来。

关于式子中的W。我们在这里以三层的神经网络为例。先介绍一下神经网络的相关参数。

python机器学习之神经网络实现

第一层是输入层,第二层是隐藏层,第三层是输出层。

在X1,X2经过W1的加权后,达到隐藏层,然后经过W2的加权,到达输出层

其中,

python机器学习之神经网络实现

我们有:

python机器学习之神经网络实现

至此,我们建立了一个初级的三层神经网络。

当我们要求其的loss function最小时,我们需要逆向来求,也就是所谓的backpropagation。

我们要分别对W1和W2进行求导,然后求出其极值。

从右手边开始逆推,首先对W2进行求导。

代入损失函数公式:

python机器学习之神经网络实现

python机器学习之神经网络实现

然后,我们进行化简:

python机器学习之神经网络实现

化简到这里,我们同理再对W1进行求导。

python机器学习之神经网络实现

我们可以发现当我们在做bp网络时候,有一个逆推回去的误差项,其决定了loss function 的最终大小。

在实际的运算当中,我们会用到梯度求解,来求出极值点。

python机器学习之神经网络实现

总结一下来说,我们使用向前推进来理顺神经网络做到回归分类等模型。而向后推进来计算他的损失函数,使得参数W有一个最优解。

当然,和线性回归等模型相类似的是,我们也可以加上正则化的项来对W参数进行约束,以免使得模型的偏差太小,而导致在测试集的表现不佳。

python机器学习之神经网络实现

python机器学习之神经网络实现

Python 的实现:

使用了KERAS的库

解决线性回归: 

model.add(Dense(1, input_dim=n_features, activation='linear', use_bias=True))

# Use mean squared error for the loss metric and use the ADAM backprop algorithm
model.compile(loss='mean_squared_error', optimizer='adam')

# Train the network (learn the weights)
# We need to convert from DataFrame to NumpyArray
history = model.fit(X_train.values, y_train.values, epochs=100, 
     batch_size=1, verbose=2, validation_split=0)

解决多重分类问题: 

# create model
model = Sequential()
model.add(Dense(64, activation='relu', input_dim=n_features))
model.add(Dropout(0.5))
model.add(Dense(64, activation='relu'))
model.add(Dropout(0.5))
# Softmax output layer
model.add(Dense(7, activation='softmax'))

model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])

model.fit(X_train.values, y_train.values, epochs=20, batch_size=16)

y_pred = model.predict(X_test.values)

y_te = np.argmax(y_test.values, axis = 1)
y_pr = np.argmax(y_pred, axis = 1)

print(np.unique(y_pr))

print(classification_report(y_te, y_pr))

print(confusion_matrix(y_te, y_pr))

当我们选取最优参数时候,有很多种解决的途径。这里就介绍一种是gridsearchcv的方法,这是一种暴力检索的方法,遍历所有的设定参数来求得最优参数。

from sklearn.model_selection import GridSearchCV

def create_model(optimizer='rmsprop'):
 model = Sequential()
 model.add(Dense(64, activation='relu', input_dim=n_features))
 model.add(Dropout(0.5))
 model.add(Dense(64, activation='relu'))
 model.add(Dropout(0.5))
 model.add(Dense(7, activation='softmax'))
 model.compile(loss='categorical_crossentropy', optimizer=optimizer, metrics=['accuracy'])
 
 return model

model = KerasClassifier(build_fn=create_model, verbose=0)

optimizers = ['rmsprop']
epochs = [5, 10, 15]
batches = [128]


param_grid = dict(optimizer=optimizers, epochs=epochs, batch_size=batches, verbose=['2'])
grid = GridSearchCV(estimator=model, param_grid=param_grid)

grid.fit(X_train.values, y_train.values)

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
初学python数组的处理代码
Jan 04 Python
jupyter安装小结
Mar 13 Python
python3.7.0的安装步骤
Aug 27 Python
Scrapy使用的基本流程与实例讲解
Oct 21 Python
python程序封装为win32服务的方法
Mar 07 Python
为什么Python中没有"a++"这种写法
Nov 27 Python
5分钟 Pipenv 上手指南
Dec 20 Python
torch 中各种图像格式转换的实现方法
Dec 26 Python
python使用openCV遍历文件夹里所有视频文件并保存成图片
Jan 14 Python
通过cmd进入python的步骤
Jun 16 Python
Python中for后接else的语法使用
May 18 Python
python中pandas对多列进行分组统计的实现
Jun 18 Python
Python pyinotify模块实现对文档的实时监控功能方法
Oct 13 #Python
基于pycharm导入模块显示不存在的解决方法
Oct 13 #Python
解决PyCharm import torch包失败的问题
Oct 13 #Python
python3+requests接口自动化session操作方法
Oct 13 #Python
解决pycharm无法识别本地site-packages的问题
Oct 13 #Python
解决PyCharm同目录下导入模块会报错的问题
Oct 13 #Python
python中单例常用的几种实现方法总结
Oct 13 #Python
You might like
PHP网页游戏学习之Xnova(ogame)源码解读(十六)
2014/06/30 PHP
php json_encode()函数返回json数据实例代码
2014/10/10 PHP
PHP在同一域名下两个不同的项目做独立登录机制详解
2017/09/22 PHP
js通过googleAIP翻译PHP系统的语言配置的实现代码
2011/10/17 Javascript
window.requestAnimationFrame是什么意思,怎么用
2013/01/13 Javascript
node.js中的fs.fstatSync方法使用说明
2014/12/15 Javascript
深入理解JavaScript系列(30):设计模式之外观模式详解
2015/03/03 Javascript
在javascript中随机数 math random如何生成指定范围数值的随机数
2015/10/21 Javascript
vue实现列表的添加点击
2016/12/29 Javascript
Require.js的基本用法详解
2017/07/03 Javascript
微信小程序基于本地缓存实现点赞功能的方法
2017/12/18 Javascript
微信小程序实现多个按钮的颜色状态转换
2019/02/15 Javascript
element-ui表格合并span-method的实现方法
2019/05/21 Javascript
vue实现简单瀑布流布局
2020/05/28 Javascript
详解Vue3中对VDOM的改进
2020/04/23 Javascript
[42:24]完美世界DOTA2联赛PWL S2 LBZS vs FTD.C 第三场 11.27
2020/12/01 DOTA
Python json 错误xx is not JSON serializable解决办法
2017/03/15 Python
Django学习教程之静态文件的调用详解
2018/05/08 Python
多个应用共存的Django配置方法
2018/05/30 Python
python机器学习之KNN分类算法
2018/08/29 Python
深入浅析Python获取对象信息的函数type()、isinstance()、dir()
2018/09/17 Python
Python使用grequests(gevent+requests)并发发送请求过程解析
2019/09/25 Python
pycharm 设置项目的根目录教程
2020/02/12 Python
Pytorch数据拼接与拆分操作实现图解
2020/04/30 Python
基于python实现地址和经纬度转换
2020/05/19 Python
python实现scrapy爬虫每天定时抓取数据的示例代码
2021/01/27 Python
美国在线家装零售商:Build.com
2016/09/02 全球购物
澳大利亚首屈一指的鞋类品牌:Tony Bianco
2018/03/13 全球购物
介绍一下内联、左联、右联
2013/12/31 面试题
作为网站管理者应当如何防范XSS
2014/08/16 面试题
社区庆中秋节活动方案
2014/02/07 职场文书
服装区域经理岗位职责
2015/04/10 职场文书
教师创先争优承诺书
2015/04/27 职场文书
教你用eclipse连接mysql数据库
2021/04/22 MySQL
mysql事务对效率的影响分析总结
2021/10/24 MySQL
MySQL读取JSON转换的方式
2022/03/18 MySQL