python实现随机梯度下降法


Posted in Python onMarch 24, 2020

看这篇文章前强烈建议你看看上一篇python实现梯度下降法:

一、为什么要提出随机梯度下降算法

注意看梯度下降法权值的更新方式(推导过程在上一篇文章中有)

python实现随机梯度下降法

 也就是说每次更新权值python实现随机梯度下降法都需要遍历整个数据集(注意那个求和符号),当数据量小的时候,我们还能够接受这种算法,一旦数据量过大,那么使用该方法会使得收敛过程极度缓慢,并且当存在多个局部极小值时,无法保证搜索到全局最优解。为了解决这样的问题,引入了梯度下降法的进阶形式:随机梯度下降法。

二、核心思想

对于权值的更新不再通过遍历全部的数据集,而是选择其中的一个样本即可(对于程序员来说你的第一反应一定是:在这里需要一个随机函数来选择一个样本,不是吗?),一般来说其步长的选择比梯度下降法的步长要小一点,因为梯度下降法使用的是准确梯度,所以它可以朝着全局最优解(当问题为凸问题时)较大幅度的迭代下去,但是随机梯度法不行,因为它使用的是近似梯度,或者对于全局来说有时候它走的也许根本不是梯度下降的方向,故而它走的比较缓,同样这样带来的好处就是相比于梯度下降法,它不是那么容易陷入到局部最优解中去。

三、权值更新方式

python实现随机梯度下降法

(i表示样本标号下标,j表示样本维数下标)

四、代码实现(大体与梯度下降法相同,不同在于while循环中的内容)

import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import axes3d
from matplotlib import style
 
 
#构造数据
def get_data(sample_num=1000):
 """
 拟合函数为
 y = 5*x1 + 7*x2
 :return:
 """
 x1 = np.linspace(0, 9, sample_num)
 x2 = np.linspace(4, 13, sample_num)
 x = np.concatenate(([x1], [x2]), axis=0).T
 y = np.dot(x, np.array([5, 7]).T) 
 return x, y
#梯度下降法
def SGD(samples, y, step_size=2, max_iter_count=1000):
 """
 :param samples: 样本
 :param y: 结果value
 :param step_size: 每一接迭代的步长
 :param max_iter_count: 最大的迭代次数
 :param batch_size: 随机选取的相对于总样本的大小
 :return:
 """
 #确定样本数量以及变量的个数初始化theta值
 
 m, var = samples.shape
 theta = np.zeros(2)
 y = y.flatten()
 #进入循环内
 loss = 1
 iter_count = 0
 iter_list=[]
 loss_list=[]
 theta1=[]
 theta2=[]
 #当损失精度大于0.01且迭代此时小于最大迭代次数时,进行
 while loss > 0.01 and iter_count < max_iter_count:
 loss = 0
 #梯度计算
 theta1.append(theta[0])
 theta2.append(theta[1]) 
 #样本维数下标
 rand1 = np.random.randint(0,m,1)
 h = np.dot(theta,samples[rand1].T)
 #关键点,只需要一个样本点来更新权值
 for i in range(len(theta)):
 theta[i] =theta[i] - step_size*(1/m)*(h - y[rand1])*samples[rand1,i]
 #计算总体的损失精度,等于各个样本损失精度之和
 for i in range(m):
 h = np.dot(theta.T, samples[i])
 #每组样本点损失的精度
 every_loss = (1/(var*m))*np.power((h - y[i]), 2)
 loss = loss + every_loss
 
 print("iter_count: ", iter_count, "the loss:", loss)
 
 iter_list.append(iter_count)
 loss_list.append(loss)
 
 iter_count += 1
 plt.plot(iter_list,loss_list)
 plt.xlabel("iter")
 plt.ylabel("loss")
 plt.show()
 return theta1,theta2,theta,loss_list
 
def painter3D(theta1,theta2,loss):
 style.use('ggplot')
 fig = plt.figure()
 ax1 = fig.add_subplot(111, projection='3d')
 x,y,z = theta1,theta2,loss
 ax1.plot_wireframe(x,y,z, rstride=5, cstride=5)
 ax1.set_xlabel("theta1")
 ax1.set_ylabel("theta2")
 ax1.set_zlabel("loss")
 plt.show()
 
if __name__ == '__main__':
 samples, y = get_data()
 theta1,theta2,theta,loss_list = SGD(samples, y)
 print(theta) # 会很接近[5, 7]
 
 painter3D(theta1,theta2,loss_list)

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python多线程实例教程
Sep 06 Python
在Docker上部署Python的Flask框架的教程
Apr 08 Python
Python中__init__.py文件的作用详解
Sep 18 Python
Python基于list的append和pop方法实现堆栈与队列功能示例
Jul 24 Python
解析Python中的eval()、exec()及其相关函数
Dec 20 Python
如何利用Boost.Python实现Python C/C++混合编程详解
Nov 08 Python
Django中的静态文件管理过程解析
Aug 01 Python
Python 静态方法和类方法实例分析
Nov 21 Python
Python实现桌面翻译工具【新手必学】
Feb 12 Python
jupyter notebook 调用环境中的Keras或者pytorch教程
Apr 14 Python
python def 定义函数,调用函数方式
Jun 02 Python
python模块如何查看
Jun 16 Python
python实现决策树分类(2)
Aug 30 #Python
python实现决策树分类
Aug 30 #Python
python实现多人聊天室
Mar 31 #Python
Python实现将数据写入netCDF4中的方法示例
Aug 30 #Python
Python使用爬虫抓取美女图片并保存到本地的方法【测试可用】
Aug 30 #Python
Python使用一行代码获取上个月是几月
Aug 30 #Python
Python实现的读取/更改/写入xml文件操作示例
Aug 30 #Python
You might like
如何使用PHP获取网络上文件
2006/10/09 PHP
php通过COM类调用组件的实现代码
2012/01/11 PHP
php根据分类合并数组的方法实例详解
2013/11/06 PHP
php操作xml入门之xml标签的属性分析
2015/01/23 PHP
yii2.0实现验证用户名与邮箱功能
2015/12/22 PHP
PHP验证终端类型是否为手机的简单实例
2017/02/07 PHP
解决php写入数据库乱码的问题
2019/09/17 PHP
让插入到 innerHTML 中的 script 跑起来的实现代码
2006/07/01 Javascript
javascript 有用的脚本函数
2009/05/07 Javascript
JS修改css样式style浅谈
2013/05/06 Javascript
jquery遍历之parent()和parents()的区别及parentsUntil()方法详解
2013/12/02 Javascript
javascript常用函数归纳整理
2014/10/31 Javascript
基于jQuery实现复选框的全选 全不选 反选功能
2014/11/24 Javascript
AngularJS身份验证的方法
2016/02/17 Javascript
提高JavaScript执行效率的23个实用技巧
2017/03/01 Javascript
JSONP解决JS跨域问题的实现
2020/05/25 Javascript
javascript实现放大镜功能
2020/12/09 Javascript
[01:51]2018年度CS GO最具人气外援-完美盛典
2018/12/16 DOTA
[04:03][TI9趣味短片] 小鸽子茶话会
2019/08/20 DOTA
python解析xml文件实例分享
2013/12/04 Python
python3.6使用urllib完成下载的实例
2018/12/19 Python
安装好Pycharm后如何配置Python解释器简易教程
2019/06/28 Python
浅析PyTorch中nn.Module的使用
2019/08/18 Python
使用 pytorch 创建神经网络拟合sin函数的实现
2020/02/24 Python
基于Python共轭梯度法与最速下降法之间的对比
2020/04/02 Python
python实时监控logstash日志代码
2020/04/27 Python
浅谈keras保存模型中的save()和save_weights()区别
2020/05/21 Python
零基础学Python之前需要学c语言吗
2020/07/21 Python
世界汽车零件:World Car Parts
2019/09/04 全球购物
澳大利亚波希米亚风时尚品牌:Tree of Life
2019/09/15 全球购物
100%法国制造的游戏和玩具:Les Jouets Français
2021/03/02 全球购物
大学学习生活感言
2014/01/18 职场文书
公司离职证明标准样本
2014/10/05 职场文书
综合素质评价个性与发展自我评价
2015/03/06 职场文书
《自己去吧》教学反思
2016/02/16 职场文书
python中使用 unittest.TestCase单元测试的用例详解
2021/08/30 Python