python+numpy+matplotalib实现梯度下降法


Posted in Python onAugust 31, 2018

这个阶段一直在做和梯度一类算法相关的东西,索性在这儿做个汇总:

一、算法论述

梯度下降法(gradient  descent)别名最速下降法(曾经我以为这是两个不同的算法-.-),是用来求解无约束最优化问题的一种常用算法。下面以求解线性回归为题来叙述:

设:一般的线性回归方程(拟合函数)为:(其中python+numpy+matplotalib实现梯度下降法的值为1)

python+numpy+matplotalib实现梯度下降法

python+numpy+matplotalib实现梯度下降法这一组向量参数选择的好与坏就需要一个机制来评估,据此我们提出了其损失函数为(选择均方误差):

python+numpy+matplotalib实现梯度下降法

我们现在的目的就是使得损失函数python+numpy+matplotalib实现梯度下降法取得最小值,即目标函数为:

python+numpy+matplotalib实现梯度下降法

如果python+numpy+matplotalib实现梯度下降法的值取到了0,意味着我们构造出了极好的拟合函数,也即选择出了最好的python+numpy+matplotalib实现梯度下降法值,但这基本是达不到的,我们只能使得其无限的接近于0,当满足一定精度时停止迭代。

那么问题来了如何调整python+numpy+matplotalib实现梯度下降法使得python+numpy+matplotalib实现梯度下降法取得的值越来越小呢?方法很多,此处以梯度下降法为例:

分为两步:(1)初始化python+numpy+matplotalib实现梯度下降法的值。

(2)改变python+numpy+matplotalib实现梯度下降法的值,使得python+numpy+matplotalib实现梯度下降法按梯度下降的方向减少。

python+numpy+matplotalib实现梯度下降法值的更新使用如下的方式来完成:

python+numpy+matplotalib实现梯度下降法

python+numpy+matplotalib实现梯度下降法

其中python+numpy+matplotalib实现梯度下降法为步长因子,这里我们取定值,但注意如果python+numpy+matplotalib实现梯度下降法取得过小会导致收敛速度过慢,python+numpy+matplotalib实现梯度下降法过大则损失函数可能不会收敛,甚至逐渐变大,可以在下述的代码中修改python+numpy+matplotalib实现梯度下降法的值来进行验证。后面我会再写一篇关于随机梯度下降法的文章,其实与梯度下降法最大的不同就在于一个求和符号。

二、代码实现

import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import axes3d
from matplotlib import style
 
 
#构造数据
def get_data(sample_num=10000):
 """
 拟合函数为
 y = 5*x1 + 7*x2
 :return:
 """
 x1 = np.linspace(0, 9, sample_num)
 x2 = np.linspace(4, 13, sample_num)
 x = np.concatenate(([x1], [x2]), axis=0).T
 y = np.dot(x, np.array([5, 7]).T) 
 return x, y
#梯度下降法
def GD(samples, y, step_size=0.01, max_iter_count=1000):
 """
 :param samples: 样本
 :param y: 结果value
 :param step_size: 每一接迭代的步长
 :param max_iter_count: 最大的迭代次数
 :param batch_size: 随机选取的相对于总样本的大小
 :return:
 """
 #确定样本数量以及变量的个数初始化theta值
 m, var = samples.shape
 theta = np.zeros(2)
 y = y.flatten()
 #进入循环内
 print(samples)
 loss = 1
 iter_count = 0
 iter_list=[]
 loss_list=[]
 theta1=[]
 theta2=[]
 #当损失精度大于0.01且迭代此时小于最大迭代次数时,进行
 while loss > 0.001 and iter_count < max_iter_count:
 loss = 0
 #梯度计算
 theta1.append(theta[0])
 theta2.append(theta[1])
 for i in range(m):
  h = np.dot(theta,samples[i].T) 
 #更新theta的值,需要的参量有:步长,梯度
  for j in range(len(theta)):
  theta[j] = theta[j] - step_size*(1/m)*(h - y[i])*samples[i,j]
 #计算总体的损失精度,等于各个样本损失精度之和
 for i in range(m):
  h = np.dot(theta.T, samples[i])
  #每组样本点损失的精度
  every_loss = (1/(var*m))*np.power((h - y[i]), 2)
  loss = loss + every_loss
 
 print("iter_count: ", iter_count, "the loss:", loss)
 
 iter_list.append(iter_count)
 loss_list.append(loss)
 
 iter_count += 1
 plt.plot(iter_list,loss_list)
 plt.xlabel("iter")
 plt.ylabel("loss")
 plt.show()
 return theta1,theta2,theta,loss_list
def painter3D(theta1,theta2,loss):
 style.use('ggplot')
 fig = plt.figure()
 ax1 = fig.add_subplot(111, projection='3d')
 x,y,z = theta1,theta2,loss
 ax1.plot_wireframe(x,y,z, rstride=5, cstride=5)
 ax1.set_xlabel("theta1")
 ax1.set_ylabel("theta2")
 ax1.set_zlabel("loss")
 plt.show()
def predict(x, theta):
 y = np.dot(theta, x.T)
 return y 
if __name__ == '__main__':
 samples, y = get_data()
 theta1,theta2,theta,loss_list = GD(samples, y)
 print(theta) # 会很接近[5, 7] 
 painter3D(theta1,theta2,loss_list)
 predict_y = predict(theta, [7,8])
 print(predict_y)

三、绘制的图像如下:

迭代次数与损失精度间的关系图如下:步长为0.01

python+numpy+matplotalib实现梯度下降法

变量python+numpy+matplotalib实现梯度下降法python+numpy+matplotalib实现梯度下降法与损失函数loss之间的关系:(从初始化之后会一步步收敛到loss满足精度,之后python+numpy+matplotalib实现梯度下降法python+numpy+matplotalib实现梯度下降法会变的稳定下来)

python+numpy+matplotalib实现梯度下降法

下面我们来看一副当步长因子变大后的图像:步长因子为0.5(很明显其收敛速度变缓了)

python+numpy+matplotalib实现梯度下降法

python+numpy+matplotalib实现梯度下降法

当步长因子设置为1.8左右时,其损失值已经开始震荡

python+numpy+matplotalib实现梯度下降法

python+numpy+matplotalib实现梯度下降法

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
深入理解python函数递归和生成器
Jun 06 Python
Windows下安装python MySQLdb遇到的问题及解决方法
Mar 16 Python
Apache如何部署django项目
May 21 Python
Python基于贪心算法解决背包问题示例
Nov 27 Python
使用python爬虫实现网络股票信息爬取的demo
Jan 05 Python
Python用imghdr模块识别图片格式实例解析
Jan 11 Python
python设置值及NaN值处理方法
Jul 03 Python
python控制nao机器人身体动作实例详解
Apr 29 Python
Python切图九宫格的实现方法
Oct 10 Python
Python搭建Keras CNN模型破解网站验证码的实现
Apr 07 Python
Python实现简单的猜单词小游戏
Oct 28 Python
pyqt5打包成exe可执行文件的方法
May 14 Python
python实现随机梯度下降法
Mar 24 #Python
python实现决策树分类(2)
Aug 30 #Python
python实现决策树分类
Aug 30 #Python
python实现多人聊天室
Mar 31 #Python
Python实现将数据写入netCDF4中的方法示例
Aug 30 #Python
Python使用爬虫抓取美女图片并保存到本地的方法【测试可用】
Aug 30 #Python
Python使用一行代码获取上个月是几月
Aug 30 #Python
You might like
php 获取客户端的真实ip
2009/11/30 PHP
PHP GD库生成图像的几个函数总结
2014/11/19 PHP
玩转jQuery按钮 请告诉我你最喜欢哪些?
2012/01/08 Javascript
Array栈方法和队列方法的特点说明
2014/01/24 Javascript
JavaScript中扩展Array contains方法实例
2020/08/23 Javascript
js实现省份下拉菜单效果
2017/02/15 Javascript
JavaScript使用FileReader实现图片上传预览效果
2020/03/27 Javascript
vue使用$emit时,父组件无法监听到子组件的事件实例
2018/02/26 Javascript
Vue.js 表单控件操作小结
2018/03/29 Javascript
vue项目base64字符串转图片的实现代码
2018/07/13 Javascript
webpack4+react多页面架构的实现
2018/10/25 Javascript
基于vue框架手写一个notify插件实现通知功能的方法
2019/03/31 Javascript
Vue 理解之白话 getter/setter详解
2019/04/16 Javascript
laravel-admin 与 vue 结合使用实例代码详解
2019/06/04 Javascript
解决layui-table单元格设置为百分比在ie8下不能自适应的问题
2019/09/28 Javascript
Vue关于组件化开发知识点详解
2020/05/13 Javascript
小程序选项卡以及swiper套用(跨页面)
2020/06/19 Javascript
[05:13]2018DOTA2亚洲邀请赛主赛事第二日战况回顾 LGD、VG双雄携手晋级
2018/04/05 DOTA
Python语言描述最大连续子序列和
2017/12/05 Python
Python3.6.0+opencv3.3.0人脸检测示例
2018/05/25 Python
Python把csv数据写入list和字典类型的变量脚本方法
2018/06/15 Python
Python2.7版os.path.isdir中文路径返回false的解决方法
2019/06/21 Python
django云端留言板实例详解
2019/07/22 Python
django 类视图的使用方法详解
2019/07/24 Python
opencv 图像加法与图像融合的实现代码
2020/07/08 Python
HTML5 Web 存储详解
2016/09/16 HTML / CSS
日本快乐生活方式购物网站:Shop Japan
2018/07/17 全球购物
我的画教学反思
2014/04/28 职场文书
文员求职信
2014/07/15 职场文书
立志成才演讲稿
2014/09/04 职场文书
深入开展党的群众路线教育实践活动心得体会
2014/11/05 职场文书
2014年六五普法工作总结
2014/11/25 职场文书
学校元旦晚会开场白
2014/12/14 职场文书
公司仓库管理制度
2015/08/04 职场文书
windows11怎么查看wifi密码? win11查看wifi密码的技巧
2021/11/21 数码科技
使用CSS实现六边形的图片效果
2022/08/05 HTML / CSS