python机器学习之神经网络实现


Posted in Python onOctober 13, 2018

神经网络在机器学习中有很大的应用,甚至涉及到方方面面。本文主要是简单介绍一下神经网络的基本理论概念和推算。同时也会介绍一下神经网络在数据分类方面的应用。

首先,当我们建立一个回归和分类模型的时候,无论是用最小二乘法(OLS)还是最大似然值(MLE)都用来使得残差达到最小。因此我们在建立模型的时候,都会有一个loss function。

而在神经网络里也不例外,也有个类似的loss function。

对回归而言:

python机器学习之神经网络实现

对分类而言:

python机器学习之神经网络实现

然后同样方法,对于W开始求导,求导为零就可以求出极值来。

关于式子中的W。我们在这里以三层的神经网络为例。先介绍一下神经网络的相关参数。

python机器学习之神经网络实现

第一层是输入层,第二层是隐藏层,第三层是输出层。

在X1,X2经过W1的加权后,达到隐藏层,然后经过W2的加权,到达输出层

其中,

python机器学习之神经网络实现

我们有:

python机器学习之神经网络实现

至此,我们建立了一个初级的三层神经网络。

当我们要求其的loss function最小时,我们需要逆向来求,也就是所谓的backpropagation。

我们要分别对W1和W2进行求导,然后求出其极值。

从右手边开始逆推,首先对W2进行求导。

代入损失函数公式:

python机器学习之神经网络实现

python机器学习之神经网络实现

然后,我们进行化简:

python机器学习之神经网络实现

化简到这里,我们同理再对W1进行求导。

python机器学习之神经网络实现

我们可以发现当我们在做bp网络时候,有一个逆推回去的误差项,其决定了loss function 的最终大小。

在实际的运算当中,我们会用到梯度求解,来求出极值点。

python机器学习之神经网络实现

总结一下来说,我们使用向前推进来理顺神经网络做到回归分类等模型。而向后推进来计算他的损失函数,使得参数W有一个最优解。

当然,和线性回归等模型相类似的是,我们也可以加上正则化的项来对W参数进行约束,以免使得模型的偏差太小,而导致在测试集的表现不佳。

python机器学习之神经网络实现

python机器学习之神经网络实现

Python 的实现:

使用了KERAS的库

解决线性回归: 

model.add(Dense(1, input_dim=n_features, activation='linear', use_bias=True))

# Use mean squared error for the loss metric and use the ADAM backprop algorithm
model.compile(loss='mean_squared_error', optimizer='adam')

# Train the network (learn the weights)
# We need to convert from DataFrame to NumpyArray
history = model.fit(X_train.values, y_train.values, epochs=100, 
     batch_size=1, verbose=2, validation_split=0)

解决多重分类问题: 

# create model
model = Sequential()
model.add(Dense(64, activation='relu', input_dim=n_features))
model.add(Dropout(0.5))
model.add(Dense(64, activation='relu'))
model.add(Dropout(0.5))
# Softmax output layer
model.add(Dense(7, activation='softmax'))

model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])

model.fit(X_train.values, y_train.values, epochs=20, batch_size=16)

y_pred = model.predict(X_test.values)

y_te = np.argmax(y_test.values, axis = 1)
y_pr = np.argmax(y_pred, axis = 1)

print(np.unique(y_pr))

print(classification_report(y_te, y_pr))

print(confusion_matrix(y_te, y_pr))

当我们选取最优参数时候,有很多种解决的途径。这里就介绍一种是gridsearchcv的方法,这是一种暴力检索的方法,遍历所有的设定参数来求得最优参数。

from sklearn.model_selection import GridSearchCV

def create_model(optimizer='rmsprop'):
 model = Sequential()
 model.add(Dense(64, activation='relu', input_dim=n_features))
 model.add(Dropout(0.5))
 model.add(Dense(64, activation='relu'))
 model.add(Dropout(0.5))
 model.add(Dense(7, activation='softmax'))
 model.compile(loss='categorical_crossentropy', optimizer=optimizer, metrics=['accuracy'])
 
 return model

model = KerasClassifier(build_fn=create_model, verbose=0)

optimizers = ['rmsprop']
epochs = [5, 10, 15]
batches = [128]


param_grid = dict(optimizer=optimizers, epochs=epochs, batch_size=batches, verbose=['2'])
grid = GridSearchCV(estimator=model, param_grid=param_grid)

grid.fit(X_train.values, y_train.values)

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python局域网ip扫描示例分享
Apr 03 Python
Python中函数的参数定义和可变参数用法实例分析
Jun 04 Python
玩转python爬虫之cookie使用方法
Feb 17 Python
python实现unicode转中文及转换默认编码的方法
Apr 29 Python
浅谈python之新式类
Aug 12 Python
python读取和保存图片5种方法对比
Sep 12 Python
详解python深浅拷贝区别
Jun 24 Python
Python包,__init__.py功能与用法分析
Jan 07 Python
tensorflow dataset.shuffle、dataset.batch、dataset.repeat顺序区别详解
Jun 03 Python
Selenium关闭INFO:CONSOLE提示的解决
Dec 07 Python
教你怎么用python爬取爱奇艺热门电影
May 20 Python
详解Python自动化之文件自动化处理
Jun 21 Python
Python pyinotify模块实现对文档的实时监控功能方法
Oct 13 #Python
基于pycharm导入模块显示不存在的解决方法
Oct 13 #Python
解决PyCharm import torch包失败的问题
Oct 13 #Python
python3+requests接口自动化session操作方法
Oct 13 #Python
解决pycharm无法识别本地site-packages的问题
Oct 13 #Python
解决PyCharm同目录下导入模块会报错的问题
Oct 13 #Python
python中单例常用的几种实现方法总结
Oct 13 #Python
You might like
用PHP进行MySQL删除记录操作代码
2008/06/07 PHP
php中常用的预定义变量小结
2012/05/09 PHP
php Ubb代码编辑器函数代码
2012/07/05 PHP
解析link_mysql的php版
2013/06/30 PHP
php比较相似字符串的方法
2015/06/05 PHP
详解PHP swoole process的使用方法
2017/08/26 PHP
jQuery实现动画效果的实例代码
2013/05/07 Javascript
JavaScript常用全局属性与方法记录积累
2013/07/03 Javascript
解析Jquery中如何把一段html代码动态写入到DIV中(实例说明)
2013/07/09 Javascript
移动节点的jquery代码
2014/01/13 Javascript
nodejs开发微博实例
2015/03/25 NodeJs
jQuery实现首页图片淡入淡出效果的方法
2015/06/10 Javascript
使用nodejs下载风景壁纸
2017/02/05 NodeJs
select自定义小三角样式代码(实用总结)
2017/08/18 Javascript
Node.js使用Koa搭建 基础项目
2018/01/08 Javascript
vue+element-ui动态生成多级表头的方法
2018/08/28 Javascript
使用vue-cli3 创建vue项目并配置VS Code 自动代码格式化 vue语法高亮问题
2019/05/14 Javascript
详解javascript中var与ES6规范中let、const区别与用法
2020/01/11 Javascript
JavaScript单线程和任务队列原理解析
2020/02/04 Javascript
[01:11:37]完美世界DOTA2联赛PWL S2 SZ vs FTD.C 第一场 11.19
2020/11/19 DOTA
详解Python中的__init__和__new__
2014/03/12 Python
在Mac OS上使用mod_wsgi连接Python与Apache服务器
2015/12/24 Python
Python接口自动化测试的实现
2020/08/28 Python
Python数据可视化常用4大绘图库原理详解
2020/10/23 Python
python中二分查找法的实现方法
2020/12/06 Python
CSS3 中的@keyframes介绍
2014/09/02 HTML / CSS
CSS3的RGBA中关于整数和百分比值的转换
2015/08/04 HTML / CSS
CSS3中Animation属性的使用详解
2015/08/06 HTML / CSS
H5页面适配iPhoneX(就是那么简单)
2019/12/02 HTML / CSS
Levi’s美国官网:美国著名的牛仔裤品牌
2016/08/19 全球购物
雅萌 (YA-MAN) :日本美容家电领域的龙头企业
2017/05/12 全球购物
银行奉献演讲稿
2014/09/16 职场文书
重阳节标语大全
2014/10/07 职场文书
学术会议通知范文
2015/04/15 职场文书
公司回复函格式
2015/07/14 职场文书
Python爬虫进阶之Beautiful Soup库详解
2021/04/29 Python