python实现随机梯度下降(SGD)


Posted in Python onMarch 24, 2020

使用神经网络进行样本训练,要实现随机梯度下降算法。这里我根据麦子学院彭亮老师的讲解,总结如下,(神经网络的结构在另一篇博客中已经定义):

def SGD(self, training_data, epochs, mini_batch_size, eta, test_data=None):
 if test_data:
  n_test = len(test_data)#有多少个测试集
  n = len(training_data)
  for j in xrange(epochs):
   random.shuffle(training_data)
   mini_batches = [
    training_data[k:k+mini_batch_size] 
    for k in xrange(0,n,mini_batch_size)]
   for mini_batch in mini_batches:
    self.update_mini_batch(mini_batch, eta)
   if test_data:
    print "Epoch {0}: {1}/{2}".format(j, self.evaluate(test_data),n_test)
   else:
    print "Epoch {0} complete".format(j)

其中training_data是训练集,是由很多的tuples(元组)组成。每一个元组(x,y)代表一个实例,x是图像的向量表示,y是图像的类别。
epochs表示训练多少轮。
mini_batch_size表示每一次训练的实例个数。
eta表示学习率。
test_data表示测试集。
比较重要的函数是self.update_mini_batch,他是更新权重和偏置的关键函数,接下来就定义这个函数。

def update_mini_batch(self, mini_batch,eta): 
 nabla_b = [np.zeros(b.shape) for b in self.biases]
 nabla_w = [np.zeros(w.shape) for w in self.weights]
 for x,y in mini_batch:
  delta_nabla_b, delta_nable_w = self.backprop(x,y)#目标函数对b和w的偏导数
  nabla_b = [nb+dnb for nb,dnb in zip(nabla_b,delta_nabla_b)]
  nabla_w = [nw+dnw for nw,dnw in zip(nabla_w,delta_nabla_w)]#累加b和w
 #最终更新权重为
 self.weights = [w-(eta/len(mini_batch))*nw for w, nw in zip(self.weights, nabla_w)]
 self.baises = [b-(eta/len(mini_batch))*nb for b, nb in zip(self.baises, nabla_b)]

这个update_mini_batch函数根据你传入的一些数据进行更新神经网络的权重和偏置。

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
在Python中使用M2Crypto模块实现AES加密的教程
Apr 08 Python
在Python的while循环中使用else以及循环嵌套的用法
Oct 14 Python
Python实现1-9数组形成的结果为100的所有运算式的示例
Nov 03 Python
基于windows下pip安装python模块时报错总结
Jun 12 Python
OpenCV+python手势识别框架和实例讲解
Aug 03 Python
Python3.5 Pandas模块之Series用法实例分析
Apr 23 Python
Appium+python自动化之连接模拟器并启动淘宝APP(超详解)
Jun 17 Python
Python流程控制 while循环实现解析
Sep 02 Python
解决springboot yml配置 logging.level 报错问题
Feb 21 Python
python批量修改xml属性的实现方式
Mar 05 Python
python3中for循环踩过的坑记录
Dec 14 Python
python解析json数据
Apr 29 Python
Python实现将一个正整数分解质因数的方法分析
Dec 14 #Python
Python随机生成均匀分布在三角形内或者任意多边形内的点
Dec 14 #Python
rabbitmq(中间消息代理)在python中的使用详解
Dec 14 #Python
用python的requests第三方模块抓取王者荣耀所有英雄的皮肤实例
Dec 14 #Python
用Python删除本地目录下某一时间点之前创建的所有文件的实例
Dec 14 #Python
python编程通过蒙特卡洛法计算定积分详解
Dec 13 #Python
Python编程产生非均匀随机数的几种方法代码分享
Dec 13 #Python
You might like
索尼SONY ICF-SW7600GR电路分析与改良
2021/03/02 无线电
使用 eAccelerator加速PHP代码的方法
2007/09/30 PHP
php单件模式结合命令链模式使用说明
2008/09/07 PHP
基于php的微信公众平台开发入门实例
2015/04/15 PHP
php计算税后工资的方法
2015/07/28 PHP
PHP简单处理表单输入的特殊字符的方法
2016/02/03 PHP
PHP中的表达式简述
2016/05/29 PHP
Js callBack 返回前一页的js方法
2008/11/30 Javascript
JavaScript.The.Good.Parts阅读笔记(二)作用域&闭包&减缓全局空间污染
2010/11/16 Javascript
js实现文本框中输入文字页面中div层同步获取文本框内容的方法
2015/03/03 Javascript
深入理解JS addLoadEvent函数
2016/05/20 Javascript
jQuery Easyui datagrid/treegrid 清空数据
2016/07/09 Javascript
Vue组件BootPage实现简单的分页功能
2016/09/12 Javascript
谈谈JS中常遇到的浏览器兼容问题和解决方法
2016/12/17 Javascript
jquery自定义插件结合baiduTemplate.js实现异步刷新(附源码)
2016/12/22 Javascript
JS实现根据密码长度显示安全条功能
2017/03/08 Javascript
Vue 表单控件绑定的实现示例
2017/08/11 Javascript
vue注册组件的几种方式总结
2018/03/08 Javascript
vue ssr 实现方式(学习笔记)
2019/01/18 Javascript
vue插件mescroll.js实现移动端上拉加载和下拉刷新
2019/03/07 Javascript
一篇文章介绍redux、react-redux、redux-saga总结
2019/05/23 Javascript
element的el-table中记录滚动条位置的示例代码
2019/11/06 Javascript
JS实现的进制转换,浮点数相加,数字判断操作示例
2019/11/09 Javascript
[01:18:21]EG vs TNC Supermajor小组赛B组败者组第一轮 BO3 第一场 6.2
2018/06/03 DOTA
django接入新浪微博OAuth的方法
2015/06/29 Python
python3 读取Excel表格中的数据
2018/10/16 Python
Python3 中作为一等对象的函数解析
2019/12/11 Python
Django admin管理工具TabularInline类用法详解
2020/05/14 Python
基于HTML5+CSS3实现简单的时钟效果
2017/09/11 HTML / CSS
美国婴儿用品及配件购买网站:Munchkin
2019/04/03 全球购物
Dyson戴森波兰官网:Dyson.pl
2019/08/05 全球购物
如何查找网页漏洞
2016/06/22 面试题
精神文明建设先进工作者事迹材料
2014/05/02 职场文书
营销与策划实训报告
2014/11/05 职场文书
个人承诺书格式范文
2015/04/29 职场文书
MYSQL(电话号码,身份证)数据脱敏的实现
2021/05/28 MySQL