python实现随机梯度下降(SGD)


Posted in Python onMarch 24, 2020

使用神经网络进行样本训练,要实现随机梯度下降算法。这里我根据麦子学院彭亮老师的讲解,总结如下,(神经网络的结构在另一篇博客中已经定义):

def SGD(self, training_data, epochs, mini_batch_size, eta, test_data=None):
 if test_data:
  n_test = len(test_data)#有多少个测试集
  n = len(training_data)
  for j in xrange(epochs):
   random.shuffle(training_data)
   mini_batches = [
    training_data[k:k+mini_batch_size] 
    for k in xrange(0,n,mini_batch_size)]
   for mini_batch in mini_batches:
    self.update_mini_batch(mini_batch, eta)
   if test_data:
    print "Epoch {0}: {1}/{2}".format(j, self.evaluate(test_data),n_test)
   else:
    print "Epoch {0} complete".format(j)

其中training_data是训练集,是由很多的tuples(元组)组成。每一个元组(x,y)代表一个实例,x是图像的向量表示,y是图像的类别。
epochs表示训练多少轮。
mini_batch_size表示每一次训练的实例个数。
eta表示学习率。
test_data表示测试集。
比较重要的函数是self.update_mini_batch,他是更新权重和偏置的关键函数,接下来就定义这个函数。

def update_mini_batch(self, mini_batch,eta): 
 nabla_b = [np.zeros(b.shape) for b in self.biases]
 nabla_w = [np.zeros(w.shape) for w in self.weights]
 for x,y in mini_batch:
  delta_nabla_b, delta_nable_w = self.backprop(x,y)#目标函数对b和w的偏导数
  nabla_b = [nb+dnb for nb,dnb in zip(nabla_b,delta_nabla_b)]
  nabla_w = [nw+dnw for nw,dnw in zip(nabla_w,delta_nabla_w)]#累加b和w
 #最终更新权重为
 self.weights = [w-(eta/len(mini_batch))*nw for w, nw in zip(self.weights, nabla_w)]
 self.baises = [b-(eta/len(mini_batch))*nb for b, nb in zip(self.baises, nabla_b)]

这个update_mini_batch函数根据你传入的一些数据进行更新神经网络的权重和偏置。

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python控制台显示时钟的示例
Feb 24 Python
零基础写python爬虫之抓取百度贴吧代码分享
Nov 06 Python
12步入门Python中的decorator装饰器使用方法
Jun 20 Python
python去掉 unicode 字符串前面的u方法
Oct 21 Python
PyCharm代码提示忽略大小写设置方法
Oct 28 Python
解决Python中list里的中文输出到html模板里的问题
Dec 17 Python
python+opencv实现阈值分割
Dec 26 Python
python pandas模块基础学习详解
Jul 03 Python
Linux下通过python获取本机ip方法示例
Sep 06 Python
python中的subprocess.Popen()使用详解
Dec 25 Python
python高级特性简介
Aug 13 Python
Python学习之time模块的基本使用
Jan 17 Python
Python实现将一个正整数分解质因数的方法分析
Dec 14 #Python
Python随机生成均匀分布在三角形内或者任意多边形内的点
Dec 14 #Python
rabbitmq(中间消息代理)在python中的使用详解
Dec 14 #Python
用python的requests第三方模块抓取王者荣耀所有英雄的皮肤实例
Dec 14 #Python
用Python删除本地目录下某一时间点之前创建的所有文件的实例
Dec 14 #Python
python编程通过蒙特卡洛法计算定积分详解
Dec 13 #Python
Python编程产生非均匀随机数的几种方法代码分享
Dec 13 #Python
You might like
PHP Callable强制指定回调类型的方法
2016/08/30 PHP
PHP对XML内容进行修改和删除实例代码
2016/10/26 PHP
jquery text()要注意啦
2009/10/30 Javascript
javascript结合html5 canvas实现(可调画笔颜色/粗细/橡皮)的涂鸦板
2013/04/27 Javascript
jquery设置控件位置的方法
2013/08/21 Javascript
js中cookie的添加、取值、删除示例代码
2013/10/21 Javascript
jQuery拖拽 & 弹出层 介绍与示例
2013/12/27 Javascript
JavaScript按值删除数组元素的方法
2015/04/24 Javascript
详解AngularJS Filter(过滤器)用法
2015/12/28 Javascript
微信小程序 教程之WXSS
2016/10/18 Javascript
H5实现中奖记录逐行滚动切换效果
2017/03/13 Javascript
百度地图JavascriptApi Marker平滑移动及车头指向行径方向
2017/03/13 Javascript
JS实现仿饿了么在浏览器标签页失去焦点时网页Title改变
2017/06/01 Javascript
js登录滑动验证的实现(不滑动无法登陆)
2018/01/03 Javascript
angular写一个列表的选择全选交互组件的示例
2018/01/22 Javascript
Vue实现动态创建和删除数据的方法
2018/03/17 Javascript
详解Require.js与Sea.js的区别
2018/08/05 Javascript
原生js代码能实现call和bind吗
2019/07/31 Javascript
vue 根据选择条件显示指定参数的例子
2019/11/09 Javascript
jQuery实现鼠标放置名字上显示详细内容气泡提示框效果的方法分析
2020/04/04 jQuery
react 不用插件实现数字滚动的效果示例
2020/04/14 Javascript
[48:05]2018DOTA2亚洲邀请赛 3.31 小组赛 B组 VGJ.T vs VP
2018/03/31 DOTA
Python自动扫雷实现方法
2015/07/25 Python
在python2.7中用numpy.reshape 对图像进行切割的方法
2018/12/05 Python
Python操作远程服务器 paramiko模块详细介绍
2019/08/07 Python
利用matplotlib实现根据实时数据动态更新图形
2019/12/13 Python
如何快速理解python的垃圾回收机制
2020/09/01 Python
Python读取图像并显示灰度图的实现
2020/12/01 Python
电子商务专业毕业生工作推荐信
2013/11/17 职场文书
集体婚礼证婚词
2014/01/13 职场文书
幼儿园英语教学反思
2014/01/30 职场文书
人大调研汇报材料
2014/08/14 职场文书
水电维修专业推荐信
2014/09/06 职场文书
付款证明模板
2015/06/19 职场文书
十大好看的穿越动漫排名:《瑞克和莫蒂》第一,国漫《有药》在榜
2022/03/18 日漫
Go 中的空白标识符下划线
2022/03/25 Golang