python实现随机梯度下降法


Posted in Python onMarch 24, 2020

看这篇文章前强烈建议你看看上一篇python实现梯度下降法:

一、为什么要提出随机梯度下降算法

注意看梯度下降法权值的更新方式(推导过程在上一篇文章中有)

python实现随机梯度下降法

 也就是说每次更新权值python实现随机梯度下降法都需要遍历整个数据集(注意那个求和符号),当数据量小的时候,我们还能够接受这种算法,一旦数据量过大,那么使用该方法会使得收敛过程极度缓慢,并且当存在多个局部极小值时,无法保证搜索到全局最优解。为了解决这样的问题,引入了梯度下降法的进阶形式:随机梯度下降法。

二、核心思想

对于权值的更新不再通过遍历全部的数据集,而是选择其中的一个样本即可(对于程序员来说你的第一反应一定是:在这里需要一个随机函数来选择一个样本,不是吗?),一般来说其步长的选择比梯度下降法的步长要小一点,因为梯度下降法使用的是准确梯度,所以它可以朝着全局最优解(当问题为凸问题时)较大幅度的迭代下去,但是随机梯度法不行,因为它使用的是近似梯度,或者对于全局来说有时候它走的也许根本不是梯度下降的方向,故而它走的比较缓,同样这样带来的好处就是相比于梯度下降法,它不是那么容易陷入到局部最优解中去。

三、权值更新方式

python实现随机梯度下降法

(i表示样本标号下标,j表示样本维数下标)

四、代码实现(大体与梯度下降法相同,不同在于while循环中的内容)

import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import axes3d
from matplotlib import style
 
 
#构造数据
def get_data(sample_num=1000):
 """
 拟合函数为
 y = 5*x1 + 7*x2
 :return:
 """
 x1 = np.linspace(0, 9, sample_num)
 x2 = np.linspace(4, 13, sample_num)
 x = np.concatenate(([x1], [x2]), axis=0).T
 y = np.dot(x, np.array([5, 7]).T) 
 return x, y
#梯度下降法
def SGD(samples, y, step_size=2, max_iter_count=1000):
 """
 :param samples: 样本
 :param y: 结果value
 :param step_size: 每一接迭代的步长
 :param max_iter_count: 最大的迭代次数
 :param batch_size: 随机选取的相对于总样本的大小
 :return:
 """
 #确定样本数量以及变量的个数初始化theta值
 
 m, var = samples.shape
 theta = np.zeros(2)
 y = y.flatten()
 #进入循环内
 loss = 1
 iter_count = 0
 iter_list=[]
 loss_list=[]
 theta1=[]
 theta2=[]
 #当损失精度大于0.01且迭代此时小于最大迭代次数时,进行
 while loss > 0.01 and iter_count < max_iter_count:
 loss = 0
 #梯度计算
 theta1.append(theta[0])
 theta2.append(theta[1]) 
 #样本维数下标
 rand1 = np.random.randint(0,m,1)
 h = np.dot(theta,samples[rand1].T)
 #关键点,只需要一个样本点来更新权值
 for i in range(len(theta)):
 theta[i] =theta[i] - step_size*(1/m)*(h - y[rand1])*samples[rand1,i]
 #计算总体的损失精度,等于各个样本损失精度之和
 for i in range(m):
 h = np.dot(theta.T, samples[i])
 #每组样本点损失的精度
 every_loss = (1/(var*m))*np.power((h - y[i]), 2)
 loss = loss + every_loss
 
 print("iter_count: ", iter_count, "the loss:", loss)
 
 iter_list.append(iter_count)
 loss_list.append(loss)
 
 iter_count += 1
 plt.plot(iter_list,loss_list)
 plt.xlabel("iter")
 plt.ylabel("loss")
 plt.show()
 return theta1,theta2,theta,loss_list
 
def painter3D(theta1,theta2,loss):
 style.use('ggplot')
 fig = plt.figure()
 ax1 = fig.add_subplot(111, projection='3d')
 x,y,z = theta1,theta2,loss
 ax1.plot_wireframe(x,y,z, rstride=5, cstride=5)
 ax1.set_xlabel("theta1")
 ax1.set_ylabel("theta2")
 ax1.set_zlabel("loss")
 plt.show()
 
if __name__ == '__main__':
 samples, y = get_data()
 theta1,theta2,theta,loss_list = SGD(samples, y)
 print(theta) # 会很接近[5, 7]
 
 painter3D(theta1,theta2,loss_list)

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
深度定制Python的Flask框架开发环境的一些技巧总结
Jul 12 Python
Python中XlsxWriter模块简介与用法分析
Apr 24 Python
python 每天如何定时启动爬虫任务(实现方法分享)
May 21 Python
解决Pycharm运行时找不到文件的问题
Oct 29 Python
解决Djang2.0.1中的reverse导入失败的问题
Aug 16 Python
python实现代码统计程序
Sep 19 Python
Python中实现输入超时及如何通过变量获取变量名
Jan 18 Python
在Python 的线程中运行协程的方法
Feb 24 Python
Python3.7安装pyaudio教程解析
Jul 24 Python
Python os库常用操作代码汇总
Nov 03 Python
Python机器学习应用之基于线性判别模型的分类篇详解
Jan 18 Python
一行Python命令实现批量加水印
Apr 07 Python
python实现决策树分类(2)
Aug 30 #Python
python实现决策树分类
Aug 30 #Python
python实现多人聊天室
Mar 31 #Python
Python实现将数据写入netCDF4中的方法示例
Aug 30 #Python
Python使用爬虫抓取美女图片并保存到本地的方法【测试可用】
Aug 30 #Python
Python使用一行代码获取上个月是几月
Aug 30 #Python
Python实现的读取/更改/写入xml文件操作示例
Aug 30 #Python
You might like
解析WordPress中控制用户登陆和判断用户登陆的PHP函数
2016/03/01 PHP
PHP面向对象程序设计内置标准类,普通数据类型转为对象类型示例
2019/06/12 PHP
PhpStorm的使用教程(本地运行PHP+远程开发+快捷键)
2020/03/26 PHP
用javascript获得地址栏参数的两种方法
2006/11/08 Javascript
十分钟打造AutoComplete自动完成效果代码
2009/12/26 Javascript
createElement与createDocumentFragment的点点区别小结
2011/12/19 Javascript
jQuery实现表格行上下移动和置顶效果
2015/06/05 Javascript
Uploadify上传文件方法
2016/03/16 Javascript
jQuery实现自动调用和触发某个事件的方法
2016/11/18 Javascript
使用BootStrap实现悬浮窗口的效果
2016/12/13 Javascript
微信小程序 表单Form实例详解(附源码)
2016/12/22 Javascript
微信小程序五星评分效果实现代码
2017/04/06 Javascript
Nodejs实现多房间简易聊天室功能
2017/06/20 NodeJs
JavaScript事件处理程序详解
2017/09/19 Javascript
Vue通过URL传参如何控制全局console.log的开关详解
2017/12/07 Javascript
Ionic学习日记实现验证码倒计时
2018/02/08 Javascript
webpack 插件html-webpack-plugin的具体使用
2018/04/09 Javascript
如何解决webpack-dev-server代理常切换问题
2019/01/09 Javascript
实例介绍JavaScript中多种组合继承
2019/01/20 Javascript
vue响应式更新机制及不使用框架实现简单的数据双向绑定问题
2019/06/27 Javascript
VUE+elementui组件在table-cell单元格中绘制微型echarts图
2020/04/20 Javascript
vue通过接口直接下载java生成好的Excel表格案例
2020/10/26 Javascript
[56:14]Fnatic vs OG 2018国际邀请赛小组赛BO2 第二场 8.18
2018/08/19 DOTA
python简单获取本机计算机名和IP地址的方法
2015/06/03 Python
详解Python读取配置文件模块ConfigParser
2017/05/11 Python
Python读取word文本操作详解
2018/01/22 Python
Python 监测文件是否更新的方法
2019/06/10 Python
Python基础类继承重写实现原理解析
2020/04/03 Python
浅谈Python3中print函数的换行
2020/08/05 Python
python 利用panda 实现列联表(交叉表)
2021/02/06 Python
Parts Express:音频、视频和扬声器的第一来源
2017/04/25 全球购物
水果超市创业计划书
2014/01/27 职场文书
经济贸易专业自荐信
2014/06/11 职场文书
运动会演讲稿100字
2014/08/25 职场文书
导游欢送词
2015/01/31 职场文书
建国大业观后感600字
2015/06/01 职场文书