python+numpy+matplotalib实现梯度下降法


Posted in Python onAugust 31, 2018

这个阶段一直在做和梯度一类算法相关的东西,索性在这儿做个汇总:

一、算法论述

梯度下降法(gradient  descent)别名最速下降法(曾经我以为这是两个不同的算法-.-),是用来求解无约束最优化问题的一种常用算法。下面以求解线性回归为题来叙述:

设:一般的线性回归方程(拟合函数)为:(其中python+numpy+matplotalib实现梯度下降法的值为1)

python+numpy+matplotalib实现梯度下降法

python+numpy+matplotalib实现梯度下降法这一组向量参数选择的好与坏就需要一个机制来评估,据此我们提出了其损失函数为(选择均方误差):

python+numpy+matplotalib实现梯度下降法

我们现在的目的就是使得损失函数python+numpy+matplotalib实现梯度下降法取得最小值,即目标函数为:

python+numpy+matplotalib实现梯度下降法

如果python+numpy+matplotalib实现梯度下降法的值取到了0,意味着我们构造出了极好的拟合函数,也即选择出了最好的python+numpy+matplotalib实现梯度下降法值,但这基本是达不到的,我们只能使得其无限的接近于0,当满足一定精度时停止迭代。

那么问题来了如何调整python+numpy+matplotalib实现梯度下降法使得python+numpy+matplotalib实现梯度下降法取得的值越来越小呢?方法很多,此处以梯度下降法为例:

分为两步:(1)初始化python+numpy+matplotalib实现梯度下降法的值。

(2)改变python+numpy+matplotalib实现梯度下降法的值,使得python+numpy+matplotalib实现梯度下降法按梯度下降的方向减少。

python+numpy+matplotalib实现梯度下降法值的更新使用如下的方式来完成:

python+numpy+matplotalib实现梯度下降法

python+numpy+matplotalib实现梯度下降法

其中python+numpy+matplotalib实现梯度下降法为步长因子,这里我们取定值,但注意如果python+numpy+matplotalib实现梯度下降法取得过小会导致收敛速度过慢,python+numpy+matplotalib实现梯度下降法过大则损失函数可能不会收敛,甚至逐渐变大,可以在下述的代码中修改python+numpy+matplotalib实现梯度下降法的值来进行验证。后面我会再写一篇关于随机梯度下降法的文章,其实与梯度下降法最大的不同就在于一个求和符号。

二、代码实现

import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import axes3d
from matplotlib import style
 
 
#构造数据
def get_data(sample_num=10000):
 """
 拟合函数为
 y = 5*x1 + 7*x2
 :return:
 """
 x1 = np.linspace(0, 9, sample_num)
 x2 = np.linspace(4, 13, sample_num)
 x = np.concatenate(([x1], [x2]), axis=0).T
 y = np.dot(x, np.array([5, 7]).T) 
 return x, y
#梯度下降法
def GD(samples, y, step_size=0.01, max_iter_count=1000):
 """
 :param samples: 样本
 :param y: 结果value
 :param step_size: 每一接迭代的步长
 :param max_iter_count: 最大的迭代次数
 :param batch_size: 随机选取的相对于总样本的大小
 :return:
 """
 #确定样本数量以及变量的个数初始化theta值
 m, var = samples.shape
 theta = np.zeros(2)
 y = y.flatten()
 #进入循环内
 print(samples)
 loss = 1
 iter_count = 0
 iter_list=[]
 loss_list=[]
 theta1=[]
 theta2=[]
 #当损失精度大于0.01且迭代此时小于最大迭代次数时,进行
 while loss > 0.001 and iter_count < max_iter_count:
 loss = 0
 #梯度计算
 theta1.append(theta[0])
 theta2.append(theta[1])
 for i in range(m):
  h = np.dot(theta,samples[i].T) 
 #更新theta的值,需要的参量有:步长,梯度
  for j in range(len(theta)):
  theta[j] = theta[j] - step_size*(1/m)*(h - y[i])*samples[i,j]
 #计算总体的损失精度,等于各个样本损失精度之和
 for i in range(m):
  h = np.dot(theta.T, samples[i])
  #每组样本点损失的精度
  every_loss = (1/(var*m))*np.power((h - y[i]), 2)
  loss = loss + every_loss
 
 print("iter_count: ", iter_count, "the loss:", loss)
 
 iter_list.append(iter_count)
 loss_list.append(loss)
 
 iter_count += 1
 plt.plot(iter_list,loss_list)
 plt.xlabel("iter")
 plt.ylabel("loss")
 plt.show()
 return theta1,theta2,theta,loss_list
def painter3D(theta1,theta2,loss):
 style.use('ggplot')
 fig = plt.figure()
 ax1 = fig.add_subplot(111, projection='3d')
 x,y,z = theta1,theta2,loss
 ax1.plot_wireframe(x,y,z, rstride=5, cstride=5)
 ax1.set_xlabel("theta1")
 ax1.set_ylabel("theta2")
 ax1.set_zlabel("loss")
 plt.show()
def predict(x, theta):
 y = np.dot(theta, x.T)
 return y 
if __name__ == '__main__':
 samples, y = get_data()
 theta1,theta2,theta,loss_list = GD(samples, y)
 print(theta) # 会很接近[5, 7] 
 painter3D(theta1,theta2,loss_list)
 predict_y = predict(theta, [7,8])
 print(predict_y)

三、绘制的图像如下:

迭代次数与损失精度间的关系图如下:步长为0.01

python+numpy+matplotalib实现梯度下降法

变量python+numpy+matplotalib实现梯度下降法python+numpy+matplotalib实现梯度下降法与损失函数loss之间的关系:(从初始化之后会一步步收敛到loss满足精度,之后python+numpy+matplotalib实现梯度下降法python+numpy+matplotalib实现梯度下降法会变的稳定下来)

python+numpy+matplotalib实现梯度下降法

下面我们来看一副当步长因子变大后的图像:步长因子为0.5(很明显其收敛速度变缓了)

python+numpy+matplotalib实现梯度下降法

python+numpy+matplotalib实现梯度下降法

当步长因子设置为1.8左右时,其损失值已经开始震荡

python+numpy+matplotalib实现梯度下降法

python+numpy+matplotalib实现梯度下降法

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python2.x和3.x下maketrans与translate函数使用上的不同
Apr 13 Python
在Python的Django框架中显示对象子集的方法
Jul 21 Python
python采用django框架实现支付宝即时到帐接口
May 17 Python
python在html中插入简单的代码并加上时间戳的方法
Oct 16 Python
python在TXT文件中按照某一字符串取出该字符串所在的行方法
Dec 10 Python
Django ORM 聚合查询和分组查询实现详解
Aug 09 Python
Django Channel实时推送与聊天的示例代码
Apr 30 Python
Python如何在循环内使用list.remove()
Jun 01 Python
pycharm 添加解释器的方法步骤
Aug 31 Python
selenium自动化测试入门实战
Dec 21 Python
python pyg2plot的原理知识点总结
Feb 28 Python
Python进行区间取值案例讲解
Aug 02 Python
python实现随机梯度下降法
Mar 24 #Python
python实现决策树分类(2)
Aug 30 #Python
python实现决策树分类
Aug 30 #Python
python实现多人聊天室
Mar 31 #Python
Python实现将数据写入netCDF4中的方法示例
Aug 30 #Python
Python使用爬虫抓取美女图片并保存到本地的方法【测试可用】
Aug 30 #Python
Python使用一行代码获取上个月是几月
Aug 30 #Python
You might like
php变量作用域的深入解析
2013/06/03 PHP
php socket实现的聊天室代码分享
2014/08/16 PHP
Sublime里直接运行PHP配置方法
2014/11/28 PHP
PHP实现数据库的增删查改功能及完整代码
2018/04/18 PHP
PHP使用SMTP邮件服务器发送邮件示例
2018/08/28 PHP
JQuery动态创建DOM、表单元素的实现代码
2011/08/09 Javascript
Js实现双击鼠标自动滚动屏幕的示例代码
2013/12/14 Javascript
js实现鼠标点击左上角滑动菜单效果代码
2015/09/06 Javascript
基于javascript实现的购物商城商品倒计时实例
2016/12/11 Javascript
基于JS递归函数细化认识及实用实例(推荐)
2017/08/07 Javascript
Vue中使用clipboard实现复制功能
2018/09/05 Javascript
vue单页面应用打开新窗口显示跳转页面的实例
2018/09/21 Javascript
利用jquery和BootStrap实现动态滚动条效果
2018/12/03 jQuery
js实现跟随鼠标移动的小球
2019/08/26 Javascript
python提示No module named images的解决方法
2014/09/29 Python
在Python中操作字典之clear()方法的使用
2015/05/21 Python
Python卸载模块的方法汇总
2016/06/07 Python
Pycharm学习教程(3) 代码运行调试
2017/05/03 Python
Python实现简单http服务器
2018/04/12 Python
Python Pandas实现数据分组求平均值并填充nan的示例
2019/07/04 Python
Django应用程序入口WSGIHandler源码解析
2019/08/05 Python
python使用pygame实现笑脸乒乓球弹珠球游戏
2019/11/25 Python
Python爬取新型冠状病毒“谣言”新闻进行数据分析
2020/02/16 Python
Python3爬虫RedisDump的安装步骤
2021/02/20 Python
html5新特性与用法大全
2018/09/13 HTML / CSS
英国最大的在线蜡烛商店:Candles Direct
2019/03/26 全球购物
师范生实习自我鉴定
2013/11/01 职场文书
小学国庆节活动方案
2014/02/11 职场文书
《诺贝尔》教学反思
2014/02/17 职场文书
店面销售职位的职责
2014/03/09 职场文书
联欢晚会主持词
2014/03/25 职场文书
优秀的应届生自荐信
2014/05/23 职场文书
买卖合同协议书范本
2014/10/18 职场文书
复试通知单模板
2015/04/24 职场文书
关于antd tree 和父子组件之间的传值问题(react 总结)
2021/06/02 Javascript
uniapp 微信小程序 自定义tabBar 导航
2022/04/22 Javascript