python实现随机梯度下降法


Posted in Python onMarch 24, 2020

看这篇文章前强烈建议你看看上一篇python实现梯度下降法:

一、为什么要提出随机梯度下降算法

注意看梯度下降法权值的更新方式(推导过程在上一篇文章中有)

python实现随机梯度下降法

 也就是说每次更新权值python实现随机梯度下降法都需要遍历整个数据集(注意那个求和符号),当数据量小的时候,我们还能够接受这种算法,一旦数据量过大,那么使用该方法会使得收敛过程极度缓慢,并且当存在多个局部极小值时,无法保证搜索到全局最优解。为了解决这样的问题,引入了梯度下降法的进阶形式:随机梯度下降法。

二、核心思想

对于权值的更新不再通过遍历全部的数据集,而是选择其中的一个样本即可(对于程序员来说你的第一反应一定是:在这里需要一个随机函数来选择一个样本,不是吗?),一般来说其步长的选择比梯度下降法的步长要小一点,因为梯度下降法使用的是准确梯度,所以它可以朝着全局最优解(当问题为凸问题时)较大幅度的迭代下去,但是随机梯度法不行,因为它使用的是近似梯度,或者对于全局来说有时候它走的也许根本不是梯度下降的方向,故而它走的比较缓,同样这样带来的好处就是相比于梯度下降法,它不是那么容易陷入到局部最优解中去。

三、权值更新方式

python实现随机梯度下降法

(i表示样本标号下标,j表示样本维数下标)

四、代码实现(大体与梯度下降法相同,不同在于while循环中的内容)

import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import axes3d
from matplotlib import style
 
 
#构造数据
def get_data(sample_num=1000):
 """
 拟合函数为
 y = 5*x1 + 7*x2
 :return:
 """
 x1 = np.linspace(0, 9, sample_num)
 x2 = np.linspace(4, 13, sample_num)
 x = np.concatenate(([x1], [x2]), axis=0).T
 y = np.dot(x, np.array([5, 7]).T) 
 return x, y
#梯度下降法
def SGD(samples, y, step_size=2, max_iter_count=1000):
 """
 :param samples: 样本
 :param y: 结果value
 :param step_size: 每一接迭代的步长
 :param max_iter_count: 最大的迭代次数
 :param batch_size: 随机选取的相对于总样本的大小
 :return:
 """
 #确定样本数量以及变量的个数初始化theta值
 
 m, var = samples.shape
 theta = np.zeros(2)
 y = y.flatten()
 #进入循环内
 loss = 1
 iter_count = 0
 iter_list=[]
 loss_list=[]
 theta1=[]
 theta2=[]
 #当损失精度大于0.01且迭代此时小于最大迭代次数时,进行
 while loss > 0.01 and iter_count < max_iter_count:
 loss = 0
 #梯度计算
 theta1.append(theta[0])
 theta2.append(theta[1]) 
 #样本维数下标
 rand1 = np.random.randint(0,m,1)
 h = np.dot(theta,samples[rand1].T)
 #关键点,只需要一个样本点来更新权值
 for i in range(len(theta)):
 theta[i] =theta[i] - step_size*(1/m)*(h - y[rand1])*samples[rand1,i]
 #计算总体的损失精度,等于各个样本损失精度之和
 for i in range(m):
 h = np.dot(theta.T, samples[i])
 #每组样本点损失的精度
 every_loss = (1/(var*m))*np.power((h - y[i]), 2)
 loss = loss + every_loss
 
 print("iter_count: ", iter_count, "the loss:", loss)
 
 iter_list.append(iter_count)
 loss_list.append(loss)
 
 iter_count += 1
 plt.plot(iter_list,loss_list)
 plt.xlabel("iter")
 plt.ylabel("loss")
 plt.show()
 return theta1,theta2,theta,loss_list
 
def painter3D(theta1,theta2,loss):
 style.use('ggplot')
 fig = plt.figure()
 ax1 = fig.add_subplot(111, projection='3d')
 x,y,z = theta1,theta2,loss
 ax1.plot_wireframe(x,y,z, rstride=5, cstride=5)
 ax1.set_xlabel("theta1")
 ax1.set_ylabel("theta2")
 ax1.set_zlabel("loss")
 plt.show()
 
if __name__ == '__main__':
 samples, y = get_data()
 theta1,theta2,theta,loss_list = SGD(samples, y)
 print(theta) # 会很接近[5, 7]
 
 painter3D(theta1,theta2,loss_list)

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
详解Python中的strftime()方法的使用
May 22 Python
Python Sqlite3以字典形式返回查询结果的实现方法
Oct 03 Python
Pycharm学习教程(7)虚拟机VM的配置教程
May 04 Python
virtualenv实现多个版本Python共存
Aug 21 Python
Python发送http请求解析返回json的实例
Mar 26 Python
对pandas进行数据预处理的实例讲解
Apr 20 Python
python爬虫获取小区经纬度以及结构化地址
Dec 30 Python
将python安装信息加入注册表的示例
Nov 20 Python
Python Unittest原理及基本使用方法
Nov 06 Python
Django怎么在admin后台注册数据库表
Nov 14 Python
Django多个app urls配置代码实例
Nov 26 Python
python实现发送QQ邮件(可加附件)
Dec 23 Python
python实现决策树分类(2)
Aug 30 #Python
python实现决策树分类
Aug 30 #Python
python实现多人聊天室
Mar 31 #Python
Python实现将数据写入netCDF4中的方法示例
Aug 30 #Python
Python使用爬虫抓取美女图片并保存到本地的方法【测试可用】
Aug 30 #Python
Python使用一行代码获取上个月是几月
Aug 30 #Python
Python实现的读取/更改/写入xml文件操作示例
Aug 30 #Python
You might like
短波的认识
2021/03/01 无线电
杏林同学录(五)
2006/10/09 PHP
php下intval()和(int)转换使用与区别
2008/07/18 PHP
php htmlspecialchars加强版
2010/02/16 PHP
php smarty 二级分类代码和模版循环例子
2011/06/16 PHP
phpmyadmin config.inc.php配置示例
2013/08/27 PHP
php获取qq用户昵称和在线状态(实例分析)
2013/10/27 PHP
基于CI框架的微信网页授权库示例
2016/11/25 PHP
实例讲解PHP中使用命名空间
2019/01/27 PHP
PHP PDOStatement::errorCode讲解
2019/01/31 PHP
JS特殊函数(Function()构造函数、函数直接量)区别介绍
2013/05/19 Javascript
捕获键盘事件(且兼容各浏览器)
2013/07/03 Javascript
为何JS操作的href都是javascript:void(0);呢
2015/11/12 Javascript
Js 获取当前函数参数对象的实现代码
2016/06/20 Javascript
原生js获取iframe中dom元素--父子页面相互获取对方dom元素的方法
2016/08/05 Javascript
JavaScript仿支付宝6位数字密码输入框
2016/12/29 Javascript
解决URL地址中的中文乱码问题的办法
2017/02/10 Javascript
防止重复发送 Ajax 请求
2017/02/15 Javascript
jQuery日程管理控件glDatePicker用法详解
2017/03/29 jQuery
详解vue.js全局组件和局部组件
2017/04/10 Javascript
jquery ztree实现右键收藏功能
2017/11/20 jQuery
layui form表单提交后实现自动刷新
2019/10/25 Javascript
vue+Element中table表格实现可编辑(select下拉框)
2020/05/21 Javascript
如何在postman测试用例中实现断言过程解析
2020/07/09 Javascript
[53:23]Secret vs Liquid 2018国际邀请赛淘汰赛BO3 第二场 8.25
2018/08/29 DOTA
python登录pop3邮件服务器接收邮件的方法
2015/04/30 Python
python装饰器实例大详解
2017/10/25 Python
Python numpy实现二维数组和一维数组拼接的方法
2018/06/05 Python
pycharm 批量修改变量名称的方法
2019/08/01 Python
Python之Class&amp;Object用法详解
2019/12/25 Python
大专生自我鉴定范文
2013/10/01 职场文书
精通CAD能手自荐书
2014/01/31 职场文书
根叔历年演讲稿
2014/05/20 职场文书
红色影片观后感
2015/06/18 职场文书
党员干部学法用法心得体会
2016/01/21 职场文书
导游词之介休绵山
2019/12/31 职场文书