python实现随机梯度下降(SGD)


Posted in Python onMarch 24, 2020

使用神经网络进行样本训练,要实现随机梯度下降算法。这里我根据麦子学院彭亮老师的讲解,总结如下,(神经网络的结构在另一篇博客中已经定义):

def SGD(self, training_data, epochs, mini_batch_size, eta, test_data=None):
 if test_data:
  n_test = len(test_data)#有多少个测试集
  n = len(training_data)
  for j in xrange(epochs):
   random.shuffle(training_data)
   mini_batches = [
    training_data[k:k+mini_batch_size] 
    for k in xrange(0,n,mini_batch_size)]
   for mini_batch in mini_batches:
    self.update_mini_batch(mini_batch, eta)
   if test_data:
    print "Epoch {0}: {1}/{2}".format(j, self.evaluate(test_data),n_test)
   else:
    print "Epoch {0} complete".format(j)

其中training_data是训练集,是由很多的tuples(元组)组成。每一个元组(x,y)代表一个实例,x是图像的向量表示,y是图像的类别。
epochs表示训练多少轮。
mini_batch_size表示每一次训练的实例个数。
eta表示学习率。
test_data表示测试集。
比较重要的函数是self.update_mini_batch,他是更新权重和偏置的关键函数,接下来就定义这个函数。

def update_mini_batch(self, mini_batch,eta): 
 nabla_b = [np.zeros(b.shape) for b in self.biases]
 nabla_w = [np.zeros(w.shape) for w in self.weights]
 for x,y in mini_batch:
  delta_nabla_b, delta_nable_w = self.backprop(x,y)#目标函数对b和w的偏导数
  nabla_b = [nb+dnb for nb,dnb in zip(nabla_b,delta_nabla_b)]
  nabla_w = [nw+dnw for nw,dnw in zip(nabla_w,delta_nabla_w)]#累加b和w
 #最终更新权重为
 self.weights = [w-(eta/len(mini_batch))*nw for w, nw in zip(self.weights, nabla_w)]
 self.baises = [b-(eta/len(mini_batch))*nb for b, nb in zip(self.baises, nabla_b)]

这个update_mini_batch函数根据你传入的一些数据进行更新神经网络的权重和偏置。

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
跟老齐学Python之深入变量和引用对象
Sep 24 Python
Python基于递归算法实现的汉诺塔与Fibonacci数列示例
Apr 18 Python
Python基于多线程操作数据库相关问题分析
Jul 11 Python
influx+grafana自定义python采集数据和一些坑的总结
Sep 17 Python
python根据list重命名文件夹里的所有文件实例
Oct 25 Python
图文详解python安装Scrapy框架步骤
May 20 Python
基于Django静态资源部署404的解决方法
Jul 28 Python
python bluetooth蓝牙信息获取蓝牙设备类型的方法
Nov 29 Python
通过 Python 和 OpenCV 实现目标数量监控
Jan 05 Python
python将dict中的unicode打印成中文实例
May 11 Python
Pandas缺失值2种处理方式代码实例
Jun 13 Python
Python用户自定义异常的实现
Dec 25 Python
Python实现将一个正整数分解质因数的方法分析
Dec 14 #Python
Python随机生成均匀分布在三角形内或者任意多边形内的点
Dec 14 #Python
rabbitmq(中间消息代理)在python中的使用详解
Dec 14 #Python
用python的requests第三方模块抓取王者荣耀所有英雄的皮肤实例
Dec 14 #Python
用Python删除本地目录下某一时间点之前创建的所有文件的实例
Dec 14 #Python
python编程通过蒙特卡洛法计算定积分详解
Dec 13 #Python
Python编程产生非均匀随机数的几种方法代码分享
Dec 13 #Python
You might like
完善CodeIgniter在IDE中代码提示功能的方法
2014/07/19 PHP
php 获取文件行数的方法总结
2016/10/11 PHP
jQuery 操作下拉列表框实现代码
2010/02/22 Javascript
jquery text(),val(),html()方法区别总结
2013/11/04 Javascript
jquery获取元素值的方法(常见的表单元素)
2013/11/15 Javascript
JavaScript判断变量是对象还是数组的方法
2014/08/28 Javascript
JavaScript实现更改网页背景与字体颜色的方法
2015/02/02 Javascript
jQuery on()方法使用技巧详解
2015/04/16 Javascript
jQuery模拟原生态App上拉刷新下拉加载更多页面及原理
2015/08/10 Javascript
Js 获取、判断浏览器版本信息的简单方法
2016/08/08 Javascript
详细谈谈AngularJS的子级作用域问题
2016/09/05 Javascript
Vue.js中数组变动的检测详解
2016/10/12 Javascript
深入理解React Native原生模块与JS模块通信的几种方式
2017/07/24 Javascript
JavaScript判断变量名是否存在数组中的实例
2017/12/28 Javascript
使用vue + less 实现简单换肤功能的示例
2018/02/21 Javascript
JavaScript中 ES6变量的结构赋值
2018/07/10 Javascript
详解javascript appendChild()的完整功能
2018/08/18 Javascript
bootstrap table实现合并单元格效果
2018/12/24 Javascript
react 原生实现头像滚动播放的示例
2020/04/21 Javascript
微信公众号网页分享功能开发的示例代码
2020/05/27 Javascript
vue中可编辑树状表格的实现代码
2020/10/31 Javascript
Python中selenium实现文件上传所有方法整理总结
2017/04/01 Python
Django admin美化插件suit使用示例
2017/12/12 Python
利用python库在局域网内传输文件的方法
2018/06/04 Python
对Python实现简单的API接口实例讲解
2018/12/10 Python
python hash每次调用结果不同的原因
2019/11/21 Python
Python 实现3种回归模型(Linear Regression,Lasso,Ridge)的示例
2020/10/15 Python
用canvas实现图片滤镜效果附演示
2013/11/05 HTML / CSS
海滩咖啡馆:Beach Cafe
2018/02/02 全球购物
意大利中国电子产品购物网站:Geekmall.com
2019/09/30 全球购物
团组织推优材料
2014/12/29 职场文书
大学军训通讯稿(2016最新版)
2015/12/21 职场文书
教师学习十八届五中全会精神心得体会
2016/01/05 职场文书
python 破解加密zip文件的密码
2021/04/22 Python
goland 恢复已更改文件的操作
2021/04/28 Golang
英镑符号 £
2022/02/17 杂记