python+numpy+matplotalib实现梯度下降法


Posted in Python onAugust 31, 2018

这个阶段一直在做和梯度一类算法相关的东西,索性在这儿做个汇总:

一、算法论述

梯度下降法(gradient  descent)别名最速下降法(曾经我以为这是两个不同的算法-.-),是用来求解无约束最优化问题的一种常用算法。下面以求解线性回归为题来叙述:

设:一般的线性回归方程(拟合函数)为:(其中python+numpy+matplotalib实现梯度下降法的值为1)

python+numpy+matplotalib实现梯度下降法

python+numpy+matplotalib实现梯度下降法这一组向量参数选择的好与坏就需要一个机制来评估,据此我们提出了其损失函数为(选择均方误差):

python+numpy+matplotalib实现梯度下降法

我们现在的目的就是使得损失函数python+numpy+matplotalib实现梯度下降法取得最小值,即目标函数为:

python+numpy+matplotalib实现梯度下降法

如果python+numpy+matplotalib实现梯度下降法的值取到了0,意味着我们构造出了极好的拟合函数,也即选择出了最好的python+numpy+matplotalib实现梯度下降法值,但这基本是达不到的,我们只能使得其无限的接近于0,当满足一定精度时停止迭代。

那么问题来了如何调整python+numpy+matplotalib实现梯度下降法使得python+numpy+matplotalib实现梯度下降法取得的值越来越小呢?方法很多,此处以梯度下降法为例:

分为两步:(1)初始化python+numpy+matplotalib实现梯度下降法的值。

(2)改变python+numpy+matplotalib实现梯度下降法的值,使得python+numpy+matplotalib实现梯度下降法按梯度下降的方向减少。

python+numpy+matplotalib实现梯度下降法值的更新使用如下的方式来完成:

python+numpy+matplotalib实现梯度下降法

python+numpy+matplotalib实现梯度下降法

其中python+numpy+matplotalib实现梯度下降法为步长因子,这里我们取定值,但注意如果python+numpy+matplotalib实现梯度下降法取得过小会导致收敛速度过慢,python+numpy+matplotalib实现梯度下降法过大则损失函数可能不会收敛,甚至逐渐变大,可以在下述的代码中修改python+numpy+matplotalib实现梯度下降法的值来进行验证。后面我会再写一篇关于随机梯度下降法的文章,其实与梯度下降法最大的不同就在于一个求和符号。

二、代码实现

import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import axes3d
from matplotlib import style
 
 
#构造数据
def get_data(sample_num=10000):
 """
 拟合函数为
 y = 5*x1 + 7*x2
 :return:
 """
 x1 = np.linspace(0, 9, sample_num)
 x2 = np.linspace(4, 13, sample_num)
 x = np.concatenate(([x1], [x2]), axis=0).T
 y = np.dot(x, np.array([5, 7]).T) 
 return x, y
#梯度下降法
def GD(samples, y, step_size=0.01, max_iter_count=1000):
 """
 :param samples: 样本
 :param y: 结果value
 :param step_size: 每一接迭代的步长
 :param max_iter_count: 最大的迭代次数
 :param batch_size: 随机选取的相对于总样本的大小
 :return:
 """
 #确定样本数量以及变量的个数初始化theta值
 m, var = samples.shape
 theta = np.zeros(2)
 y = y.flatten()
 #进入循环内
 print(samples)
 loss = 1
 iter_count = 0
 iter_list=[]
 loss_list=[]
 theta1=[]
 theta2=[]
 #当损失精度大于0.01且迭代此时小于最大迭代次数时,进行
 while loss > 0.001 and iter_count < max_iter_count:
 loss = 0
 #梯度计算
 theta1.append(theta[0])
 theta2.append(theta[1])
 for i in range(m):
  h = np.dot(theta,samples[i].T) 
 #更新theta的值,需要的参量有:步长,梯度
  for j in range(len(theta)):
  theta[j] = theta[j] - step_size*(1/m)*(h - y[i])*samples[i,j]
 #计算总体的损失精度,等于各个样本损失精度之和
 for i in range(m):
  h = np.dot(theta.T, samples[i])
  #每组样本点损失的精度
  every_loss = (1/(var*m))*np.power((h - y[i]), 2)
  loss = loss + every_loss
 
 print("iter_count: ", iter_count, "the loss:", loss)
 
 iter_list.append(iter_count)
 loss_list.append(loss)
 
 iter_count += 1
 plt.plot(iter_list,loss_list)
 plt.xlabel("iter")
 plt.ylabel("loss")
 plt.show()
 return theta1,theta2,theta,loss_list
def painter3D(theta1,theta2,loss):
 style.use('ggplot')
 fig = plt.figure()
 ax1 = fig.add_subplot(111, projection='3d')
 x,y,z = theta1,theta2,loss
 ax1.plot_wireframe(x,y,z, rstride=5, cstride=5)
 ax1.set_xlabel("theta1")
 ax1.set_ylabel("theta2")
 ax1.set_zlabel("loss")
 plt.show()
def predict(x, theta):
 y = np.dot(theta, x.T)
 return y 
if __name__ == '__main__':
 samples, y = get_data()
 theta1,theta2,theta,loss_list = GD(samples, y)
 print(theta) # 会很接近[5, 7] 
 painter3D(theta1,theta2,loss_list)
 predict_y = predict(theta, [7,8])
 print(predict_y)

三、绘制的图像如下:

迭代次数与损失精度间的关系图如下:步长为0.01

python+numpy+matplotalib实现梯度下降法

变量python+numpy+matplotalib实现梯度下降法python+numpy+matplotalib实现梯度下降法与损失函数loss之间的关系:(从初始化之后会一步步收敛到loss满足精度,之后python+numpy+matplotalib实现梯度下降法python+numpy+matplotalib实现梯度下降法会变的稳定下来)

python+numpy+matplotalib实现梯度下降法

下面我们来看一副当步长因子变大后的图像:步长因子为0.5(很明显其收敛速度变缓了)

python+numpy+matplotalib实现梯度下降法

python+numpy+matplotalib实现梯度下降法

当步长因子设置为1.8左右时,其损失值已经开始震荡

python+numpy+matplotalib实现梯度下降法

python+numpy+matplotalib实现梯度下降法

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python中文编码那些事
Jun 25 Python
python在windows和linux下获得本机本地ip地址方法小结
Mar 20 Python
21行Python代码实现拼写检查器
Jan 25 Python
一个基于flask的web应用诞生 组织结构调整(7)
Apr 11 Python
pip matplotlib报错equired packages can not be built解决
Jan 06 Python
Python实现全排列的打印
Aug 18 Python
Scrapy框架使用的基本知识
Oct 21 Python
python腾讯语音合成实现过程解析
Aug 01 Python
Python2比较当前图片跟图库哪个图片相似的方法示例
Sep 28 Python
150行Python代码实现带界面的数独游戏
Apr 04 Python
python读取hdfs上的parquet文件方式
Jun 06 Python
如何在Anaconda中打开python自带idle
Sep 21 Python
python实现随机梯度下降法
Mar 24 #Python
python实现决策树分类(2)
Aug 30 #Python
python实现决策树分类
Aug 30 #Python
python实现多人聊天室
Mar 31 #Python
Python实现将数据写入netCDF4中的方法示例
Aug 30 #Python
Python使用爬虫抓取美女图片并保存到本地的方法【测试可用】
Aug 30 #Python
Python使用一行代码获取上个月是几月
Aug 30 #Python
You might like
php中限制ip段访问、禁止ip提交表单的代码分享
2014/08/22 PHP
PHP实现浏览器中直接输出图片的方法示例
2018/03/14 PHP
控制打印时页眉角的代码
2007/02/08 Javascript
编写js扩展方法判断一个数组中是否包含某个元素
2013/11/08 Javascript
js数组的操作指南
2014/12/28 Javascript
JavaScript基础知识点归纳(推荐)
2016/07/09 Javascript
jQuery实现简洁的轮播图效果实例
2016/09/07 Javascript
javascript函数中的3个高级技巧
2016/09/22 Javascript
JavaScript实现通过select标签跳转网页的方法
2016/09/29 Javascript
vue.js表格组件开发的实例详解
2016/10/12 Javascript
详解webpack打包vue时提取css
2017/05/26 Javascript
vue-router路由与页面间导航实例解析
2017/11/07 Javascript
详解Vue中的基本语法和常用指令
2019/07/23 Javascript
用Python编程实现语音控制电脑
2014/04/01 Python
python读取html中指定元素生成excle文件示例
2014/04/03 Python
Python编写一个闹钟功能
2017/07/11 Python
利用信号如何监控Django模型对象字段值的变化详解
2017/11/27 Python
Python中时间datetime的处理与转换用法总结
2019/02/18 Python
Python实现非正太分布的异常值检测方式
2019/12/09 Python
如何使用python3获取当前路径及os.path.dirname的使用
2019/12/13 Python
Django {{ MEDIA_URL }}无法显示图片的解决方式
2020/04/07 Python
如何使用python切换hosts文件
2020/04/29 Python
前后端结合实现amazeUI分页效果
2020/08/21 HTML / CSS
日常奢侈品,轻松购物:Verishop
2019/08/20 全球购物
触发器(trigger)的功能都有哪些?写出一个触发器的例子
2012/09/17 面试题
应届生个人求职信模板
2013/11/26 职场文书
数学高效课堂实施方案
2014/03/29 职场文书
入党积极分子批评与自我批评思想汇报
2014/09/14 职场文书
企业办公室主任岗位职责
2015/04/01 职场文书
教研活动主持词
2015/07/03 职场文书
公司人力资源管理制度
2015/08/05 职场文书
干部考核工作总结
2015/08/12 职场文书
《狼牙山五壮士》教学反思
2016/02/17 职场文书
Python爬虫实战之爬取京东商品数据并实实现数据可视化
2021/06/07 Python
使用Python解决图表与画布的间距问题
2022/04/11 Python
Python序列化模块JSON与Pickle
2022/06/05 Python