python实现随机梯度下降(SGD)


Posted in Python onMarch 24, 2020

使用神经网络进行样本训练,要实现随机梯度下降算法。这里我根据麦子学院彭亮老师的讲解,总结如下,(神经网络的结构在另一篇博客中已经定义):

def SGD(self, training_data, epochs, mini_batch_size, eta, test_data=None):
 if test_data:
  n_test = len(test_data)#有多少个测试集
  n = len(training_data)
  for j in xrange(epochs):
   random.shuffle(training_data)
   mini_batches = [
    training_data[k:k+mini_batch_size] 
    for k in xrange(0,n,mini_batch_size)]
   for mini_batch in mini_batches:
    self.update_mini_batch(mini_batch, eta)
   if test_data:
    print "Epoch {0}: {1}/{2}".format(j, self.evaluate(test_data),n_test)
   else:
    print "Epoch {0} complete".format(j)

其中training_data是训练集,是由很多的tuples(元组)组成。每一个元组(x,y)代表一个实例,x是图像的向量表示,y是图像的类别。
epochs表示训练多少轮。
mini_batch_size表示每一次训练的实例个数。
eta表示学习率。
test_data表示测试集。
比较重要的函数是self.update_mini_batch,他是更新权重和偏置的关键函数,接下来就定义这个函数。

def update_mini_batch(self, mini_batch,eta): 
 nabla_b = [np.zeros(b.shape) for b in self.biases]
 nabla_w = [np.zeros(w.shape) for w in self.weights]
 for x,y in mini_batch:
  delta_nabla_b, delta_nable_w = self.backprop(x,y)#目标函数对b和w的偏导数
  nabla_b = [nb+dnb for nb,dnb in zip(nabla_b,delta_nabla_b)]
  nabla_w = [nw+dnw for nw,dnw in zip(nabla_w,delta_nabla_w)]#累加b和w
 #最终更新权重为
 self.weights = [w-(eta/len(mini_batch))*nw for w, nw in zip(self.weights, nabla_w)]
 self.baises = [b-(eta/len(mini_batch))*nb for b, nb in zip(self.baises, nabla_b)]

这个update_mini_batch函数根据你传入的一些数据进行更新神经网络的权重和偏置。

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
浅析Python中signal包的使用
Nov 13 Python
Python操作mysql数据库实现增删查改功能的方法
Jan 15 Python
python3实现基于用户的协同过滤
May 31 Python
解决pip install xxx报错SyntaxError: invalid syntax的问题
Nov 30 Python
Scrapy框架爬取西刺代理网免费高匿代理的实现代码
Feb 22 Python
Python简单I/O操作示例
Mar 18 Python
windows环境中利用celery实现简单任务队列过程解析
Nov 29 Python
tensorflow -gpu安装方法(不用自己装cuda,cdnn)
Jan 20 Python
利用python 读写csv文件
Sep 10 Python
浅谈Python xlwings 读取Excel文件的正确姿势
Feb 26 Python
python3读取文件指定行的三种方法
May 24 Python
Python学习之os包使用教程详解
Mar 21 Python
Python实现将一个正整数分解质因数的方法分析
Dec 14 #Python
Python随机生成均匀分布在三角形内或者任意多边形内的点
Dec 14 #Python
rabbitmq(中间消息代理)在python中的使用详解
Dec 14 #Python
用python的requests第三方模块抓取王者荣耀所有英雄的皮肤实例
Dec 14 #Python
用Python删除本地目录下某一时间点之前创建的所有文件的实例
Dec 14 #Python
python编程通过蒙特卡洛法计算定积分详解
Dec 13 #Python
Python编程产生非均匀随机数的几种方法代码分享
Dec 13 #Python
You might like
Session的工作方式
2006/10/09 PHP
PHP如何抛出异常处理错误
2011/03/02 PHP
php中使用__autoload()自动加载未定义类的实现代码
2013/02/06 PHP
如何修改和添加Apache的默认站点目录
2013/07/05 PHP
Symfony页面的基本创建实例详解
2015/01/26 PHP
php截取指定2个字符之间字符串的方法
2015/04/15 PHP
PHP计算数组中值的和与乘积的方法(array_sum与array_product函数)
2016/04/01 PHP
微信支付PHP SDK ―― 公众号支付代码详解
2016/09/13 PHP
Laravel框架实现超简单的分页效果示例
2019/02/08 PHP
tagName的使用,留一笔
2006/06/26 Javascript
JavaScript触发器详解
2007/03/10 Javascript
Javascript中valueOf与toString区别浅析
2013/03/19 Javascript
JavaScript实现的一个日期格式化函数分享
2014/12/06 Javascript
详解JavaScript对W3C DOM模版的支持情况
2015/06/16 Javascript
基于jQuery实现交互体验社会化分享代码附源码下载
2016/01/04 Javascript
javascript原生ajax写法分享
2016/04/10 Javascript
jQuery中iframe的操作(点击按钮新增窗口)
2016/04/20 Javascript
原生JS封装Ajax插件(同域、jsonp跨域)
2016/05/03 Javascript
Angular表单验证实例详解
2016/10/20 Javascript
关于微信上网页图片点击全屏放大效果
2016/12/19 Javascript
深入理解Vue Computed计算属性原理
2018/05/29 Javascript
python使用装饰器和线程限制函数执行时间的方法
2015/04/18 Python
VScode编写第一个Python程序HelloWorld步骤
2018/04/06 Python
python 从csv读数据到mysql的实例
2018/06/21 Python
深入浅析Python传值与传址
2018/07/10 Python
用python生成与调用cntk模型代码演示方法
2019/08/26 Python
关于ResNeXt网络的pytorch实现
2020/01/14 Python
Python实现迪杰斯特拉算法并生成最短路径的示例代码
2020/12/01 Python
Python修改DBF文件指定列
2020/12/19 Python
使用phonegap获取位置信息的实现方法
2017/03/31 HTML / CSS
化妆品促销方案
2014/02/24 职场文书
2014小学植树节活动总结
2014/03/10 职场文书
夫妻双方自愿离婚协议书怎么写
2014/12/01 职场文书
2015年度房地产工作总结
2015/04/09 职场文书
2019年冬至:天冷暖人心的问候祝福语大全
2019/12/20 职场文书
在容器中使用nginx搭建上传下载服务器
2022/05/11 Servers