python+numpy+matplotalib实现梯度下降法


Posted in Python onAugust 31, 2018

这个阶段一直在做和梯度一类算法相关的东西,索性在这儿做个汇总:

一、算法论述

梯度下降法(gradient  descent)别名最速下降法(曾经我以为这是两个不同的算法-.-),是用来求解无约束最优化问题的一种常用算法。下面以求解线性回归为题来叙述:

设:一般的线性回归方程(拟合函数)为:(其中python+numpy+matplotalib实现梯度下降法的值为1)

python+numpy+matplotalib实现梯度下降法

python+numpy+matplotalib实现梯度下降法这一组向量参数选择的好与坏就需要一个机制来评估,据此我们提出了其损失函数为(选择均方误差):

python+numpy+matplotalib实现梯度下降法

我们现在的目的就是使得损失函数python+numpy+matplotalib实现梯度下降法取得最小值,即目标函数为:

python+numpy+matplotalib实现梯度下降法

如果python+numpy+matplotalib实现梯度下降法的值取到了0,意味着我们构造出了极好的拟合函数,也即选择出了最好的python+numpy+matplotalib实现梯度下降法值,但这基本是达不到的,我们只能使得其无限的接近于0,当满足一定精度时停止迭代。

那么问题来了如何调整python+numpy+matplotalib实现梯度下降法使得python+numpy+matplotalib实现梯度下降法取得的值越来越小呢?方法很多,此处以梯度下降法为例:

分为两步:(1)初始化python+numpy+matplotalib实现梯度下降法的值。

(2)改变python+numpy+matplotalib实现梯度下降法的值,使得python+numpy+matplotalib实现梯度下降法按梯度下降的方向减少。

python+numpy+matplotalib实现梯度下降法值的更新使用如下的方式来完成:

python+numpy+matplotalib实现梯度下降法

python+numpy+matplotalib实现梯度下降法

其中python+numpy+matplotalib实现梯度下降法为步长因子,这里我们取定值,但注意如果python+numpy+matplotalib实现梯度下降法取得过小会导致收敛速度过慢,python+numpy+matplotalib实现梯度下降法过大则损失函数可能不会收敛,甚至逐渐变大,可以在下述的代码中修改python+numpy+matplotalib实现梯度下降法的值来进行验证。后面我会再写一篇关于随机梯度下降法的文章,其实与梯度下降法最大的不同就在于一个求和符号。

二、代码实现

import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import axes3d
from matplotlib import style
 
 
#构造数据
def get_data(sample_num=10000):
 """
 拟合函数为
 y = 5*x1 + 7*x2
 :return:
 """
 x1 = np.linspace(0, 9, sample_num)
 x2 = np.linspace(4, 13, sample_num)
 x = np.concatenate(([x1], [x2]), axis=0).T
 y = np.dot(x, np.array([5, 7]).T) 
 return x, y
#梯度下降法
def GD(samples, y, step_size=0.01, max_iter_count=1000):
 """
 :param samples: 样本
 :param y: 结果value
 :param step_size: 每一接迭代的步长
 :param max_iter_count: 最大的迭代次数
 :param batch_size: 随机选取的相对于总样本的大小
 :return:
 """
 #确定样本数量以及变量的个数初始化theta值
 m, var = samples.shape
 theta = np.zeros(2)
 y = y.flatten()
 #进入循环内
 print(samples)
 loss = 1
 iter_count = 0
 iter_list=[]
 loss_list=[]
 theta1=[]
 theta2=[]
 #当损失精度大于0.01且迭代此时小于最大迭代次数时,进行
 while loss > 0.001 and iter_count < max_iter_count:
 loss = 0
 #梯度计算
 theta1.append(theta[0])
 theta2.append(theta[1])
 for i in range(m):
  h = np.dot(theta,samples[i].T) 
 #更新theta的值,需要的参量有:步长,梯度
  for j in range(len(theta)):
  theta[j] = theta[j] - step_size*(1/m)*(h - y[i])*samples[i,j]
 #计算总体的损失精度,等于各个样本损失精度之和
 for i in range(m):
  h = np.dot(theta.T, samples[i])
  #每组样本点损失的精度
  every_loss = (1/(var*m))*np.power((h - y[i]), 2)
  loss = loss + every_loss
 
 print("iter_count: ", iter_count, "the loss:", loss)
 
 iter_list.append(iter_count)
 loss_list.append(loss)
 
 iter_count += 1
 plt.plot(iter_list,loss_list)
 plt.xlabel("iter")
 plt.ylabel("loss")
 plt.show()
 return theta1,theta2,theta,loss_list
def painter3D(theta1,theta2,loss):
 style.use('ggplot')
 fig = plt.figure()
 ax1 = fig.add_subplot(111, projection='3d')
 x,y,z = theta1,theta2,loss
 ax1.plot_wireframe(x,y,z, rstride=5, cstride=5)
 ax1.set_xlabel("theta1")
 ax1.set_ylabel("theta2")
 ax1.set_zlabel("loss")
 plt.show()
def predict(x, theta):
 y = np.dot(theta, x.T)
 return y 
if __name__ == '__main__':
 samples, y = get_data()
 theta1,theta2,theta,loss_list = GD(samples, y)
 print(theta) # 会很接近[5, 7] 
 painter3D(theta1,theta2,loss_list)
 predict_y = predict(theta, [7,8])
 print(predict_y)

三、绘制的图像如下:

迭代次数与损失精度间的关系图如下:步长为0.01

python+numpy+matplotalib实现梯度下降法

变量python+numpy+matplotalib实现梯度下降法python+numpy+matplotalib实现梯度下降法与损失函数loss之间的关系:(从初始化之后会一步步收敛到loss满足精度,之后python+numpy+matplotalib实现梯度下降法python+numpy+matplotalib实现梯度下降法会变的稳定下来)

python+numpy+matplotalib实现梯度下降法

下面我们来看一副当步长因子变大后的图像:步长因子为0.5(很明显其收敛速度变缓了)

python+numpy+matplotalib实现梯度下降法

python+numpy+matplotalib实现梯度下降法

当步长因子设置为1.8左右时,其损失值已经开始震荡

python+numpy+matplotalib实现梯度下降法

python+numpy+matplotalib实现梯度下降法

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python协程用法实例分析
Jun 04 Python
使用Python中的tkinter模块作图的方法
Feb 07 Python
Python实现爬虫设置代理IP和伪装成浏览器的方法分享
May 07 Python
python实现猜数字小游戏
Mar 24 Python
Python关于excel和shp的使用在matplotlib
Jan 03 Python
python numpy中cumsum的用法详解
Oct 17 Python
python已协程方式处理任务实现过程
Dec 27 Python
基于python实现文件加密功能
Jan 06 Python
From CSV to SQLite3 by python 导入csv到sqlite实例
Feb 14 Python
Django-celery-beat动态添加周期性任务实现过程解析
Nov 26 Python
python中的random模块和相关函数详解
Apr 22 Python
python内置模块之上下文管理contextlib
Jun 14 Python
python实现随机梯度下降法
Mar 24 #Python
python实现决策树分类(2)
Aug 30 #Python
python实现决策树分类
Aug 30 #Python
python实现多人聊天室
Mar 31 #Python
Python实现将数据写入netCDF4中的方法示例
Aug 30 #Python
Python使用爬虫抓取美女图片并保存到本地的方法【测试可用】
Aug 30 #Python
Python使用一行代码获取上个月是几月
Aug 30 #Python
You might like
Smarty模板快速入门
2007/01/04 PHP
利用ThinkPHP内置的ThinkAjax实现异步传输技术的实现方法
2011/12/19 PHP
使用array mutisort 实现按某字段对数据排序
2013/06/18 PHP
ThinkPHP基于PHPExcel导入Excel文件的方法
2014/10/15 PHP
php与Mysql的一些简单的操作
2015/02/26 PHP
详解PHP中的mb_detect_encoding函数使用方法
2015/08/18 PHP
ThinkPHP进程计数类Process用法实例详解
2015/09/25 PHP
jQuery探测位置的提示弹窗(toolTip box)详细解析
2013/11/14 Javascript
用javascript替换URL中的参数值示例代码
2014/01/27 Javascript
js文件Cookie存取值示例代码
2014/02/20 Javascript
一个JavaScript用逗号分割字符串实例
2014/09/22 Javascript
JavaScript数据结构和算法之二叉树详解
2015/02/11 Javascript
深入浅析JavaScript系列(13):This? Yes,this!
2016/01/05 Javascript
js注入 黑客之路必备!
2016/09/14 Javascript
关于JS中二维数组的声明方法
2016/09/24 Javascript
js判断文件格式及大小的简单实例(必看)
2016/10/11 Javascript
angularJS模态框$modal实例代码
2017/05/27 Javascript
详解A标签中href=&quot;&quot;的几种用法
2017/08/20 Javascript
EasyUI的DataGrid每行数据添加操作按钮的实现代码
2017/08/22 Javascript
微信小程序-API接口安全详解
2019/07/16 Javascript
基于JavaScript实现表格隔行换色
2020/05/08 Javascript
Vue+Element自定义纵向表格表头教程
2020/10/26 Javascript
[03:05]DOTA2英雄基础教程 嗜血狂魔
2013/12/10 DOTA
Windows下PyMongo下载及安装教程
2015/04/27 Python
python判断windows系统是32位还是64位的方法
2015/05/11 Python
Python探索之Metaclass初步了解
2017/10/28 Python
python3使用pandas获取股票数据的方法
2018/12/22 Python
详解Python装饰器
2019/03/25 Python
详解PyQt5中textBrowser显示print语句输出的简单方法
2020/08/07 Python
python exit出错原因整理
2020/08/31 Python
售前工程师职业生涯规划
2014/03/02 职场文书
国旗下的演讲稿
2014/05/08 职场文书
为什么阅读对所有年龄段的孩子都很重要?
2019/07/08 职场文书
在python中实现导入一个需要传参的模块
2021/05/12 Python
SpringBoot实现异步事件驱动的方法
2021/06/28 Java/Android
SpringBoot整合RabbitMQ的5种模式实战
2021/08/02 Java/Android