python实现随机梯度下降(SGD)


Posted in Python onMarch 24, 2020

使用神经网络进行样本训练,要实现随机梯度下降算法。这里我根据麦子学院彭亮老师的讲解,总结如下,(神经网络的结构在另一篇博客中已经定义):

def SGD(self, training_data, epochs, mini_batch_size, eta, test_data=None):
 if test_data:
  n_test = len(test_data)#有多少个测试集
  n = len(training_data)
  for j in xrange(epochs):
   random.shuffle(training_data)
   mini_batches = [
    training_data[k:k+mini_batch_size] 
    for k in xrange(0,n,mini_batch_size)]
   for mini_batch in mini_batches:
    self.update_mini_batch(mini_batch, eta)
   if test_data:
    print "Epoch {0}: {1}/{2}".format(j, self.evaluate(test_data),n_test)
   else:
    print "Epoch {0} complete".format(j)

其中training_data是训练集,是由很多的tuples(元组)组成。每一个元组(x,y)代表一个实例,x是图像的向量表示,y是图像的类别。
epochs表示训练多少轮。
mini_batch_size表示每一次训练的实例个数。
eta表示学习率。
test_data表示测试集。
比较重要的函数是self.update_mini_batch,他是更新权重和偏置的关键函数,接下来就定义这个函数。

def update_mini_batch(self, mini_batch,eta): 
 nabla_b = [np.zeros(b.shape) for b in self.biases]
 nabla_w = [np.zeros(w.shape) for w in self.weights]
 for x,y in mini_batch:
  delta_nabla_b, delta_nable_w = self.backprop(x,y)#目标函数对b和w的偏导数
  nabla_b = [nb+dnb for nb,dnb in zip(nabla_b,delta_nabla_b)]
  nabla_w = [nw+dnw for nw,dnw in zip(nabla_w,delta_nabla_w)]#累加b和w
 #最终更新权重为
 self.weights = [w-(eta/len(mini_batch))*nw for w, nw in zip(self.weights, nabla_w)]
 self.baises = [b-(eta/len(mini_batch))*nb for b, nb in zip(self.baises, nabla_b)]

这个update_mini_batch函数根据你传入的一些数据进行更新神经网络的权重和偏置。

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python实现生成简单的Makefile文件代码示例
Mar 10 Python
Python读取Excel的方法实例分析
Jul 11 Python
Python程序中用csv模块来操作csv文件的基本使用教程
Mar 03 Python
python时间日期函数与利用pandas进行时间序列处理详解
Mar 13 Python
python+webdriver自动化环境搭建步骤详解
Jun 03 Python
在python中利用numpy求解多项式以及多项式拟合的方法
Jul 03 Python
Django restframework 框架认证、权限、限流用法示例
Dec 21 Python
python @propert装饰器使用方法原理解析
Dec 25 Python
如何使用Python自动生成报表并以邮件发送
Oct 15 Python
详解Python魔法方法之描述符类
May 26 Python
教你怎么用Python操作MySql数据库
May 31 Python
python接口测试返回数据为字典取值方式
Feb 12 Python
Python实现将一个正整数分解质因数的方法分析
Dec 14 #Python
Python随机生成均匀分布在三角形内或者任意多边形内的点
Dec 14 #Python
rabbitmq(中间消息代理)在python中的使用详解
Dec 14 #Python
用python的requests第三方模块抓取王者荣耀所有英雄的皮肤实例
Dec 14 #Python
用Python删除本地目录下某一时间点之前创建的所有文件的实例
Dec 14 #Python
python编程通过蒙特卡洛法计算定积分详解
Dec 13 #Python
Python编程产生非均匀随机数的几种方法代码分享
Dec 13 #Python
You might like
thinkphp3.2.2实现生成多张缩略图的方法
2014/12/19 PHP
php查询相似度最高的字符串的方法
2015/03/12 PHP
php简单防盗链实现方法
2015/07/29 PHP
thinkphp微信开之安全模式消息加密解密不成功的解决办法
2015/12/02 PHP
prototype.js的Ajax对象
2006/09/23 Javascript
仅用[]()+!等符号就足以实现几乎任意Javascript代码
2010/03/01 Javascript
jQuery实现div浮动层跟随页面滚动效果
2014/02/11 Javascript
Angular实现form自动布局
2016/01/28 Javascript
BootStrap下拉框在firefox浏览器界面不友好的解决方案
2016/08/18 Javascript
Bootstrap模态窗口源码解析
2017/02/08 Javascript
Ionic学习日记实现验证码倒计时
2018/02/08 Javascript
vue $set 给数据赋值的实例
2019/11/09 Javascript
动态实现element ui的el-table某列数据不同样式的示例
2021/01/22 Javascript
[16:56]heroes英雄教学 司夜刺客
2014/09/18 DOTA
[04:49]期待西雅图之战 2016国际邀请赛中国区预选赛WINGS战队赛后采访
2016/06/29 DOTA
使用70行Python代码实现一个递归下降解析器的教程
2015/04/17 Python
Python检测网站链接是否已存在
2016/04/07 Python
python实现八大排序算法(1)
2017/09/14 Python
python生成随机图形验证码详解
2017/11/08 Python
基于Python socket的端口扫描程序实例代码
2018/02/09 Python
python3获取文件中url内容并下载代码实例
2019/12/27 Python
Pytorch: 自定义网络层实例
2020/01/07 Python
Django框架models使用group by详解
2020/03/11 Python
Python如何爬取51cto数据并存入MySQL
2020/08/25 Python
草莓网化妆品日本站:Strawberrynet日本
2017/10/20 全球购物
StubHub中国:购买和出售全球活动门票
2020/01/01 全球购物
仓管员岗位责任制
2014/02/19 职场文书
办公室主任竞聘演讲稿
2014/05/15 职场文书
行政秘书工作自我鉴定
2014/09/15 职场文书
商业用房租赁协议书
2014/10/13 职场文书
实训报告范文大全
2014/11/04 职场文书
小浪底导游词
2015/02/12 职场文书
应聘教师自荐信
2015/03/26 职场文书
复试通知单模板
2015/04/24 职场文书
家长反馈意见及建议
2015/06/03 职场文书
Django Paginator分页器的使用示例
2021/06/23 Python