python实现随机梯度下降(SGD)


Posted in Python onMarch 24, 2020

使用神经网络进行样本训练,要实现随机梯度下降算法。这里我根据麦子学院彭亮老师的讲解,总结如下,(神经网络的结构在另一篇博客中已经定义):

def SGD(self, training_data, epochs, mini_batch_size, eta, test_data=None):
 if test_data:
  n_test = len(test_data)#有多少个测试集
  n = len(training_data)
  for j in xrange(epochs):
   random.shuffle(training_data)
   mini_batches = [
    training_data[k:k+mini_batch_size] 
    for k in xrange(0,n,mini_batch_size)]
   for mini_batch in mini_batches:
    self.update_mini_batch(mini_batch, eta)
   if test_data:
    print "Epoch {0}: {1}/{2}".format(j, self.evaluate(test_data),n_test)
   else:
    print "Epoch {0} complete".format(j)

其中training_data是训练集,是由很多的tuples(元组)组成。每一个元组(x,y)代表一个实例,x是图像的向量表示,y是图像的类别。
epochs表示训练多少轮。
mini_batch_size表示每一次训练的实例个数。
eta表示学习率。
test_data表示测试集。
比较重要的函数是self.update_mini_batch,他是更新权重和偏置的关键函数,接下来就定义这个函数。

def update_mini_batch(self, mini_batch,eta): 
 nabla_b = [np.zeros(b.shape) for b in self.biases]
 nabla_w = [np.zeros(w.shape) for w in self.weights]
 for x,y in mini_batch:
  delta_nabla_b, delta_nable_w = self.backprop(x,y)#目标函数对b和w的偏导数
  nabla_b = [nb+dnb for nb,dnb in zip(nabla_b,delta_nabla_b)]
  nabla_w = [nw+dnw for nw,dnw in zip(nabla_w,delta_nabla_w)]#累加b和w
 #最终更新权重为
 self.weights = [w-(eta/len(mini_batch))*nw for w, nw in zip(self.weights, nabla_w)]
 self.baises = [b-(eta/len(mini_batch))*nb for b, nb in zip(self.baises, nabla_b)]

这个update_mini_batch函数根据你传入的一些数据进行更新神经网络的权重和偏置。

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python 读写txt文件 json文件的实现方法
Oct 22 Python
Python入门_浅谈字符串的分片与索引、字符串的方法
May 16 Python
在windows下Python打印彩色字体的方法
May 15 Python
Python实现合并两个列表的方法分析
May 28 Python
python实现飞机大战微信小游戏
Mar 21 Python
解决PyCharm import torch包失败的问题
Oct 13 Python
解决python opencv无法显示图片的问题
Oct 28 Python
Python获取时间范围内日期列表和周列表的函数
Aug 05 Python
Python 实现自动获取种子磁力链接方式
Jan 16 Python
快速解决Django关闭Debug模式无法加载media图片与static静态文件
Apr 07 Python
Java Unsafe类实现原理及测试代码
Sep 15 Python
Python使用sql语句对mysql数据库多条件模糊查询的思路详解
Apr 12 Python
Python实现将一个正整数分解质因数的方法分析
Dec 14 #Python
Python随机生成均匀分布在三角形内或者任意多边形内的点
Dec 14 #Python
rabbitmq(中间消息代理)在python中的使用详解
Dec 14 #Python
用python的requests第三方模块抓取王者荣耀所有英雄的皮肤实例
Dec 14 #Python
用Python删除本地目录下某一时间点之前创建的所有文件的实例
Dec 14 #Python
python编程通过蒙特卡洛法计算定积分详解
Dec 13 #Python
Python编程产生非均匀随机数的几种方法代码分享
Dec 13 #Python
You might like
php实现文章置顶功能的方法
2016/10/20 PHP
PHP获取文件扩展名的常用方法小结【五种方式】
2018/04/27 PHP
laravel ORM关联关系中的 with和whereHas用法
2019/10/16 PHP
js 数组实现一个类似ruby的迭代器
2009/10/27 Javascript
JavaScript中匿名函数用法实例
2015/03/23 Javascript
jquery利用json实现页面之间传值的实例解析
2016/12/12 Javascript
ajax的分页查询示例(不刷新页面)
2017/01/11 Javascript
JavaScript使用链式方法封装jQuery中CSS()方法示例
2017/04/07 jQuery
使用jQuery和ajax代替iframe的方法(详解)
2017/04/12 jQuery
Vue官方推荐AJAX组件axios.js使用方法详解与API
2018/10/09 Javascript
详解webpack打包后如何调试的方法步骤
2018/11/07 Javascript
浅谈Javascript中的对象和继承
2019/04/19 Javascript
cordova+vue+webapp使用html5获取地理位置的方法
2019/07/06 Javascript
原生JS实现留言板
2020/03/26 Javascript
JavaScript实现网页计算器功能
2020/10/29 Javascript
python迭代器与生成器详解
2016/03/10 Python
Python获取文件所在目录和文件名的方法
2017/01/12 Python
Python中使用支持向量机(SVM)算法
2017/12/26 Python
使用PM2+nginx部署python项目的方法示例
2018/11/07 Python
python正向最大匹配分词和逆向最大匹配分词的实例
2018/11/14 Python
将string类型的数据类型转换为spark rdd时报错的解决方法
2019/02/18 Python
Python中正则表达式的用法总结
2019/02/22 Python
Python配置文件处理的方法教程
2019/08/29 Python
python中time库的实例使用方法
2019/10/31 Python
python实现mean-shift聚类算法
2020/06/10 Python
Ryderwear澳洲官网:澳大利亚高端健身训练装备品牌
2018/09/18 全球购物
英国手机壳购买网站:Case Hut
2019/04/11 全球购物
初中生三年学习生活的自我评价
2013/11/03 职场文书
乡下人家教学反思
2014/02/01 职场文书
给老师的检讨书
2014/02/11 职场文书
个人函授自我鉴定
2014/03/25 职场文书
安康杯竞赛活动总结
2014/05/05 职场文书
法定代表人授权委托书范文
2014/08/02 职场文书
未受刑事制裁公证证明
2014/09/20 职场文书
辞职信格式范文
2015/05/13 职场文书
严以用权专题学习研讨会发言材料
2015/11/09 职场文书