python实现随机梯度下降(SGD)


Posted in Python onMarch 24, 2020

使用神经网络进行样本训练,要实现随机梯度下降算法。这里我根据麦子学院彭亮老师的讲解,总结如下,(神经网络的结构在另一篇博客中已经定义):

def SGD(self, training_data, epochs, mini_batch_size, eta, test_data=None):
 if test_data:
  n_test = len(test_data)#有多少个测试集
  n = len(training_data)
  for j in xrange(epochs):
   random.shuffle(training_data)
   mini_batches = [
    training_data[k:k+mini_batch_size] 
    for k in xrange(0,n,mini_batch_size)]
   for mini_batch in mini_batches:
    self.update_mini_batch(mini_batch, eta)
   if test_data:
    print "Epoch {0}: {1}/{2}".format(j, self.evaluate(test_data),n_test)
   else:
    print "Epoch {0} complete".format(j)

其中training_data是训练集,是由很多的tuples(元组)组成。每一个元组(x,y)代表一个实例,x是图像的向量表示,y是图像的类别。
epochs表示训练多少轮。
mini_batch_size表示每一次训练的实例个数。
eta表示学习率。
test_data表示测试集。
比较重要的函数是self.update_mini_batch,他是更新权重和偏置的关键函数,接下来就定义这个函数。

def update_mini_batch(self, mini_batch,eta): 
 nabla_b = [np.zeros(b.shape) for b in self.biases]
 nabla_w = [np.zeros(w.shape) for w in self.weights]
 for x,y in mini_batch:
  delta_nabla_b, delta_nable_w = self.backprop(x,y)#目标函数对b和w的偏导数
  nabla_b = [nb+dnb for nb,dnb in zip(nabla_b,delta_nabla_b)]
  nabla_w = [nw+dnw for nw,dnw in zip(nabla_w,delta_nabla_w)]#累加b和w
 #最终更新权重为
 self.weights = [w-(eta/len(mini_batch))*nw for w, nw in zip(self.weights, nabla_w)]
 self.baises = [b-(eta/len(mini_batch))*nb for b, nb in zip(self.baises, nabla_b)]

这个update_mini_batch函数根据你传入的一些数据进行更新神经网络的权重和偏置。

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python实例之wxpython中Frame使用方法
Jun 09 Python
浅析python中SQLAlchemy排序的一个坑
Feb 24 Python
梯度下降法介绍及利用Python实现的方法示例
Jul 12 Python
python监控键盘输入实例代码
Feb 09 Python
python破解zip加密文件的方法
May 31 Python
浅谈python中np.array的shape( ,)与( ,1)的区别
Jun 04 Python
解决django 新增加用户信息出现错误的问题
Jul 28 Python
PyInstaller将Python文件打包为exe后如何反编译(破解源码)以及防止反编译
Apr 15 Python
PyQT5 实现快捷键复制表格数据的方法示例
Jun 19 Python
详解Django ORM引发的数据库N+1性能问题
Oct 12 Python
详解python polyscope库的安装和例程
Nov 13 Python
教你怎么用Python生成九宫格照片
May 20 Python
Python实现将一个正整数分解质因数的方法分析
Dec 14 #Python
Python随机生成均匀分布在三角形内或者任意多边形内的点
Dec 14 #Python
rabbitmq(中间消息代理)在python中的使用详解
Dec 14 #Python
用python的requests第三方模块抓取王者荣耀所有英雄的皮肤实例
Dec 14 #Python
用Python删除本地目录下某一时间点之前创建的所有文件的实例
Dec 14 #Python
python编程通过蒙特卡洛法计算定积分详解
Dec 13 #Python
Python编程产生非均匀随机数的几种方法代码分享
Dec 13 #Python
You might like
php获得当前的脚本网址
2007/12/10 PHP
php5中date()得出的时间为什么不是当前时间的解决方法
2008/06/30 PHP
PHP用SAX解析XML的实现代码与问题分析
2011/08/22 PHP
PHP读取zip文件的方法示例
2016/11/17 PHP
Yii2 批量插入、更新数据实例
2017/03/15 PHP
PHP的Trait机制原理与用法分析
2019/10/18 PHP
用JavaScript 处理 URL 的两个函数代码
2007/08/13 Javascript
JS操作JSON要领详细总结
2013/08/25 Javascript
JSON+HTML实现国家省市联动选择效果
2014/05/18 Javascript
javascript实用方法总结
2015/02/06 Javascript
JS实现漂亮的窗口拖拽效果(可改变大小、最大化、最小化、关闭)
2015/10/10 Javascript
zTree插件下拉树使用入门教程
2016/04/11 Javascript
瀑布流的实现方式(原生js+jquery+css3)
2020/06/28 Javascript
JS转换HTML转义符的方法
2016/08/24 Javascript
Node.js配合node-http-proxy解决本地开发ajax跨域问题
2016/08/31 Javascript
微信小程序 地图定位简单实例
2016/10/14 Javascript
React Router基础使用
2017/01/17 Javascript
vue.js实现备忘录功能的方法
2017/07/10 Javascript
ES6 Promise对象概念及用法实例详解
2019/10/15 Javascript
[46:16]2018DOTA2亚洲邀请赛3月30日 小组赛B组 iG VS VP
2018/03/31 DOTA
Python中实现两个字典(dict)合并的方法
2014/09/23 Python
python多线程之事件Event的使用详解
2018/04/27 Python
selenium+python 对输入框的输入处理方法
2018/10/11 Python
通过python改变图片特定区域的颜色详解
2019/07/15 Python
python实现tail实时查看服务器日志示例
2019/12/24 Python
Brora官网:英国领先的羊绒服装品牌
2019/08/28 全球购物
分解成质因数(如435234=251*17*17*3*2,据说是华为笔试题)
2014/07/16 面试题
Java基础类库面试题
2013/09/04 面试题
演讲稿怎么写
2014/01/07 职场文书
大学毕业感言
2014/01/10 职场文书
网络教育自我鉴定
2014/02/04 职场文书
幼儿园家长评语
2014/02/10 职场文书
护理中职生求职信范文
2014/02/24 职场文书
大学生实习证明
2015/06/16 职场文书
交流会主持词
2015/07/02 职场文书
入党转正申请自我鉴定
2019/06/25 职场文书