python+numpy+matplotalib实现梯度下降法


Posted in Python onAugust 31, 2018

这个阶段一直在做和梯度一类算法相关的东西,索性在这儿做个汇总:

一、算法论述

梯度下降法(gradient  descent)别名最速下降法(曾经我以为这是两个不同的算法-.-),是用来求解无约束最优化问题的一种常用算法。下面以求解线性回归为题来叙述:

设:一般的线性回归方程(拟合函数)为:(其中python+numpy+matplotalib实现梯度下降法的值为1)

python+numpy+matplotalib实现梯度下降法

python+numpy+matplotalib实现梯度下降法这一组向量参数选择的好与坏就需要一个机制来评估,据此我们提出了其损失函数为(选择均方误差):

python+numpy+matplotalib实现梯度下降法

我们现在的目的就是使得损失函数python+numpy+matplotalib实现梯度下降法取得最小值,即目标函数为:

python+numpy+matplotalib实现梯度下降法

如果python+numpy+matplotalib实现梯度下降法的值取到了0,意味着我们构造出了极好的拟合函数,也即选择出了最好的python+numpy+matplotalib实现梯度下降法值,但这基本是达不到的,我们只能使得其无限的接近于0,当满足一定精度时停止迭代。

那么问题来了如何调整python+numpy+matplotalib实现梯度下降法使得python+numpy+matplotalib实现梯度下降法取得的值越来越小呢?方法很多,此处以梯度下降法为例:

分为两步:(1)初始化python+numpy+matplotalib实现梯度下降法的值。

(2)改变python+numpy+matplotalib实现梯度下降法的值,使得python+numpy+matplotalib实现梯度下降法按梯度下降的方向减少。

python+numpy+matplotalib实现梯度下降法值的更新使用如下的方式来完成:

python+numpy+matplotalib实现梯度下降法

python+numpy+matplotalib实现梯度下降法

其中python+numpy+matplotalib实现梯度下降法为步长因子,这里我们取定值,但注意如果python+numpy+matplotalib实现梯度下降法取得过小会导致收敛速度过慢,python+numpy+matplotalib实现梯度下降法过大则损失函数可能不会收敛,甚至逐渐变大,可以在下述的代码中修改python+numpy+matplotalib实现梯度下降法的值来进行验证。后面我会再写一篇关于随机梯度下降法的文章,其实与梯度下降法最大的不同就在于一个求和符号。

二、代码实现

import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import axes3d
from matplotlib import style
 
 
#构造数据
def get_data(sample_num=10000):
 """
 拟合函数为
 y = 5*x1 + 7*x2
 :return:
 """
 x1 = np.linspace(0, 9, sample_num)
 x2 = np.linspace(4, 13, sample_num)
 x = np.concatenate(([x1], [x2]), axis=0).T
 y = np.dot(x, np.array([5, 7]).T) 
 return x, y
#梯度下降法
def GD(samples, y, step_size=0.01, max_iter_count=1000):
 """
 :param samples: 样本
 :param y: 结果value
 :param step_size: 每一接迭代的步长
 :param max_iter_count: 最大的迭代次数
 :param batch_size: 随机选取的相对于总样本的大小
 :return:
 """
 #确定样本数量以及变量的个数初始化theta值
 m, var = samples.shape
 theta = np.zeros(2)
 y = y.flatten()
 #进入循环内
 print(samples)
 loss = 1
 iter_count = 0
 iter_list=[]
 loss_list=[]
 theta1=[]
 theta2=[]
 #当损失精度大于0.01且迭代此时小于最大迭代次数时,进行
 while loss > 0.001 and iter_count < max_iter_count:
 loss = 0
 #梯度计算
 theta1.append(theta[0])
 theta2.append(theta[1])
 for i in range(m):
  h = np.dot(theta,samples[i].T) 
 #更新theta的值,需要的参量有:步长,梯度
  for j in range(len(theta)):
  theta[j] = theta[j] - step_size*(1/m)*(h - y[i])*samples[i,j]
 #计算总体的损失精度,等于各个样本损失精度之和
 for i in range(m):
  h = np.dot(theta.T, samples[i])
  #每组样本点损失的精度
  every_loss = (1/(var*m))*np.power((h - y[i]), 2)
  loss = loss + every_loss
 
 print("iter_count: ", iter_count, "the loss:", loss)
 
 iter_list.append(iter_count)
 loss_list.append(loss)
 
 iter_count += 1
 plt.plot(iter_list,loss_list)
 plt.xlabel("iter")
 plt.ylabel("loss")
 plt.show()
 return theta1,theta2,theta,loss_list
def painter3D(theta1,theta2,loss):
 style.use('ggplot')
 fig = plt.figure()
 ax1 = fig.add_subplot(111, projection='3d')
 x,y,z = theta1,theta2,loss
 ax1.plot_wireframe(x,y,z, rstride=5, cstride=5)
 ax1.set_xlabel("theta1")
 ax1.set_ylabel("theta2")
 ax1.set_zlabel("loss")
 plt.show()
def predict(x, theta):
 y = np.dot(theta, x.T)
 return y 
if __name__ == '__main__':
 samples, y = get_data()
 theta1,theta2,theta,loss_list = GD(samples, y)
 print(theta) # 会很接近[5, 7] 
 painter3D(theta1,theta2,loss_list)
 predict_y = predict(theta, [7,8])
 print(predict_y)

三、绘制的图像如下:

迭代次数与损失精度间的关系图如下:步长为0.01

python+numpy+matplotalib实现梯度下降法

变量python+numpy+matplotalib实现梯度下降法python+numpy+matplotalib实现梯度下降法与损失函数loss之间的关系:(从初始化之后会一步步收敛到loss满足精度,之后python+numpy+matplotalib实现梯度下降法python+numpy+matplotalib实现梯度下降法会变的稳定下来)

python+numpy+matplotalib实现梯度下降法

下面我们来看一副当步长因子变大后的图像:步长因子为0.5(很明显其收敛速度变缓了)

python+numpy+matplotalib实现梯度下降法

python+numpy+matplotalib实现梯度下降法

当步长因子设置为1.8左右时,其损失值已经开始震荡

python+numpy+matplotalib实现梯度下降法

python+numpy+matplotalib实现梯度下降法

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Windows下用py2exe将Python程序打包成exe程序的教程
Apr 08 Python
python执行等待程序直到第二天零点的方法
Apr 23 Python
windows及linux环境下永久修改pip镜像源的方法
Nov 28 Python
Windows安装Python、pip、easy_install的方法
Mar 05 Python
Python通过调用有道翻译api实现翻译功能示例
Jul 19 Python
Python图像的增强处理操作示例【基于ImageEnhance类】
Jan 03 Python
python在openstreetmap地图上绘制路线图的实现
Jul 11 Python
django连接mysql数据库及建表操作实例详解
Dec 10 Python
python 实现批量替换文本中的某部分内容
Dec 13 Python
在python中使用pymysql往mysql数据库中插入(insert)数据实例
Mar 02 Python
Python使用xpath实现图片爬取
Sep 16 Python
python3 os进行嵌套操作的实例讲解
Nov 19 Python
python实现随机梯度下降法
Mar 24 #Python
python实现决策树分类(2)
Aug 30 #Python
python实现决策树分类
Aug 30 #Python
python实现多人聊天室
Mar 31 #Python
Python实现将数据写入netCDF4中的方法示例
Aug 30 #Python
Python使用爬虫抓取美女图片并保存到本地的方法【测试可用】
Aug 30 #Python
Python使用一行代码获取上个月是几月
Aug 30 #Python
You might like
linux iconv方法的使用
2011/10/01 PHP
PHP5函数小全(分享)
2013/06/06 PHP
PHP设置图片文件上传大小的具体实现方法
2013/10/11 PHP
php处理restful请求的路由类分享
2014/02/27 PHP
php实现遍历目录并删除指定文件中指定内容
2015/01/21 PHP
PHP实现的蚂蚁爬杆路径算法代码
2015/12/03 PHP
php结合mysql与mysqli扩展处理事务的方法
2016/06/29 PHP
浅谈php中的访问修饰符private、protected、public的作用范围
2016/11/20 PHP
PHP AjaxForm提交图片上传并显示图片源码
2016/11/29 PHP
javascript 获取页面的高度及滚动条的位置的代码
2010/05/06 Javascript
jquery一句话全选/取消全选
2011/03/01 Javascript
基于JQuery的类似新浪微博展示信息效果的代码
2012/07/23 Javascript
Jquery图片滚动与幻灯片的实例代码
2013/04/08 Javascript
对js关键字命名的疑问介绍
2014/04/25 Javascript
jquery常用操作小结
2014/07/21 Javascript
javascript判断移动端访问设备并解析对应CSS的方法
2015/02/05 Javascript
JS实现浏览器状态栏文字闪烁效果的方法
2015/10/27 Javascript
JS中对象与字符串的互相转换详解
2016/05/20 Javascript
JS获取checkbox的个数简单实例
2016/08/19 Javascript
JavaScript模板引擎Template.js使用详解
2016/12/15 Javascript
VUE中v-model和v-for指令详解
2017/06/23 Javascript
Vue2 模板template的四种写法总结
2018/02/23 Javascript
[01:09:23]KG vs TNC 2019国际邀请赛小组赛 BO2 第一场 8.15
2019/08/16 DOTA
Python中的zipfile模块使用详解
2015/06/25 Python
详解MySQL数据类型int(M)中M的含义
2016/11/20 Python
django rest framework之请求与响应(详解)
2017/11/06 Python
python3解析库pyquery的深入讲解
2018/06/26 Python
python操作docx写入内容,并控制文本的字体颜色
2020/02/13 Python
深入理解Tensorflow中的masking和padding
2020/02/24 Python
详解python的super()的作用和原理
2020/10/29 Python
网上蛋糕店创业计划书
2014/01/24 职场文书
2014年基层党组织公开承诺书
2014/03/29 职场文书
怎样拟定创业计划书
2014/05/01 职场文书
新颖的化妆品活动方案
2014/08/21 职场文书
防灾减灾标语
2014/10/07 职场文书
电力企业职工培训心得体会
2016/01/11 职场文书