python实现随机梯度下降(SGD)


Posted in Python onMarch 24, 2020

使用神经网络进行样本训练,要实现随机梯度下降算法。这里我根据麦子学院彭亮老师的讲解,总结如下,(神经网络的结构在另一篇博客中已经定义):

def SGD(self, training_data, epochs, mini_batch_size, eta, test_data=None):
 if test_data:
  n_test = len(test_data)#有多少个测试集
  n = len(training_data)
  for j in xrange(epochs):
   random.shuffle(training_data)
   mini_batches = [
    training_data[k:k+mini_batch_size] 
    for k in xrange(0,n,mini_batch_size)]
   for mini_batch in mini_batches:
    self.update_mini_batch(mini_batch, eta)
   if test_data:
    print "Epoch {0}: {1}/{2}".format(j, self.evaluate(test_data),n_test)
   else:
    print "Epoch {0} complete".format(j)

其中training_data是训练集,是由很多的tuples(元组)组成。每一个元组(x,y)代表一个实例,x是图像的向量表示,y是图像的类别。
epochs表示训练多少轮。
mini_batch_size表示每一次训练的实例个数。
eta表示学习率。
test_data表示测试集。
比较重要的函数是self.update_mini_batch,他是更新权重和偏置的关键函数,接下来就定义这个函数。

def update_mini_batch(self, mini_batch,eta): 
 nabla_b = [np.zeros(b.shape) for b in self.biases]
 nabla_w = [np.zeros(w.shape) for w in self.weights]
 for x,y in mini_batch:
  delta_nabla_b, delta_nable_w = self.backprop(x,y)#目标函数对b和w的偏导数
  nabla_b = [nb+dnb for nb,dnb in zip(nabla_b,delta_nabla_b)]
  nabla_w = [nw+dnw for nw,dnw in zip(nabla_w,delta_nabla_w)]#累加b和w
 #最终更新权重为
 self.weights = [w-(eta/len(mini_batch))*nw for w, nw in zip(self.weights, nabla_w)]
 self.baises = [b-(eta/len(mini_batch))*nb for b, nb in zip(self.baises, nabla_b)]

这个update_mini_batch函数根据你传入的一些数据进行更新神经网络的权重和偏置。

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
使用Python的Tornado框架实现一个简单的WebQQ机器人
Apr 24 Python
python登录pop3邮件服务器接收邮件的方法
Apr 30 Python
Python中关于使用模块的基础知识
May 24 Python
Python实现读取TXT文件数据并存进内置数据库SQLite3的方法
Aug 08 Python
python3+PyQt5实现柱状图
Apr 24 Python
Python 正则表达式 re.match/re.search/re.sub的使用解析
Jul 22 Python
Python Pandas对缺失值的处理方法
Sep 27 Python
Python Scrapy框架:通用爬虫之CrawlSpider用法简单示例
Apr 11 Python
如何在python中执行另一个py文件
Apr 30 Python
六种酷炫Python运行进度条效果的实现代码
Jul 17 Python
python 实现的车牌识别项目
Jan 25 Python
Pyside2中嵌入Matplotlib的绘图的实现
Feb 22 Python
Python实现将一个正整数分解质因数的方法分析
Dec 14 #Python
Python随机生成均匀分布在三角形内或者任意多边形内的点
Dec 14 #Python
rabbitmq(中间消息代理)在python中的使用详解
Dec 14 #Python
用python的requests第三方模块抓取王者荣耀所有英雄的皮肤实例
Dec 14 #Python
用Python删除本地目录下某一时间点之前创建的所有文件的实例
Dec 14 #Python
python编程通过蒙特卡洛法计算定积分详解
Dec 13 #Python
Python编程产生非均匀随机数的几种方法代码分享
Dec 13 #Python
You might like
深入PHP数据加密详解
2013/06/18 PHP
实例介绍PHP的Reflection反射机制
2014/08/05 PHP
php中ftp_chdir与ftp_cdup函数用法
2014/11/18 PHP
WordPress中用于获取文章作者与分类信息的方法整理
2015/12/17 PHP
PHP实现基于栈的后缀表达式求值功能
2017/11/10 PHP
php伪静态验证码不显示的解决方案
2019/09/26 PHP
在html页面中包含共享页面的方法
2008/10/24 Javascript
基于jquery的商品展示放大镜
2010/08/07 Javascript
js 创建快捷方式的代码(fso)
2010/11/19 Javascript
jquery 选择器引擎sizzle浅析
2013/02/06 Javascript
JavaScript事件处理器中的event参数使用介绍
2013/05/24 Javascript
javascript 10进制和62进制的相互转换
2014/07/31 Javascript
javascript结合CSS实现苹果开关按钮特效
2015/04/07 Javascript
JavaScript的RequireJS库入门指南
2015/07/01 Javascript
javascript实现移动端上的触屏拖拽功能
2016/03/04 Javascript
深入理解关于javascript中apply()和call()方法的区别
2016/04/12 Javascript
基于JavaScript实现Tab选项卡切换效果
2016/11/24 Javascript
nodejs 最新版安装npm 的使用详解
2018/01/18 NodeJs
微信小程序App生命周期详解
2018/01/31 Javascript
vue watch监听对象及对应值的变化详解
2018/02/24 Javascript
Vue $mount实战之实现消息弹窗组件
2019/04/22 Javascript
es6中reduce的基本使用方法
2019/09/10 Javascript
vue prop属性传值与传引用示例
2019/11/13 Javascript
VUE项目axios请求头更改Content-Type操作
2020/07/24 Javascript
[01:12:44]VG vs Mineski Supermajor 败者组 BO3 第二场 6.6
2018/06/07 DOTA
Python批量更改文件名的实现方法
2017/10/29 Python
python图书管理系统
2020/04/05 Python
详解Python Matplotlib解决绘图X轴值不按数组排序问题
2019/08/05 Python
Python sklearn库实现PCA教程(以鸢尾花分类为例)
2020/02/24 Python
Python gevent协程切换实现详解
2020/09/14 Python
Django 用户认证Auth组件的使用
2020/11/30 Python
Ann Taylor官方网站:美国最大的女性产品制造商之一
2016/09/14 全球购物
ghd澳大利亚官方网站:英国最受欢迎的美发工具品牌
2018/05/21 全球购物
化石印度尼西亚在线商店:Fossil Indonesia
2019/03/11 全球购物
刮痧观后感
2015/06/05 职场文书
解析Redis Cluster原理
2021/06/21 Redis