python实现随机梯度下降法


Posted in Python onMarch 24, 2020

看这篇文章前强烈建议你看看上一篇python实现梯度下降法:

一、为什么要提出随机梯度下降算法

注意看梯度下降法权值的更新方式(推导过程在上一篇文章中有)

python实现随机梯度下降法

 也就是说每次更新权值python实现随机梯度下降法都需要遍历整个数据集(注意那个求和符号),当数据量小的时候,我们还能够接受这种算法,一旦数据量过大,那么使用该方法会使得收敛过程极度缓慢,并且当存在多个局部极小值时,无法保证搜索到全局最优解。为了解决这样的问题,引入了梯度下降法的进阶形式:随机梯度下降法。

二、核心思想

对于权值的更新不再通过遍历全部的数据集,而是选择其中的一个样本即可(对于程序员来说你的第一反应一定是:在这里需要一个随机函数来选择一个样本,不是吗?),一般来说其步长的选择比梯度下降法的步长要小一点,因为梯度下降法使用的是准确梯度,所以它可以朝着全局最优解(当问题为凸问题时)较大幅度的迭代下去,但是随机梯度法不行,因为它使用的是近似梯度,或者对于全局来说有时候它走的也许根本不是梯度下降的方向,故而它走的比较缓,同样这样带来的好处就是相比于梯度下降法,它不是那么容易陷入到局部最优解中去。

三、权值更新方式

python实现随机梯度下降法

(i表示样本标号下标,j表示样本维数下标)

四、代码实现(大体与梯度下降法相同,不同在于while循环中的内容)

import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import axes3d
from matplotlib import style
 
 
#构造数据
def get_data(sample_num=1000):
 """
 拟合函数为
 y = 5*x1 + 7*x2
 :return:
 """
 x1 = np.linspace(0, 9, sample_num)
 x2 = np.linspace(4, 13, sample_num)
 x = np.concatenate(([x1], [x2]), axis=0).T
 y = np.dot(x, np.array([5, 7]).T) 
 return x, y
#梯度下降法
def SGD(samples, y, step_size=2, max_iter_count=1000):
 """
 :param samples: 样本
 :param y: 结果value
 :param step_size: 每一接迭代的步长
 :param max_iter_count: 最大的迭代次数
 :param batch_size: 随机选取的相对于总样本的大小
 :return:
 """
 #确定样本数量以及变量的个数初始化theta值
 
 m, var = samples.shape
 theta = np.zeros(2)
 y = y.flatten()
 #进入循环内
 loss = 1
 iter_count = 0
 iter_list=[]
 loss_list=[]
 theta1=[]
 theta2=[]
 #当损失精度大于0.01且迭代此时小于最大迭代次数时,进行
 while loss > 0.01 and iter_count < max_iter_count:
 loss = 0
 #梯度计算
 theta1.append(theta[0])
 theta2.append(theta[1]) 
 #样本维数下标
 rand1 = np.random.randint(0,m,1)
 h = np.dot(theta,samples[rand1].T)
 #关键点,只需要一个样本点来更新权值
 for i in range(len(theta)):
 theta[i] =theta[i] - step_size*(1/m)*(h - y[rand1])*samples[rand1,i]
 #计算总体的损失精度,等于各个样本损失精度之和
 for i in range(m):
 h = np.dot(theta.T, samples[i])
 #每组样本点损失的精度
 every_loss = (1/(var*m))*np.power((h - y[i]), 2)
 loss = loss + every_loss
 
 print("iter_count: ", iter_count, "the loss:", loss)
 
 iter_list.append(iter_count)
 loss_list.append(loss)
 
 iter_count += 1
 plt.plot(iter_list,loss_list)
 plt.xlabel("iter")
 plt.ylabel("loss")
 plt.show()
 return theta1,theta2,theta,loss_list
 
def painter3D(theta1,theta2,loss):
 style.use('ggplot')
 fig = plt.figure()
 ax1 = fig.add_subplot(111, projection='3d')
 x,y,z = theta1,theta2,loss
 ax1.plot_wireframe(x,y,z, rstride=5, cstride=5)
 ax1.set_xlabel("theta1")
 ax1.set_ylabel("theta2")
 ax1.set_zlabel("loss")
 plt.show()
 
if __name__ == '__main__':
 samples, y = get_data()
 theta1,theta2,theta,loss_list = SGD(samples, y)
 print(theta) # 会很接近[5, 7]
 
 painter3D(theta1,theta2,loss_list)

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python实现股市信息下载的方法
Jun 15 Python
python获取局域网占带宽最大3个ip的方法
Jul 09 Python
python数据处理实战(必看篇)
Jun 11 Python
python3 shelve模块的详解
Jul 08 Python
Python实现批量读取图片并存入mongodb数据库的方法示例
Apr 02 Python
使用python对excle和json互相转换的示例
Oct 23 Python
Windows10下Tensorflow2.0 安装及环境配置教程(图文)
Nov 21 Python
将python包发布到PyPI和制作whl文件方式
Dec 25 Python
Python HTTP下载文件并显示下载进度条功能的实现
Apr 02 Python
python opencv把一张图片嵌入(叠加)到另一张图片上的实现代码
Jun 11 Python
Python unittest装饰器实现原理及代码
Sep 08 Python
scrapy与selenium结合爬取数据(爬取动态网站)的示例代码
Sep 28 Python
python实现决策树分类(2)
Aug 30 #Python
python实现决策树分类
Aug 30 #Python
python实现多人聊天室
Mar 31 #Python
Python实现将数据写入netCDF4中的方法示例
Aug 30 #Python
Python使用爬虫抓取美女图片并保存到本地的方法【测试可用】
Aug 30 #Python
Python使用一行代码获取上个月是几月
Aug 30 #Python
Python实现的读取/更改/写入xml文件操作示例
Aug 30 #Python
You might like
php strnatcmp()函数的用法总结
2013/11/27 PHP
PHP fclose函数用法总结
2019/02/15 PHP
laravel中数据显示方法(默认值和下拉option默认选中)
2019/10/11 PHP
PHP实现爬虫爬取图片代码实例
2021/03/03 PHP
ie 调试javascript的工具
2009/04/29 Javascript
document.createElement()用法
2013/03/13 Javascript
tangram框架响应式加载图片方法
2013/11/21 Javascript
Jquery设置attr的disabled属性控制某行显示或者隐藏
2014/09/25 Javascript
实用框架(iframe)操作代码
2014/10/23 Javascript
DOM基础教程之事件对象
2015/01/20 Javascript
浅谈Jquery核心函数
2015/06/18 Javascript
Bootstarp风格的toggle效果分享
2016/02/23 Javascript
微信小程序 基础组件与导航组件详细介绍
2017/02/21 Javascript
详解VUE2.X过滤器的使用方法
2018/01/11 Javascript
React Native 真机断点调试+跨域资源加载出错问题的解决方法
2018/01/18 Javascript
ES10 特性的完整指南小结
2019/03/04 Javascript
WebStorm无法正确识别Vue3组合式API的解决方案
2021/02/18 Vue.js
[02:30]辉夜杯主赛事第二日胜者组半决赛 CDEC.Y赛后采访
2015/12/26 DOTA
Python 自动安装 Rising 杀毒软件
2009/04/24 Python
Python Deque 模块使用详解
2014/07/04 Python
使用Python获取Linux系统的各种信息
2014/07/10 Python
Python使用PDFMiner解析PDF代码实例
2017/03/27 Python
Python学生成绩管理系统简洁版
2020/04/05 Python
Windows平台Python编程必会模块之pywin32介绍
2019/10/01 Python
python 上下文管理器及自定义原理解析
2019/11/19 Python
Pytorch实现的手写数字mnist识别功能完整示例
2019/12/13 Python
Python实现汇率转换操作
2020/05/03 Python
Python flask框架如何显示图像到web页面
2020/06/03 Python
python和JavaScript哪个容易上手
2020/06/23 Python
Stuart Weitzman美国官网:美国奢华鞋履品牌
2016/08/18 全球购物
斯凯奇新西兰官网:SKECHERS新西兰
2018/02/22 全球购物
美国性感内衣店:Yandy
2018/06/12 全球购物
DataList 能否分页,请问如何实现?
2015/05/03 面试题
大学毕业自我评价
2014/02/02 职场文书
导游词之上海豫园
2019/10/24 职场文书
python开发实时可视化仪表盘的示例
2021/05/07 Python