python实现随机梯度下降法


Posted in Python onMarch 24, 2020

看这篇文章前强烈建议你看看上一篇python实现梯度下降法:

一、为什么要提出随机梯度下降算法

注意看梯度下降法权值的更新方式(推导过程在上一篇文章中有)

python实现随机梯度下降法

 也就是说每次更新权值python实现随机梯度下降法都需要遍历整个数据集(注意那个求和符号),当数据量小的时候,我们还能够接受这种算法,一旦数据量过大,那么使用该方法会使得收敛过程极度缓慢,并且当存在多个局部极小值时,无法保证搜索到全局最优解。为了解决这样的问题,引入了梯度下降法的进阶形式:随机梯度下降法。

二、核心思想

对于权值的更新不再通过遍历全部的数据集,而是选择其中的一个样本即可(对于程序员来说你的第一反应一定是:在这里需要一个随机函数来选择一个样本,不是吗?),一般来说其步长的选择比梯度下降法的步长要小一点,因为梯度下降法使用的是准确梯度,所以它可以朝着全局最优解(当问题为凸问题时)较大幅度的迭代下去,但是随机梯度法不行,因为它使用的是近似梯度,或者对于全局来说有时候它走的也许根本不是梯度下降的方向,故而它走的比较缓,同样这样带来的好处就是相比于梯度下降法,它不是那么容易陷入到局部最优解中去。

三、权值更新方式

python实现随机梯度下降法

(i表示样本标号下标,j表示样本维数下标)

四、代码实现(大体与梯度下降法相同,不同在于while循环中的内容)

import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import axes3d
from matplotlib import style
 
 
#构造数据
def get_data(sample_num=1000):
 """
 拟合函数为
 y = 5*x1 + 7*x2
 :return:
 """
 x1 = np.linspace(0, 9, sample_num)
 x2 = np.linspace(4, 13, sample_num)
 x = np.concatenate(([x1], [x2]), axis=0).T
 y = np.dot(x, np.array([5, 7]).T) 
 return x, y
#梯度下降法
def SGD(samples, y, step_size=2, max_iter_count=1000):
 """
 :param samples: 样本
 :param y: 结果value
 :param step_size: 每一接迭代的步长
 :param max_iter_count: 最大的迭代次数
 :param batch_size: 随机选取的相对于总样本的大小
 :return:
 """
 #确定样本数量以及变量的个数初始化theta值
 
 m, var = samples.shape
 theta = np.zeros(2)
 y = y.flatten()
 #进入循环内
 loss = 1
 iter_count = 0
 iter_list=[]
 loss_list=[]
 theta1=[]
 theta2=[]
 #当损失精度大于0.01且迭代此时小于最大迭代次数时,进行
 while loss > 0.01 and iter_count < max_iter_count:
 loss = 0
 #梯度计算
 theta1.append(theta[0])
 theta2.append(theta[1]) 
 #样本维数下标
 rand1 = np.random.randint(0,m,1)
 h = np.dot(theta,samples[rand1].T)
 #关键点,只需要一个样本点来更新权值
 for i in range(len(theta)):
 theta[i] =theta[i] - step_size*(1/m)*(h - y[rand1])*samples[rand1,i]
 #计算总体的损失精度,等于各个样本损失精度之和
 for i in range(m):
 h = np.dot(theta.T, samples[i])
 #每组样本点损失的精度
 every_loss = (1/(var*m))*np.power((h - y[i]), 2)
 loss = loss + every_loss
 
 print("iter_count: ", iter_count, "the loss:", loss)
 
 iter_list.append(iter_count)
 loss_list.append(loss)
 
 iter_count += 1
 plt.plot(iter_list,loss_list)
 plt.xlabel("iter")
 plt.ylabel("loss")
 plt.show()
 return theta1,theta2,theta,loss_list
 
def painter3D(theta1,theta2,loss):
 style.use('ggplot')
 fig = plt.figure()
 ax1 = fig.add_subplot(111, projection='3d')
 x,y,z = theta1,theta2,loss
 ax1.plot_wireframe(x,y,z, rstride=5, cstride=5)
 ax1.set_xlabel("theta1")
 ax1.set_ylabel("theta2")
 ax1.set_zlabel("loss")
 plt.show()
 
if __name__ == '__main__':
 samples, y = get_data()
 theta1,theta2,theta,loss_list = SGD(samples, y)
 print(theta) # 会很接近[5, 7]
 
 painter3D(theta1,theta2,loss_list)

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python在linux中输出带颜色的文字的方法
Jun 19 Python
python批量导入数据进Elasticsearch的实例
May 30 Python
如何在Python中实现goto语句的方法
May 18 Python
Python增强赋值和共享引用注意事项小结
May 28 Python
python顺序执行多个py文件的方法
Jun 29 Python
python opencv将图片转为灰度图的方法示例
Jul 31 Python
Python 使用多属性来进行排序
Sep 01 Python
使用Python代码实现Linux中的ls遍历目录命令的实例代码
Sep 07 Python
用Python写一个自动木马程序
Sep 17 Python
利用python实现.dcm格式图像转为.jpg格式
Jan 13 Python
多个python文件调用logging模块报错误
Feb 12 Python
python利用Excel读取和存储测试数据完成接口自动化教程
Apr 30 Python
python实现决策树分类(2)
Aug 30 #Python
python实现决策树分类
Aug 30 #Python
python实现多人聊天室
Mar 31 #Python
Python实现将数据写入netCDF4中的方法示例
Aug 30 #Python
Python使用爬虫抓取美女图片并保存到本地的方法【测试可用】
Aug 30 #Python
Python使用一行代码获取上个月是几月
Aug 30 #Python
Python实现的读取/更改/写入xml文件操作示例
Aug 30 #Python
You might like
如何实现给定日期的若干天以后的日期
2006/10/09 PHP
FirePHP 推荐一款PHP调试工具
2011/04/23 PHP
php数组函数序列之array_unshift() 在数组开头插入一个或多个元素
2011/11/07 PHP
php获取服务器信息的实现代码
2013/02/04 PHP
PHP中执行MYSQL事务解决数据写入不完整等情况
2014/01/07 PHP
php5.4传引用时报错问题分析
2016/01/22 PHP
PHP观察者模式示例【Laravel框架中有用到】
2018/06/15 PHP
Ajax+PHP实现的模拟进度条功能示例
2019/02/11 PHP
Laravel5.1 框架路由基础详解
2020/01/04 PHP
php 使用ActiveMQ发送消息,与处理消息操作示例
2020/02/23 PHP
如何做到打开一个页面,过几分钟自动转到另一页面
2007/04/20 Javascript
js限制文本框只能输入整数或者带小数点的数字
2015/04/27 Javascript
BootStrap制作导航条实例代码
2016/05/06 Javascript
jQuery Easyui datagrid连续发送两次请求问题
2016/12/13 Javascript
webpack入门必知必会
2017/01/16 Javascript
微信小程序自定义组件封装及父子间组件传值的方法
2018/08/28 Javascript
详解vue挂载到dom上会发生什么
2019/01/20 Javascript
vue实现Excel文件的上传与下载功能的两种方式
2019/06/28 Javascript
监控微信小程序中的慢HTTP请求过程详解
2019/07/05 Javascript
[57:53]DOTA2上海特级锦标赛主赛事日 - 2 败者组第二轮#3OG VS VP
2016/03/03 DOTA
python获得文件创建时间和修改时间的方法
2015/06/30 Python
Python 实现引用其他.py文件中的类和类的方法
2018/04/29 Python
python3 requests中使用ip代理池随机生成ip的实例
2018/05/07 Python
Python+OpenCV实现旋转文本校正方式
2020/01/09 Python
html5 跨文档消息传输示例探讨
2013/04/01 HTML / CSS
适合各种场合的美食礼品:Harry & David
2016/08/03 全球购物
Groupon荷兰官方网站:高达70%的折扣
2019/11/01 全球购物
大学生收银员求职信分享
2014/01/02 职场文书
致跳高运动员广播稿
2014/01/13 职场文书
励志演讲稿200字
2014/08/21 职场文书
作风建设整改方案
2014/10/27 职场文书
安全生产工作汇报材料
2014/10/28 职场文书
全国劳模先进事迹材料(2016精选版)
2016/02/25 职场文书
2019年二手房买卖合同范本
2019/10/14 职场文书
解决Django transaction进行事务管理踩过的坑
2021/04/24 Python
SQL Server查询某个字段在哪些表中存在
2022/03/03 SQL Server