python实现随机梯度下降法


Posted in Python onMarch 24, 2020

看这篇文章前强烈建议你看看上一篇python实现梯度下降法:

一、为什么要提出随机梯度下降算法

注意看梯度下降法权值的更新方式(推导过程在上一篇文章中有)

python实现随机梯度下降法

 也就是说每次更新权值python实现随机梯度下降法都需要遍历整个数据集(注意那个求和符号),当数据量小的时候,我们还能够接受这种算法,一旦数据量过大,那么使用该方法会使得收敛过程极度缓慢,并且当存在多个局部极小值时,无法保证搜索到全局最优解。为了解决这样的问题,引入了梯度下降法的进阶形式:随机梯度下降法。

二、核心思想

对于权值的更新不再通过遍历全部的数据集,而是选择其中的一个样本即可(对于程序员来说你的第一反应一定是:在这里需要一个随机函数来选择一个样本,不是吗?),一般来说其步长的选择比梯度下降法的步长要小一点,因为梯度下降法使用的是准确梯度,所以它可以朝着全局最优解(当问题为凸问题时)较大幅度的迭代下去,但是随机梯度法不行,因为它使用的是近似梯度,或者对于全局来说有时候它走的也许根本不是梯度下降的方向,故而它走的比较缓,同样这样带来的好处就是相比于梯度下降法,它不是那么容易陷入到局部最优解中去。

三、权值更新方式

python实现随机梯度下降法

(i表示样本标号下标,j表示样本维数下标)

四、代码实现(大体与梯度下降法相同,不同在于while循环中的内容)

import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import axes3d
from matplotlib import style
 
 
#构造数据
def get_data(sample_num=1000):
 """
 拟合函数为
 y = 5*x1 + 7*x2
 :return:
 """
 x1 = np.linspace(0, 9, sample_num)
 x2 = np.linspace(4, 13, sample_num)
 x = np.concatenate(([x1], [x2]), axis=0).T
 y = np.dot(x, np.array([5, 7]).T) 
 return x, y
#梯度下降法
def SGD(samples, y, step_size=2, max_iter_count=1000):
 """
 :param samples: 样本
 :param y: 结果value
 :param step_size: 每一接迭代的步长
 :param max_iter_count: 最大的迭代次数
 :param batch_size: 随机选取的相对于总样本的大小
 :return:
 """
 #确定样本数量以及变量的个数初始化theta值
 
 m, var = samples.shape
 theta = np.zeros(2)
 y = y.flatten()
 #进入循环内
 loss = 1
 iter_count = 0
 iter_list=[]
 loss_list=[]
 theta1=[]
 theta2=[]
 #当损失精度大于0.01且迭代此时小于最大迭代次数时,进行
 while loss > 0.01 and iter_count < max_iter_count:
 loss = 0
 #梯度计算
 theta1.append(theta[0])
 theta2.append(theta[1]) 
 #样本维数下标
 rand1 = np.random.randint(0,m,1)
 h = np.dot(theta,samples[rand1].T)
 #关键点,只需要一个样本点来更新权值
 for i in range(len(theta)):
 theta[i] =theta[i] - step_size*(1/m)*(h - y[rand1])*samples[rand1,i]
 #计算总体的损失精度,等于各个样本损失精度之和
 for i in range(m):
 h = np.dot(theta.T, samples[i])
 #每组样本点损失的精度
 every_loss = (1/(var*m))*np.power((h - y[i]), 2)
 loss = loss + every_loss
 
 print("iter_count: ", iter_count, "the loss:", loss)
 
 iter_list.append(iter_count)
 loss_list.append(loss)
 
 iter_count += 1
 plt.plot(iter_list,loss_list)
 plt.xlabel("iter")
 plt.ylabel("loss")
 plt.show()
 return theta1,theta2,theta,loss_list
 
def painter3D(theta1,theta2,loss):
 style.use('ggplot')
 fig = plt.figure()
 ax1 = fig.add_subplot(111, projection='3d')
 x,y,z = theta1,theta2,loss
 ax1.plot_wireframe(x,y,z, rstride=5, cstride=5)
 ax1.set_xlabel("theta1")
 ax1.set_ylabel("theta2")
 ax1.set_zlabel("loss")
 plt.show()
 
if __name__ == '__main__':
 samples, y = get_data()
 theta1,theta2,theta,loss_list = SGD(samples, y)
 print(theta) # 会很接近[5, 7]
 
 painter3D(theta1,theta2,loss_list)

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python XML RPC服务器端和客户端实例
Nov 22 Python
Python下rrdtool模块的基本使用方法
Nov 13 Python
深入理解python中sort()与sorted()的区别
Aug 29 Python
python Pexpect 实现输密码 scp 拷贝的方法
Jan 03 Python
对Python中DataFrame选择某列值为XX的行实例详解
Jan 29 Python
django与小程序实现登录验证功能的示例代码
Feb 19 Python
python-Web-flask-视图内容和模板知识点西宁街
Aug 23 Python
Python 闭包,函数分隔作用域,nonlocal声明非局部变量操作示例
Oct 14 Python
Python日志syslog使用原理详解
Feb 18 Python
Python实现清理微信僵尸粉功能示例【基于itchat模块】
May 29 Python
python之语音识别speech模块
Sep 09 Python
使用Python+OpenCV进行卡类型及16位卡号数字的OCR功能
Aug 30 Python
python实现决策树分类(2)
Aug 30 #Python
python实现决策树分类
Aug 30 #Python
python实现多人聊天室
Mar 31 #Python
Python实现将数据写入netCDF4中的方法示例
Aug 30 #Python
Python使用爬虫抓取美女图片并保存到本地的方法【测试可用】
Aug 30 #Python
Python使用一行代码获取上个月是几月
Aug 30 #Python
Python实现的读取/更改/写入xml文件操作示例
Aug 30 #Python
You might like
全国FM电台频率大全 - 12 安徽省
2020/03/11 无线电
php下的权限算法的实现
2007/04/28 PHP
使用 eAccelerator加速PHP代码的方法
2007/09/30 PHP
PHP 错误处理机制
2015/07/06 PHP
必须收藏的php实用代码片段
2016/02/02 PHP
详解PHP的Yii框架中的Controller控制器
2016/03/29 PHP
PHP自定义多进制的方法
2016/11/03 PHP
PHP命令Command模式用法实例分析
2018/08/08 PHP
Laravel框架数据库迁移操作实例详解
2020/04/06 PHP
游戏人文件夹程序 ver 4.03
2006/07/14 Javascript
一个js的tab切换效果代码[代码分离]
2010/04/11 Javascript
JavaScript游戏之优化篇
2010/11/08 Javascript
实例说明为什么不要行内使用javascript
2014/04/18 Javascript
使用原生js写的一个简单slider
2014/04/29 Javascript
JS实现仿Windows经典风格的选项卡Tab切换代码
2015/10/20 Javascript
bootstrap与Jquery UI 按钮样式冲突的解决办法
2016/09/23 Javascript
详解vue数据渲染出现闪烁问题
2017/06/29 Javascript
Webpack path与publicPath的区别详解
2018/05/03 Javascript
基于vue v-for 多层循环嵌套获取行数的方法
2018/09/26 Javascript
Vue中的Props(不可变状态)
2018/09/29 Javascript
使用koa2创建web项目的方法步骤
2019/03/12 Javascript
JavaScript获取当前url路径过程解析
2019/12/27 Javascript
JavaScript基于面向对象实现的无缝滚动轮播示例
2020/01/17 Javascript
JavaScript实现字符串与HTML格式相互转换
2020/03/17 Javascript
JS实现放大镜效果
2020/09/21 Javascript
angular8.5集成TinyMce5的使用和详细配置(推荐)
2020/11/16 Javascript
解决Python pandas plot输出图形中显示中文乱码问题
2018/12/12 Python
Python计算时间间隔(精确到微妙)的代码实例
2019/02/26 Python
python ffmpeg任意提取视频帧的方法
2020/02/21 Python
OpenCV Python实现拼图小游戏
2020/03/23 Python
联想印度官方网上商店:Lenovo India
2019/08/24 全球购物
经理职责范文
2013/11/08 职场文书
网络维护中文求职信
2014/01/03 职场文书
计算机专业优秀大学生自我总结
2014/01/21 职场文书
公务员廉洁从政心得体会
2016/01/19 职场文书
Matlab如何实现矩阵复制扩充
2021/06/02 Python