python实现随机梯度下降法


Posted in Python onMarch 24, 2020

看这篇文章前强烈建议你看看上一篇python实现梯度下降法:

一、为什么要提出随机梯度下降算法

注意看梯度下降法权值的更新方式(推导过程在上一篇文章中有)

python实现随机梯度下降法

 也就是说每次更新权值python实现随机梯度下降法都需要遍历整个数据集(注意那个求和符号),当数据量小的时候,我们还能够接受这种算法,一旦数据量过大,那么使用该方法会使得收敛过程极度缓慢,并且当存在多个局部极小值时,无法保证搜索到全局最优解。为了解决这样的问题,引入了梯度下降法的进阶形式:随机梯度下降法。

二、核心思想

对于权值的更新不再通过遍历全部的数据集,而是选择其中的一个样本即可(对于程序员来说你的第一反应一定是:在这里需要一个随机函数来选择一个样本,不是吗?),一般来说其步长的选择比梯度下降法的步长要小一点,因为梯度下降法使用的是准确梯度,所以它可以朝着全局最优解(当问题为凸问题时)较大幅度的迭代下去,但是随机梯度法不行,因为它使用的是近似梯度,或者对于全局来说有时候它走的也许根本不是梯度下降的方向,故而它走的比较缓,同样这样带来的好处就是相比于梯度下降法,它不是那么容易陷入到局部最优解中去。

三、权值更新方式

python实现随机梯度下降法

(i表示样本标号下标,j表示样本维数下标)

四、代码实现(大体与梯度下降法相同,不同在于while循环中的内容)

import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import axes3d
from matplotlib import style
 
 
#构造数据
def get_data(sample_num=1000):
 """
 拟合函数为
 y = 5*x1 + 7*x2
 :return:
 """
 x1 = np.linspace(0, 9, sample_num)
 x2 = np.linspace(4, 13, sample_num)
 x = np.concatenate(([x1], [x2]), axis=0).T
 y = np.dot(x, np.array([5, 7]).T) 
 return x, y
#梯度下降法
def SGD(samples, y, step_size=2, max_iter_count=1000):
 """
 :param samples: 样本
 :param y: 结果value
 :param step_size: 每一接迭代的步长
 :param max_iter_count: 最大的迭代次数
 :param batch_size: 随机选取的相对于总样本的大小
 :return:
 """
 #确定样本数量以及变量的个数初始化theta值
 
 m, var = samples.shape
 theta = np.zeros(2)
 y = y.flatten()
 #进入循环内
 loss = 1
 iter_count = 0
 iter_list=[]
 loss_list=[]
 theta1=[]
 theta2=[]
 #当损失精度大于0.01且迭代此时小于最大迭代次数时,进行
 while loss > 0.01 and iter_count < max_iter_count:
 loss = 0
 #梯度计算
 theta1.append(theta[0])
 theta2.append(theta[1]) 
 #样本维数下标
 rand1 = np.random.randint(0,m,1)
 h = np.dot(theta,samples[rand1].T)
 #关键点,只需要一个样本点来更新权值
 for i in range(len(theta)):
 theta[i] =theta[i] - step_size*(1/m)*(h - y[rand1])*samples[rand1,i]
 #计算总体的损失精度,等于各个样本损失精度之和
 for i in range(m):
 h = np.dot(theta.T, samples[i])
 #每组样本点损失的精度
 every_loss = (1/(var*m))*np.power((h - y[i]), 2)
 loss = loss + every_loss
 
 print("iter_count: ", iter_count, "the loss:", loss)
 
 iter_list.append(iter_count)
 loss_list.append(loss)
 
 iter_count += 1
 plt.plot(iter_list,loss_list)
 plt.xlabel("iter")
 plt.ylabel("loss")
 plt.show()
 return theta1,theta2,theta,loss_list
 
def painter3D(theta1,theta2,loss):
 style.use('ggplot')
 fig = plt.figure()
 ax1 = fig.add_subplot(111, projection='3d')
 x,y,z = theta1,theta2,loss
 ax1.plot_wireframe(x,y,z, rstride=5, cstride=5)
 ax1.set_xlabel("theta1")
 ax1.set_ylabel("theta2")
 ax1.set_zlabel("loss")
 plt.show()
 
if __name__ == '__main__':
 samples, y = get_data()
 theta1,theta2,theta,loss_list = SGD(samples, y)
 print(theta) # 会很接近[5, 7]
 
 painter3D(theta1,theta2,loss_list)

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python实现探测socket和web服务示例
Mar 28 Python
python任务调度实例分析
May 19 Python
简介Python设计模式中的代理模式与模板方法模式编程
Feb 02 Python
回调函数的意义以及python实现实例
Jun 20 Python
Pyqt实现无边框窗口拖动以及窗口大小改变
Apr 19 Python
Python实现重建二叉树的三种方法详解
Jun 23 Python
python dict 相同key 合并value的实例
Jan 21 Python
python读取图片的几种方式及图像宽和高的存储顺序
Feb 11 Python
Python如何使用正则表达式爬取京东商品信息
Jun 01 Python
PyCharm Community安装与配置的详细教程
Nov 24 Python
python中最小二乘法详细讲解
Feb 19 Python
Python 数据可视化神器Pyecharts绘制图像练习
Feb 28 Python
python实现决策树分类(2)
Aug 30 #Python
python实现决策树分类
Aug 30 #Python
python实现多人聊天室
Mar 31 #Python
Python实现将数据写入netCDF4中的方法示例
Aug 30 #Python
Python使用爬虫抓取美女图片并保存到本地的方法【测试可用】
Aug 30 #Python
Python使用一行代码获取上个月是几月
Aug 30 #Python
Python实现的读取/更改/写入xml文件操作示例
Aug 30 #Python
You might like
常用表单验证类,有了这个,一般的验证就都齐了。
2006/12/06 PHP
利用discuz实现PHP大文件上传应用实例代码
2008/11/14 PHP
PHP页面实现定时跳转的方法
2014/10/31 PHP
javascript动态添加表格数据行(ASP后台数据库保存例子)
2010/05/08 Javascript
分享一个自定义的console类 让你不再纠结JS中的调试代码的兼容
2012/04/20 Javascript
JS获取url链接字符串 location.href
2013/12/23 Javascript
js对象基础实例分析
2015/01/13 Javascript
jQuery基于$.ajax设置移动端click超时处理方法
2016/05/14 Javascript
AngularJS 所有版本下载地址
2016/09/14 Javascript
JavaScript正则表达式exec/g实现多次循环用法示例
2017/01/17 Javascript
JS仿Base.js实现的继承示例
2017/04/07 Javascript
javascript实现动态显示颜色块的报表效果
2017/04/10 Javascript
vue用ant design中table表格,点击某行时触发的事件操作
2020/10/28 Javascript
JavaScript点击按钮生成4位随机验证码
2021/01/28 Javascript
python使用reportlab画图示例(含中文汉字)
2013/12/03 Python
python获取android设备的GPS信息脚本分享
2015/03/06 Python
Python实现程序的单一实例用法分析
2015/06/03 Python
Python MySQLdb 使用utf-8 编码插入中文数据问题
2018/03/13 Python
对Python3.6 IDLE常用快捷键介绍
2018/07/16 Python
Python 画出来六维图
2019/07/26 Python
用Python批量把文件复制到另一个文件夹的实现方法
2019/08/16 Python
Python任务自动化工具tox使用教程
2020/03/17 Python
Python是怎样处理json模块的
2020/07/16 Python
python实现图像随机裁剪的示例代码
2020/12/10 Python
Data URI scheme详解和使用实例及图片base64编码实现方法
2014/05/08 HTML / CSS
解释i节点在文件系统中的作用
2013/11/26 面试题
大学生个人推荐信范文
2013/11/25 职场文书
年度考核自我评价
2014/01/25 职场文书
小学生手册家长评语
2014/04/16 职场文书
积极向上的团队口号
2014/06/06 职场文书
详细的本科生职业生涯规划范文
2014/09/16 职场文书
2014大四本科生自我鉴定总结
2014/10/04 职场文书
2014年小学辅导员工作总结
2014/12/23 职场文书
2015年万圣节活动总结
2015/03/24 职场文书
七夕情人节问候语
2015/11/11 职场文书
教你利用Nginx 服务搭建子域环境提升二维地图加载性能的步骤
2021/09/25 Servers