python实现随机梯度下降法


Posted in Python onMarch 24, 2020

看这篇文章前强烈建议你看看上一篇python实现梯度下降法:

一、为什么要提出随机梯度下降算法

注意看梯度下降法权值的更新方式(推导过程在上一篇文章中有)

python实现随机梯度下降法

 也就是说每次更新权值python实现随机梯度下降法都需要遍历整个数据集(注意那个求和符号),当数据量小的时候,我们还能够接受这种算法,一旦数据量过大,那么使用该方法会使得收敛过程极度缓慢,并且当存在多个局部极小值时,无法保证搜索到全局最优解。为了解决这样的问题,引入了梯度下降法的进阶形式:随机梯度下降法。

二、核心思想

对于权值的更新不再通过遍历全部的数据集,而是选择其中的一个样本即可(对于程序员来说你的第一反应一定是:在这里需要一个随机函数来选择一个样本,不是吗?),一般来说其步长的选择比梯度下降法的步长要小一点,因为梯度下降法使用的是准确梯度,所以它可以朝着全局最优解(当问题为凸问题时)较大幅度的迭代下去,但是随机梯度法不行,因为它使用的是近似梯度,或者对于全局来说有时候它走的也许根本不是梯度下降的方向,故而它走的比较缓,同样这样带来的好处就是相比于梯度下降法,它不是那么容易陷入到局部最优解中去。

三、权值更新方式

python实现随机梯度下降法

(i表示样本标号下标,j表示样本维数下标)

四、代码实现(大体与梯度下降法相同,不同在于while循环中的内容)

import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import axes3d
from matplotlib import style
 
 
#构造数据
def get_data(sample_num=1000):
 """
 拟合函数为
 y = 5*x1 + 7*x2
 :return:
 """
 x1 = np.linspace(0, 9, sample_num)
 x2 = np.linspace(4, 13, sample_num)
 x = np.concatenate(([x1], [x2]), axis=0).T
 y = np.dot(x, np.array([5, 7]).T) 
 return x, y
#梯度下降法
def SGD(samples, y, step_size=2, max_iter_count=1000):
 """
 :param samples: 样本
 :param y: 结果value
 :param step_size: 每一接迭代的步长
 :param max_iter_count: 最大的迭代次数
 :param batch_size: 随机选取的相对于总样本的大小
 :return:
 """
 #确定样本数量以及变量的个数初始化theta值
 
 m, var = samples.shape
 theta = np.zeros(2)
 y = y.flatten()
 #进入循环内
 loss = 1
 iter_count = 0
 iter_list=[]
 loss_list=[]
 theta1=[]
 theta2=[]
 #当损失精度大于0.01且迭代此时小于最大迭代次数时,进行
 while loss > 0.01 and iter_count < max_iter_count:
 loss = 0
 #梯度计算
 theta1.append(theta[0])
 theta2.append(theta[1]) 
 #样本维数下标
 rand1 = np.random.randint(0,m,1)
 h = np.dot(theta,samples[rand1].T)
 #关键点,只需要一个样本点来更新权值
 for i in range(len(theta)):
 theta[i] =theta[i] - step_size*(1/m)*(h - y[rand1])*samples[rand1,i]
 #计算总体的损失精度,等于各个样本损失精度之和
 for i in range(m):
 h = np.dot(theta.T, samples[i])
 #每组样本点损失的精度
 every_loss = (1/(var*m))*np.power((h - y[i]), 2)
 loss = loss + every_loss
 
 print("iter_count: ", iter_count, "the loss:", loss)
 
 iter_list.append(iter_count)
 loss_list.append(loss)
 
 iter_count += 1
 plt.plot(iter_list,loss_list)
 plt.xlabel("iter")
 plt.ylabel("loss")
 plt.show()
 return theta1,theta2,theta,loss_list
 
def painter3D(theta1,theta2,loss):
 style.use('ggplot')
 fig = plt.figure()
 ax1 = fig.add_subplot(111, projection='3d')
 x,y,z = theta1,theta2,loss
 ax1.plot_wireframe(x,y,z, rstride=5, cstride=5)
 ax1.set_xlabel("theta1")
 ax1.set_ylabel("theta2")
 ax1.set_zlabel("loss")
 plt.show()
 
if __name__ == '__main__':
 samples, y = get_data()
 theta1,theta2,theta,loss_list = SGD(samples, y)
 print(theta) # 会很接近[5, 7]
 
 painter3D(theta1,theta2,loss_list)

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python模拟鼠标拖动操作的方法
Mar 11 Python
Python实现压缩与解压gzip大文件的方法
Sep 18 Python
python数据结构之链表的实例讲解
Jul 25 Python
Python标准模块--ContextManager上下文管理器的具体用法
Nov 27 Python
使用python实现knn算法
Dec 20 Python
python的Tqdm模块的使用
Jan 10 Python
解决pandas 作图无法显示中文的问题
May 24 Python
pandas值替换方法
Jul 10 Python
利用Python模拟登录pastebin.com的实现方法
Jul 12 Python
Python基于Dlib的人脸识别系统的实现
Feb 26 Python
python中turtle库的简单使用教程
Nov 11 Python
Python如何加载模型并查看网络
Jul 15 Python
python实现决策树分类(2)
Aug 30 #Python
python实现决策树分类
Aug 30 #Python
python实现多人聊天室
Mar 31 #Python
Python实现将数据写入netCDF4中的方法示例
Aug 30 #Python
Python使用爬虫抓取美女图片并保存到本地的方法【测试可用】
Aug 30 #Python
Python使用一行代码获取上个月是几月
Aug 30 #Python
Python实现的读取/更改/写入xml文件操作示例
Aug 30 #Python
You might like
php和数据库结合的一个简单的web实例 代码分析 (php初学者)
2011/07/28 PHP
基于递归实现的php树形菜单代码
2014/11/19 PHP
PHP MVC框架路由学习笔记
2016/03/02 PHP
php使用gearman进行任务分发操作实例详解
2020/02/26 PHP
js 蒙版进度条(结合图片)
2010/03/10 Javascript
jQuery1.6 类型判断实现代码
2011/09/01 Javascript
jquery 滚动条事件简单实例
2013/07/12 Javascript
js动态添加删除,后台取数据(示例代码)
2013/11/25 Javascript
关于js中for in的缺陷浅析
2013/12/02 Javascript
最短的IE判断var ie=!-[1,]分析
2014/05/28 Javascript
利用JavaScript的AngularJS库制作电子名片的方法
2015/06/18 Javascript
jQuery复制表单元素附源码分享效果演示
2015/09/30 Javascript
AngularJS转换响应内容
2016/01/27 Javascript
JS 滚动事件window.onscroll与position:fixed写兼容IE6的回到顶部组件
2016/10/10 Javascript
Bootstrop实现多级下拉菜单功能
2016/11/24 Javascript
BootStrap模态框和select2合用时input无法获取焦点的解决方法
2017/09/01 Javascript
利用js实现简易红绿灯
2020/10/15 Javascript
python根据出生年份简单计算生肖的方法
2015/03/27 Python
单链表反转python实现代码示例
2018/02/08 Python
如何用python整理附件
2018/05/13 Python
Python实现从SQL型数据库读写dataframe型数据的方法【基于pandas】
2019/03/18 Python
Python numpy线性代数用法实例解析
2019/11/15 Python
python GUI编程(Tkinter) 创建子窗口及在窗口上用图片绘图实例
2020/03/04 Python
python实现简单井字棋小游戏
2020/03/05 Python
Python实现自动装机功能案例分析
2020/10/22 Python
python实现ping命令小程序
2020/12/28 Python
新西兰领先的鞋类和靴子网上商城:Merchant 1948
2017/09/08 全球购物
意大利香水和化妆品购物网站:Parfimo.it
2019/10/06 全球购物
积极分子思想汇报
2014/01/04 职场文书
2014年两会学习心得范例
2014/03/17 职场文书
中医学专业自荐信范文
2014/04/01 职场文书
计划生育证明书写要求
2014/09/17 职场文书
2014光棍节大学生联谊活动方案
2014/10/10 职场文书
辛亥革命观后感
2015/06/02 职场文书
Nginx域名转发https访问的实现
2021/03/31 Servers
ConstraintValidator类如何实现自定义注解校验前端传参
2021/06/18 Java/Android