python实现随机梯度下降(SGD)


Posted in Python onMarch 24, 2020

使用神经网络进行样本训练,要实现随机梯度下降算法。这里我根据麦子学院彭亮老师的讲解,总结如下,(神经网络的结构在另一篇博客中已经定义):

def SGD(self, training_data, epochs, mini_batch_size, eta, test_data=None):
 if test_data:
  n_test = len(test_data)#有多少个测试集
  n = len(training_data)
  for j in xrange(epochs):
   random.shuffle(training_data)
   mini_batches = [
    training_data[k:k+mini_batch_size] 
    for k in xrange(0,n,mini_batch_size)]
   for mini_batch in mini_batches:
    self.update_mini_batch(mini_batch, eta)
   if test_data:
    print "Epoch {0}: {1}/{2}".format(j, self.evaluate(test_data),n_test)
   else:
    print "Epoch {0} complete".format(j)

其中training_data是训练集,是由很多的tuples(元组)组成。每一个元组(x,y)代表一个实例,x是图像的向量表示,y是图像的类别。
epochs表示训练多少轮。
mini_batch_size表示每一次训练的实例个数。
eta表示学习率。
test_data表示测试集。
比较重要的函数是self.update_mini_batch,他是更新权重和偏置的关键函数,接下来就定义这个函数。

def update_mini_batch(self, mini_batch,eta): 
 nabla_b = [np.zeros(b.shape) for b in self.biases]
 nabla_w = [np.zeros(w.shape) for w in self.weights]
 for x,y in mini_batch:
  delta_nabla_b, delta_nable_w = self.backprop(x,y)#目标函数对b和w的偏导数
  nabla_b = [nb+dnb for nb,dnb in zip(nabla_b,delta_nabla_b)]
  nabla_w = [nw+dnw for nw,dnw in zip(nabla_w,delta_nabla_w)]#累加b和w
 #最终更新权重为
 self.weights = [w-(eta/len(mini_batch))*nw for w, nw in zip(self.weights, nabla_w)]
 self.baises = [b-(eta/len(mini_batch))*nb for b, nb in zip(self.baises, nabla_b)]

这个update_mini_batch函数根据你传入的一些数据进行更新神经网络的权重和偏置。

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
phpsir 开发 一个检测百度关键字网站排名的python 程序
Sep 17 Python
Python中获取对象信息的方法
Apr 27 Python
python自定义异常实例详解
Jul 11 Python
python GUI实例学习
Nov 21 Python
对pandas中Series的map函数详解
Jul 25 Python
python实现杨氏矩阵查找
Mar 02 Python
Python 进程之间共享数据(全局变量)的方法
Jul 16 Python
Python 异常处理Ⅳ过程图解
Oct 18 Python
python标识符命名规范原理解析
Jan 10 Python
Windows下实现将Pascal VOC转化为TFRecords
Feb 17 Python
python爬虫中抓取指数的实例讲解
Dec 01 Python
python之随机数函数的实现示例
Dec 30 Python
Python实现将一个正整数分解质因数的方法分析
Dec 14 #Python
Python随机生成均匀分布在三角形内或者任意多边形内的点
Dec 14 #Python
rabbitmq(中间消息代理)在python中的使用详解
Dec 14 #Python
用python的requests第三方模块抓取王者荣耀所有英雄的皮肤实例
Dec 14 #Python
用Python删除本地目录下某一时间点之前创建的所有文件的实例
Dec 14 #Python
python编程通过蒙特卡洛法计算定积分详解
Dec 13 #Python
Python编程产生非均匀随机数的几种方法代码分享
Dec 13 #Python
You might like
小偷PHP+Html+缓存
2006/11/25 PHP
对squid中refresh_pattern的一些理解和建议
2009/04/17 PHP
PHP实现搜索相似图片
2015/09/22 PHP
Zend Framework教程之模型Model基本规则和使用方法
2016/03/04 PHP
php获取本机真实IP地址实例代码
2016/03/31 PHP
php简单实现数组分页的方法
2016/04/30 PHP
PHP请求Socket接口测试实例
2016/08/12 PHP
php获取目录下所有文件及目录(多种方法)(推荐)
2019/05/14 PHP
php 多进程编程父进程的阻塞与非阻塞实例分析
2020/02/22 PHP
比较简单的一个符合web标准的JS调用flash方法
2007/11/29 Javascript
[推荐]javascript 面向对象技术基础教程
2009/03/03 Javascript
深入理解Javascript闭包 新手版
2010/12/28 Javascript
js判断一个元素是否为另一个元素的子元素的代码
2012/03/21 Javascript
javascript抖动元素的小例子
2013/10/28 Javascript
jQuery使用prepend()方法在元素前添加内容用法实例
2015/03/26 Javascript
jquery.gridrotator实现响应式图片展示画廊效果
2015/06/23 Javascript
基于AngularJs + Bootstrap + AngularStrap相结合实现省市区联动代码
2016/05/30 Javascript
从源码看angular/material2 中 dialog模块的实现方法
2017/10/18 Javascript
修改vue+webpack run build的路径方法
2018/09/01 Javascript
使用validate.js实现表单数据提交前的验证方法
2018/09/04 Javascript
记一次用ts+vuecli4重构项目的实现
2020/05/21 Javascript
[01:20]DOTA2更新全新英雄 天涯墨客现已加入游戏
2018/08/25 DOTA
基于python的Tkinter实现一个简易计算器
2015/12/31 Python
使用Python进行二进制文件读写的简单方法(推荐)
2016/09/12 Python
[原创]Python入门教程5. 字典基本操作【定义、运算、常用函数】
2018/11/01 Python
Python3爬虫学习入门教程
2018/12/11 Python
解析Python的缩进规则的使用
2019/01/16 Python
python安装pywin32clipboard的操作方法
2019/01/24 Python
Python 字典中的所有方法及用法
2020/06/10 Python
班级口号大全
2014/06/09 职场文书
2014年大学班长工作总结
2014/11/14 职场文书
烈士陵园观后感
2015/06/08 职场文书
2016七夕情人节寄语
2015/12/04 职场文书
Python中异常处理用法
2021/11/27 Python
redis调用二维码时的不断刷新排查分析
2022/04/01 Redis
MySQL数据库Innodb 引擎实现mvcc锁
2022/05/06 MySQL