python实现随机梯度下降法


Posted in Python onMarch 24, 2020

看这篇文章前强烈建议你看看上一篇python实现梯度下降法:

一、为什么要提出随机梯度下降算法

注意看梯度下降法权值的更新方式(推导过程在上一篇文章中有)

python实现随机梯度下降法

 也就是说每次更新权值python实现随机梯度下降法都需要遍历整个数据集(注意那个求和符号),当数据量小的时候,我们还能够接受这种算法,一旦数据量过大,那么使用该方法会使得收敛过程极度缓慢,并且当存在多个局部极小值时,无法保证搜索到全局最优解。为了解决这样的问题,引入了梯度下降法的进阶形式:随机梯度下降法。

二、核心思想

对于权值的更新不再通过遍历全部的数据集,而是选择其中的一个样本即可(对于程序员来说你的第一反应一定是:在这里需要一个随机函数来选择一个样本,不是吗?),一般来说其步长的选择比梯度下降法的步长要小一点,因为梯度下降法使用的是准确梯度,所以它可以朝着全局最优解(当问题为凸问题时)较大幅度的迭代下去,但是随机梯度法不行,因为它使用的是近似梯度,或者对于全局来说有时候它走的也许根本不是梯度下降的方向,故而它走的比较缓,同样这样带来的好处就是相比于梯度下降法,它不是那么容易陷入到局部最优解中去。

三、权值更新方式

python实现随机梯度下降法

(i表示样本标号下标,j表示样本维数下标)

四、代码实现(大体与梯度下降法相同,不同在于while循环中的内容)

import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import axes3d
from matplotlib import style
 
 
#构造数据
def get_data(sample_num=1000):
 """
 拟合函数为
 y = 5*x1 + 7*x2
 :return:
 """
 x1 = np.linspace(0, 9, sample_num)
 x2 = np.linspace(4, 13, sample_num)
 x = np.concatenate(([x1], [x2]), axis=0).T
 y = np.dot(x, np.array([5, 7]).T) 
 return x, y
#梯度下降法
def SGD(samples, y, step_size=2, max_iter_count=1000):
 """
 :param samples: 样本
 :param y: 结果value
 :param step_size: 每一接迭代的步长
 :param max_iter_count: 最大的迭代次数
 :param batch_size: 随机选取的相对于总样本的大小
 :return:
 """
 #确定样本数量以及变量的个数初始化theta值
 
 m, var = samples.shape
 theta = np.zeros(2)
 y = y.flatten()
 #进入循环内
 loss = 1
 iter_count = 0
 iter_list=[]
 loss_list=[]
 theta1=[]
 theta2=[]
 #当损失精度大于0.01且迭代此时小于最大迭代次数时,进行
 while loss > 0.01 and iter_count < max_iter_count:
 loss = 0
 #梯度计算
 theta1.append(theta[0])
 theta2.append(theta[1]) 
 #样本维数下标
 rand1 = np.random.randint(0,m,1)
 h = np.dot(theta,samples[rand1].T)
 #关键点,只需要一个样本点来更新权值
 for i in range(len(theta)):
 theta[i] =theta[i] - step_size*(1/m)*(h - y[rand1])*samples[rand1,i]
 #计算总体的损失精度,等于各个样本损失精度之和
 for i in range(m):
 h = np.dot(theta.T, samples[i])
 #每组样本点损失的精度
 every_loss = (1/(var*m))*np.power((h - y[i]), 2)
 loss = loss + every_loss
 
 print("iter_count: ", iter_count, "the loss:", loss)
 
 iter_list.append(iter_count)
 loss_list.append(loss)
 
 iter_count += 1
 plt.plot(iter_list,loss_list)
 plt.xlabel("iter")
 plt.ylabel("loss")
 plt.show()
 return theta1,theta2,theta,loss_list
 
def painter3D(theta1,theta2,loss):
 style.use('ggplot')
 fig = plt.figure()
 ax1 = fig.add_subplot(111, projection='3d')
 x,y,z = theta1,theta2,loss
 ax1.plot_wireframe(x,y,z, rstride=5, cstride=5)
 ax1.set_xlabel("theta1")
 ax1.set_ylabel("theta2")
 ax1.set_zlabel("loss")
 plt.show()
 
if __name__ == '__main__':
 samples, y = get_data()
 theta1,theta2,theta,loss_list = SGD(samples, y)
 print(theta) # 会很接近[5, 7]
 
 painter3D(theta1,theta2,loss_list)

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python实现调用其他python脚本的方法
Oct 05 Python
Python编程之序列操作实例详解
Jul 22 Python
Python实现简单求解给定整数的质因数算法示例
Mar 25 Python
解决python 输出是省略号的问题
Apr 19 Python
tensorflow使用神经网络实现mnist分类
Sep 08 Python
python 使用sys.stdin和fileinput读入标准输入的方法
Oct 17 Python
Python图像的增强处理操作示例【基于ImageEnhance类】
Jan 03 Python
使用Python进行体育竞技分析(预测球队成绩)
May 16 Python
Python利用pandas处理Excel数据的应用详解
Jun 18 Python
python 利用pyttsx3文字转语音过程详解
Sep 25 Python
Jupyter notebook 远程配置及SSL加密教程
Apr 14 Python
Python实现异步IO的示例
Nov 05 Python
python实现决策树分类(2)
Aug 30 #Python
python实现决策树分类
Aug 30 #Python
python实现多人聊天室
Mar 31 #Python
Python实现将数据写入netCDF4中的方法示例
Aug 30 #Python
Python使用爬虫抓取美女图片并保存到本地的方法【测试可用】
Aug 30 #Python
Python使用一行代码获取上个月是几月
Aug 30 #Python
Python实现的读取/更改/写入xml文件操作示例
Aug 30 #Python
You might like
WindowsXP中快速配置Apache+PHP5+Mysql
2008/06/05 PHP
phpMyAdmin 安装及问题总结
2009/05/28 PHP
zend framework多模块多布局配置
2011/02/26 PHP
php cURL和Rolling cURL并发方式比较
2013/10/30 PHP
php中sprintf与printf函数用法区别解析
2014/02/17 PHP
php从文件夹随机读取文件的方法
2015/06/01 PHP
php简单实现单态设计模式的方法分析
2017/07/28 PHP
用js生产批量批处理执行命令
2008/07/28 Javascript
jQuery 常见开发使用技巧总结
2009/12/26 Javascript
js跟随滚动条滚动浮动代码
2009/12/31 Javascript
E3 tree 1.6在Firefox下显示问题的修复方法
2013/01/30 Javascript
json数据处理技巧(字段带空格、增加字段、排序等等)
2013/06/14 Javascript
Javascript学习笔记之相等符号与严格相等符号
2014/11/23 Javascript
nodejs事件的监听与触发的理解分析
2015/02/12 NodeJs
JavaScript中关联原型链属性特性
2016/02/13 Javascript
浅谈Sticky组件的改进实现
2016/03/22 Javascript
Vue渲染函数详解
2017/09/15 Javascript
node.js用fs.rename强制重命名或移动文件夹的方法
2017/12/27 Javascript
js将键值对字符串转为json字符串的方法
2018/03/30 Javascript
JavaScript反射与依赖注入实例详解
2018/05/29 Javascript
Vux+Axios拦截器增加loading的问题及实现方法
2018/11/08 Javascript
微信小程序实现判断是分享到群还是个人功能示例
2019/05/03 Javascript
jQuery属性选择器用法实例分析
2019/06/28 jQuery
深入浅析vue中cross-env的使用
2019/09/12 Javascript
[02:12]DOTA2英雄基础教程 变体精灵
2013/12/16 DOTA
Python爬虫的两套解析方法和四种爬虫实现过程
2018/07/20 Python
Django如何自定义分页
2018/09/25 Python
在python中实现调用可执行文件.exe的3种方法
2019/07/07 Python
Python 二叉树的层序建立与三种遍历实现详解
2019/07/29 Python
python中的函数递归和迭代原理解析
2019/11/14 Python
python实现ping命令小程序
2020/12/28 Python
美国婚戒购物网站:Anjays Designs
2017/06/28 全球购物
人力资源管理毕业求职信
2014/08/05 职场文书
2014年绩效考核工作总结
2014/12/11 职场文书
小学运动会通讯稿
2015/07/18 职场文书
2016大学迎新晚会开场白
2015/11/24 职场文书