python实现随机梯度下降(SGD)


Posted in Python onMarch 24, 2020

使用神经网络进行样本训练,要实现随机梯度下降算法。这里我根据麦子学院彭亮老师的讲解,总结如下,(神经网络的结构在另一篇博客中已经定义):

def SGD(self, training_data, epochs, mini_batch_size, eta, test_data=None):
 if test_data:
  n_test = len(test_data)#有多少个测试集
  n = len(training_data)
  for j in xrange(epochs):
   random.shuffle(training_data)
   mini_batches = [
    training_data[k:k+mini_batch_size] 
    for k in xrange(0,n,mini_batch_size)]
   for mini_batch in mini_batches:
    self.update_mini_batch(mini_batch, eta)
   if test_data:
    print "Epoch {0}: {1}/{2}".format(j, self.evaluate(test_data),n_test)
   else:
    print "Epoch {0} complete".format(j)

其中training_data是训练集,是由很多的tuples(元组)组成。每一个元组(x,y)代表一个实例,x是图像的向量表示,y是图像的类别。
epochs表示训练多少轮。
mini_batch_size表示每一次训练的实例个数。
eta表示学习率。
test_data表示测试集。
比较重要的函数是self.update_mini_batch,他是更新权重和偏置的关键函数,接下来就定义这个函数。

def update_mini_batch(self, mini_batch,eta): 
 nabla_b = [np.zeros(b.shape) for b in self.biases]
 nabla_w = [np.zeros(w.shape) for w in self.weights]
 for x,y in mini_batch:
  delta_nabla_b, delta_nable_w = self.backprop(x,y)#目标函数对b和w的偏导数
  nabla_b = [nb+dnb for nb,dnb in zip(nabla_b,delta_nabla_b)]
  nabla_w = [nw+dnw for nw,dnw in zip(nabla_w,delta_nabla_w)]#累加b和w
 #最终更新权重为
 self.weights = [w-(eta/len(mini_batch))*nw for w, nw in zip(self.weights, nabla_w)]
 self.baises = [b-(eta/len(mini_batch))*nb for b, nb in zip(self.baises, nabla_b)]

这个update_mini_batch函数根据你传入的一些数据进行更新神经网络的权重和偏置。

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
更改Python命令行交互提示符的方法
Jan 14 Python
探究Python多进程编程下线程之间变量的共享问题
May 05 Python
Flask的图形化管理界面搭建框架Flask-Admin的使用教程
Jun 13 Python
详解python eval函数的妙用
Nov 16 Python
Python使用matplotlib绘制多个图形单独显示的方法示例
Mar 14 Python
快速解决PyCharm无法引用matplotlib的问题
May 24 Python
将Dataframe数据转化为ndarry数据的方法
Jun 28 Python
解决Mac下使用python的坑
Aug 13 Python
Python开发之基于模板匹配的信用卡数字识别功能
Jan 13 Python
django admin管理工具自定义时间区间筛选器DateRangeFilter介绍
May 19 Python
零基础学python应该从哪里入手
Aug 11 Python
class类在python中获取金融数据的实例方法
Dec 10 Python
Python实现将一个正整数分解质因数的方法分析
Dec 14 #Python
Python随机生成均匀分布在三角形内或者任意多边形内的点
Dec 14 #Python
rabbitmq(中间消息代理)在python中的使用详解
Dec 14 #Python
用python的requests第三方模块抓取王者荣耀所有英雄的皮肤实例
Dec 14 #Python
用Python删除本地目录下某一时间点之前创建的所有文件的实例
Dec 14 #Python
python编程通过蒙特卡洛法计算定积分详解
Dec 13 #Python
Python编程产生非均匀随机数的几种方法代码分享
Dec 13 #Python
You might like
一篇入门的php Class 文章
2007/04/04 PHP
一个显示效果非常不错的PHP错误、异常处理类
2014/03/21 PHP
php获取网页请求状态程序示例
2014/06/17 PHP
Thinkphp的volist标签嵌套循环使用教程
2014/07/08 PHP
php的ddos攻击解决方法
2015/01/08 PHP
PHP中的魔术方法总结和使用实例
2015/05/11 PHP
php实现的Curl封装类Curl.class.php用法实例分析
2015/09/25 PHP
PHP实现求连续子数组最大和问题2种解决方法
2017/12/26 PHP
PHP7原生MySQL数据库操作实现代码
2020/07/03 PHP
jQuery中:gt选择器用法实例
2014/12/29 Javascript
jQuery模拟黑客帝国矩阵效果实例
2015/06/28 Javascript
jquery结合html实现中英文页面切换
2016/11/29 Javascript
js遍历json的key和value的实例
2017/01/22 Javascript
JS实现含有中文字符串的友好截取功能分析
2017/03/13 Javascript
常用的几个JQuery代码片段
2017/03/13 Javascript
AngularJS实现进度条功能示例
2017/07/05 Javascript
cocos creator Touch事件应用(触控选择多个子节点的实例)
2017/09/10 Javascript
JavaScript 下载svg图片为png格式
2018/06/21 Javascript
Vue+element-ui 实现表格的分页功能示例
2018/08/18 Javascript
Angular6 Filter实现页面搜索的示例代码
2018/12/02 Javascript
vue 之 css module的使用方法
2018/12/04 Javascript
js面向对象之实现淘宝放大镜
2020/01/15 Javascript
[02:43]DOTA2英雄基础教程 半人马战行者
2014/01/13 DOTA
一个检测OpenSSL心脏出血漏洞的Python脚本分享
2014/04/10 Python
python将.ppm格式图片转换成.jpg格式文件的方法
2018/10/27 Python
python实现简单名片管理系统
2018/11/30 Python
详解CSS3中强大的filter(滤镜)属性
2017/06/29 HTML / CSS
HTML5 在canvas中绘制矩形附效果图
2014/06/23 HTML / CSS
新年寄语大全
2014/04/12 职场文书
关于美容院的活动方案
2014/08/14 职场文书
2014年第四季度入党积极分子思想汇报(十八届四中全会)
2014/11/03 职场文书
2015年度保密工作总结
2015/04/24 职场文书
学生会工作感言
2015/08/07 职场文书
优秀共产党员主要事迹材料
2015/11/05 职场文书
导游词之永泰公主墓
2019/12/04 职场文书
使用Ajax实现无刷新上传文件
2022/04/12 Javascript