python梯度下降算法的实现


Posted in Python onFebruary 24, 2020

本文实例为大家分享了python实现梯度下降算法的具体代码,供大家参考,具体内容如下

简介

本文使用python实现了梯度下降算法,支持y = Wx+b的线性回归
目前支持批量梯度算法和随机梯度下降算法(bs=1)
也支持输入特征向量的x维度小于3的图像可视化
代码要求python版本>3.4

代码

'''
梯度下降算法
Batch Gradient Descent
Stochastic Gradient Descent SGD
'''
__author__ = 'epleone'
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import sys

# 使用随机数种子, 让每次的随机数生成相同,方便调试
# np.random.seed(111111111)


class GradientDescent(object):
 eps = 1.0e-8
 max_iter = 1000000 # 暂时不需要
 dim = 1
 func_args = [2.1, 2.7] # [w_0, .., w_dim, b]

 def __init__(self, func_arg=None, N=1000):
 self.data_num = N
 if func_arg is not None:
 self.FuncArgs = func_arg
 self._getData()

 def _getData(self):
 x = 20 * (np.random.rand(self.data_num, self.dim) - 0.5)
 b_1 = np.ones((self.data_num, 1), dtype=np.float)
 # x = np.concatenate((x, b_1), axis=1)
 self.x = np.concatenate((x, b_1), axis=1)

 def func(self, x):
 # noise太大的话, 梯度下降法失去作用
 noise = 0.01 * np.random.randn(self.data_num) + 0
 w = np.array(self.func_args)
 # y1 = w * self.x[0, ] # 直接相乘
 y = np.dot(self.x, w) # 矩阵乘法
 y += noise
 return y

 @property
 def FuncArgs(self):
 return self.func_args

 @FuncArgs.setter
 def FuncArgs(self, args):
 if not isinstance(args, list):
 raise Exception(
 'args is not list, it should be like [w_0, ..., w_dim, b]')
 if len(args) == 0:
 raise Exception('args is empty list!!')
 if len(args) == 1:
 args.append(0.0)
 self.func_args = args
 self.dim = len(args) - 1
 self._getData()

 @property
 def EPS(self):
 return self.eps

 @EPS.setter
 def EPS(self, value):
 if not isinstance(value, float) and not isinstance(value, int):
 raise Exception("The type of eps should be an float number")
 self.eps = value

 def plotFunc(self):
 # 一维画图
 if self.dim == 1:
 # x = np.sort(self.x, axis=0)
 x = self.x
 y = self.func(x)
 fig, ax = plt.subplots()
 ax.plot(x, y, 'o')
 ax.set(xlabel='x ', ylabel='y', title='Loss Curve')
 ax.grid()
 plt.show()
 # 二维画图
 if self.dim == 2:
 # x = np.sort(self.x, axis=0)
 x = self.x
 y = self.func(x)
 xs = x[:, 0]
 ys = x[:, 1]
 zs = y
 fig = plt.figure()
 ax = fig.add_subplot(111, projection='3d')
 ax.scatter(xs, ys, zs, c='r', marker='o')

 ax.set_xlabel('X Label')
 ax.set_ylabel('Y Label')
 ax.set_zlabel('Z Label')
 plt.show()
 else:
 # plt.axis('off')
 plt.text(
 0.5,
 0.5,
 "The dimension(x.dim > 2) \n is too high to draw",
 size=17,
 rotation=0.,
 ha="center",
 va="center",
 bbox=dict(
 boxstyle="round",
 ec=(1., 0.5, 0.5),
 fc=(1., 0.8, 0.8), ))
 plt.draw()
 plt.show()
 # print('The dimension(x.dim > 2) is too high to draw')

 # 梯度下降法只能求解凸函数
 def _gradient_descent(self, bs, lr, epoch):
 x = self.x
 # shuffle数据集没有必要
 # np.random.shuffle(x)
 y = self.func(x)
 w = np.ones((self.dim + 1, 1), dtype=float)
 for e in range(epoch):
 print('epoch:' + str(e), end=',')
 # 批量梯度下降,bs为1时 等价单样本梯度下降
 for i in range(0, self.data_num, bs):
 y_ = np.dot(x[i:i + bs], w)
 loss = y_ - y[i:i + bs].reshape(-1, 1)
 d = loss * x[i:i + bs]
 d = d.sum(axis=0) / bs
 d = lr * d
 d.shape = (-1, 1)
 w = w - d

 y_ = np.dot(self.x, w)
 loss_ = abs((y_ - y).sum())
 print('\tLoss = ' + str(loss_))
 print('拟合的结果为:', end=',')
 print(sum(w.tolist(), []))
 print()
 if loss_ < self.eps:
 print('The Gradient Descent algorithm has converged!!\n')
 break
 pass

 def __call__(self, bs=1, lr=0.1, epoch=10):
 if sys.version_info < (3, 4):
 raise RuntimeError('At least Python 3.4 is required')
 if not isinstance(bs, int) or not isinstance(epoch, int):
 raise Exception(
 "The type of BatchSize/Epoch should be an integer number")
 self._gradient_descent(bs, lr, epoch)
 pass

 pass


if __name__ == "__main__":
 if sys.version_info < (3, 4):
 raise RuntimeError('At least Python 3.4 is required')

 gd = GradientDescent([1.2, 1.4, 2.1, 4.5, 2.1])
 # gd = GradientDescent([1.2, 1.4, 2.1])
 print("要拟合的参数结果是: ")
 print(gd.FuncArgs)
 print("===================\n\n")
 # gd.EPS = 0.0
 gd.plotFunc()
 gd(10, 0.01)
 print("Finished!")

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python 命令行参数sys.argv
Sep 06 Python
在Python程序中操作文件之flush()方法的使用教程
May 24 Python
简单讲解Python编程中namedtuple类的用法
Jun 21 Python
Python最火、R极具潜力 2017机器学习调查报告
Dec 11 Python
Python3实现的画图及加载图片动画效果示例
Jan 19 Python
关于python下cv.waitKey无响应的原因及解决方法
Jan 10 Python
Python中常用的内置方法
Jan 28 Python
利用pyinstaller打包exe文件的基本教程
May 02 Python
使用PyInstaller将Pygame库编写的小游戏程序打包为exe文件及出现问题解决方法
Sep 06 Python
python爬虫模拟浏览器访问-User-Agent过程解析
Dec 28 Python
基于梯度爆炸的解决方法:clip gradient
Feb 04 Python
python实现无边框进度条的实例代码
Dec 30 Python
利用python实现逐步回归
Feb 24 #Python
python数据分析:关键字提取方式
Feb 24 #Python
python数据预处理 :数据共线性处理详解
Feb 24 #Python
使用python实现多维数据降维操作
Feb 24 #Python
python数据预处理 :数据抽样解析
Feb 24 #Python
Python找出列表中出现次数最多的元素三种方式
Feb 24 #Python
Python流程控制常用工具详解
Feb 24 #Python
You might like
在PHP模板引擎smarty生成随机数的方法和math函数详解
2014/04/24 PHP
一键生成各种尺寸Icon的php脚本(实例)
2017/02/08 PHP
jQuery 树形结构的选择器
2010/02/15 Javascript
javascript对象之内置对象Math使用方法
2010/04/16 Javascript
javascript ready和load事件的区别示例介绍
2013/08/30 Javascript
js的正则test,match,exec详细解析
2014/01/29 Javascript
js获取当前页面的url网址信息
2014/06/12 Javascript
详谈JavaScript 匿名函数及闭包
2014/11/14 Javascript
js使用DOM操作实现简单留言板的方法
2015/04/10 Javascript
NodeJS整合银联网关支付(DEMO)
2016/11/09 NodeJs
BootStrap学习系列之布局组件(下拉,按钮组[toolbar],上拉)
2017/01/03 Javascript
js时间戳和c#时间戳互转方法(推荐)
2017/02/15 Javascript
jquery之基本选择器practice(实例讲解)
2017/09/30 jQuery
nuxt+axios解决前后端分离SSR的示例代码
2017/10/24 Javascript
vue中$refs的用法及作用详解
2018/04/24 Javascript
VuePress 静态网站生成方法步骤
2019/02/14 Javascript
解决antd Form 表单校验方法无响应的问题
2020/10/27 Javascript
Vue如何实现验证码输入交互
2020/12/07 Vue.js
[01:35:53]完美世界DOTA2联赛PWL S3 Magma vs GXR 第二场 12.13
2020/12/17 DOTA
python访问系统环境变量的方法
2015/04/29 Python
[原创]使用豆瓣提供的国内pypi源
2017/07/02 Python
Python实现线性判别分析(LDA)的MATLAB方式
2019/12/09 Python
基于python实现上传文件到OSS代码实例
2020/05/09 Python
Python正则re模块使用步骤及原理解析
2020/08/18 Python
python报错TypeError: ‘NoneType‘ object is not subscriptable的解决方法
2020/11/05 Python
python RSA加密的示例
2020/12/09 Python
俄罗斯街头服装品牌:Black Star Wear
2017/03/01 全球购物
简单说说tomcat的配置
2013/05/28 面试题
单位在职证明范本
2014/01/09 职场文书
实习生求职自荐信
2014/02/07 职场文书
低碳日宣传活动总结
2014/07/09 职场文书
售后服务承诺函格式
2015/01/21 职场文书
行政助理岗位职责
2015/02/10 职场文书
中学语文教学反思
2016/02/16 职场文书
MySQL Router的安装部署
2021/04/24 MySQL
微信小程序APP的生命周期及页面的生命周期
2022/04/19 Javascript