python实现随机梯度下降法


Posted in Python onMarch 24, 2020

看这篇文章前强烈建议你看看上一篇python实现梯度下降法:

一、为什么要提出随机梯度下降算法

注意看梯度下降法权值的更新方式(推导过程在上一篇文章中有)

python实现随机梯度下降法

 也就是说每次更新权值python实现随机梯度下降法都需要遍历整个数据集(注意那个求和符号),当数据量小的时候,我们还能够接受这种算法,一旦数据量过大,那么使用该方法会使得收敛过程极度缓慢,并且当存在多个局部极小值时,无法保证搜索到全局最优解。为了解决这样的问题,引入了梯度下降法的进阶形式:随机梯度下降法。

二、核心思想

对于权值的更新不再通过遍历全部的数据集,而是选择其中的一个样本即可(对于程序员来说你的第一反应一定是:在这里需要一个随机函数来选择一个样本,不是吗?),一般来说其步长的选择比梯度下降法的步长要小一点,因为梯度下降法使用的是准确梯度,所以它可以朝着全局最优解(当问题为凸问题时)较大幅度的迭代下去,但是随机梯度法不行,因为它使用的是近似梯度,或者对于全局来说有时候它走的也许根本不是梯度下降的方向,故而它走的比较缓,同样这样带来的好处就是相比于梯度下降法,它不是那么容易陷入到局部最优解中去。

三、权值更新方式

python实现随机梯度下降法

(i表示样本标号下标,j表示样本维数下标)

四、代码实现(大体与梯度下降法相同,不同在于while循环中的内容)

import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import axes3d
from matplotlib import style
 
 
#构造数据
def get_data(sample_num=1000):
 """
 拟合函数为
 y = 5*x1 + 7*x2
 :return:
 """
 x1 = np.linspace(0, 9, sample_num)
 x2 = np.linspace(4, 13, sample_num)
 x = np.concatenate(([x1], [x2]), axis=0).T
 y = np.dot(x, np.array([5, 7]).T) 
 return x, y
#梯度下降法
def SGD(samples, y, step_size=2, max_iter_count=1000):
 """
 :param samples: 样本
 :param y: 结果value
 :param step_size: 每一接迭代的步长
 :param max_iter_count: 最大的迭代次数
 :param batch_size: 随机选取的相对于总样本的大小
 :return:
 """
 #确定样本数量以及变量的个数初始化theta值
 
 m, var = samples.shape
 theta = np.zeros(2)
 y = y.flatten()
 #进入循环内
 loss = 1
 iter_count = 0
 iter_list=[]
 loss_list=[]
 theta1=[]
 theta2=[]
 #当损失精度大于0.01且迭代此时小于最大迭代次数时,进行
 while loss > 0.01 and iter_count < max_iter_count:
 loss = 0
 #梯度计算
 theta1.append(theta[0])
 theta2.append(theta[1]) 
 #样本维数下标
 rand1 = np.random.randint(0,m,1)
 h = np.dot(theta,samples[rand1].T)
 #关键点,只需要一个样本点来更新权值
 for i in range(len(theta)):
 theta[i] =theta[i] - step_size*(1/m)*(h - y[rand1])*samples[rand1,i]
 #计算总体的损失精度,等于各个样本损失精度之和
 for i in range(m):
 h = np.dot(theta.T, samples[i])
 #每组样本点损失的精度
 every_loss = (1/(var*m))*np.power((h - y[i]), 2)
 loss = loss + every_loss
 
 print("iter_count: ", iter_count, "the loss:", loss)
 
 iter_list.append(iter_count)
 loss_list.append(loss)
 
 iter_count += 1
 plt.plot(iter_list,loss_list)
 plt.xlabel("iter")
 plt.ylabel("loss")
 plt.show()
 return theta1,theta2,theta,loss_list
 
def painter3D(theta1,theta2,loss):
 style.use('ggplot')
 fig = plt.figure()
 ax1 = fig.add_subplot(111, projection='3d')
 x,y,z = theta1,theta2,loss
 ax1.plot_wireframe(x,y,z, rstride=5, cstride=5)
 ax1.set_xlabel("theta1")
 ax1.set_ylabel("theta2")
 ax1.set_zlabel("loss")
 plt.show()
 
if __name__ == '__main__':
 samples, y = get_data()
 theta1,theta2,theta,loss_list = SGD(samples, y)
 print(theta) # 会很接近[5, 7]
 
 painter3D(theta1,theta2,loss_list)

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python Requests 基础入门
Apr 07 Python
python win32 简单操作方法
May 25 Python
python里使用正则的findall函数的实例详解
Oct 19 Python
python生成excel的实例代码
Nov 08 Python
python3.4爬虫demo
Jan 22 Python
Python实现操纵控制windows注册表的方法分析
May 24 Python
如何理解Python中的变量
Jun 01 Python
Python SQLAlchemy库的使用方法
Oct 13 Python
scrapy实践之翻页爬取的实现
Jan 05 Python
python asyncio 协程库的使用
Jan 21 Python
Python开发之QT解决无边框界面拖动卡屏问题(附带源码)
May 27 Python
Pygame如何使用精灵和碰撞检测
Nov 17 Python
python实现决策树分类(2)
Aug 30 #Python
python实现决策树分类
Aug 30 #Python
python实现多人聊天室
Mar 31 #Python
Python实现将数据写入netCDF4中的方法示例
Aug 30 #Python
Python使用爬虫抓取美女图片并保存到本地的方法【测试可用】
Aug 30 #Python
Python使用一行代码获取上个月是几月
Aug 30 #Python
Python实现的读取/更改/写入xml文件操作示例
Aug 30 #Python
You might like
discuz7 phpMysql操作类
2009/06/21 PHP
PHP函数篇之掌握ord()与chr()函数应用
2011/12/05 PHP
php中session使用示例
2014/03/29 PHP
Thinkphp模板中截取字符串函数简介
2014/06/17 PHP
PHP bin2hex()函数基础实例讲解
2019/02/11 PHP
Mootools 1.2教程 滑动效果(Slide)
2009/09/15 Javascript
24款非常有用的 jQuery 插件分享
2011/04/06 Javascript
使用jquery animate创建平滑滚动效果(可以是到顶部、到底部或指定地方)
2014/05/27 Javascript
分享28款免费实用的 JQuery 图片和内容滑块插件
2014/12/15 Javascript
jQuery中index()方法用法实例
2014/12/27 Javascript
jQuery创建DOM元素实例解析
2015/01/19 Javascript
jQuery中trigger()方法用法实例
2015/01/19 Javascript
JavaScript按值删除数组元素的方法
2015/04/24 Javascript
在jQuery中使用$而避免跟其它库产生冲突的方法
2015/08/13 Javascript
判断JS对象是否拥有某属性的方法推荐
2016/05/12 Javascript
Vue.JS入门教程之处理表单
2016/12/01 Javascript
从零开始用webpack构建一个vue3.0项目工程的实现
2020/09/24 Javascript
Python 正则表达式(转义问题)
2014/12/15 Python
用Python编写web API的教程
2015/04/30 Python
实例讲解Python中SocketServer模块处理网络请求的用法
2016/06/28 Python
在cmd命令行里进入和退出Python程序的方法
2018/05/12 Python
零基础使用Python读写处理Excel表格的方法
2019/05/02 Python
Python单元测试工具doctest和unittest使用解析
2019/09/02 Python
ALEX AND ANI:手镯,项链,耳环和更多
2017/04/20 全球购物
美体小铺美国官网:The Body Shop美国
2017/11/10 全球购物
匈牙利墨盒和碳粉购买网站:CDRmarket
2018/04/14 全球购物
德国内衣、泳装和睡衣网上商店:Bigsize Dessous
2018/07/09 全球购物
Intersport西班牙:在线体育商店
2019/11/06 全球购物
英国领先的鞋类零售商和顶级品牌的官方零售商:Wynsors
2020/02/17 全球购物
法律工作求职自荐信
2013/10/31 职场文书
中班幼儿评语大全
2014/04/30 职场文书
超市督导岗位职责
2015/04/10 职场文书
2016高考感言
2015/08/01 职场文书
nginx常用命令放入shell脚本详解
2021/03/31 Servers
Oracle安装TNS_ADMIN环境变量设置参考
2021/11/01 Oracle
MySQL学习之基础操作总结
2022/03/19 MySQL