python梯度下降算法的实现


Posted in Python onFebruary 24, 2020

本文实例为大家分享了python实现梯度下降算法的具体代码,供大家参考,具体内容如下

简介

本文使用python实现了梯度下降算法,支持y = Wx+b的线性回归
目前支持批量梯度算法和随机梯度下降算法(bs=1)
也支持输入特征向量的x维度小于3的图像可视化
代码要求python版本>3.4

代码

'''
梯度下降算法
Batch Gradient Descent
Stochastic Gradient Descent SGD
'''
__author__ = 'epleone'
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import sys

# 使用随机数种子, 让每次的随机数生成相同,方便调试
# np.random.seed(111111111)


class GradientDescent(object):
 eps = 1.0e-8
 max_iter = 1000000 # 暂时不需要
 dim = 1
 func_args = [2.1, 2.7] # [w_0, .., w_dim, b]

 def __init__(self, func_arg=None, N=1000):
 self.data_num = N
 if func_arg is not None:
 self.FuncArgs = func_arg
 self._getData()

 def _getData(self):
 x = 20 * (np.random.rand(self.data_num, self.dim) - 0.5)
 b_1 = np.ones((self.data_num, 1), dtype=np.float)
 # x = np.concatenate((x, b_1), axis=1)
 self.x = np.concatenate((x, b_1), axis=1)

 def func(self, x):
 # noise太大的话, 梯度下降法失去作用
 noise = 0.01 * np.random.randn(self.data_num) + 0
 w = np.array(self.func_args)
 # y1 = w * self.x[0, ] # 直接相乘
 y = np.dot(self.x, w) # 矩阵乘法
 y += noise
 return y

 @property
 def FuncArgs(self):
 return self.func_args

 @FuncArgs.setter
 def FuncArgs(self, args):
 if not isinstance(args, list):
 raise Exception(
 'args is not list, it should be like [w_0, ..., w_dim, b]')
 if len(args) == 0:
 raise Exception('args is empty list!!')
 if len(args) == 1:
 args.append(0.0)
 self.func_args = args
 self.dim = len(args) - 1
 self._getData()

 @property
 def EPS(self):
 return self.eps

 @EPS.setter
 def EPS(self, value):
 if not isinstance(value, float) and not isinstance(value, int):
 raise Exception("The type of eps should be an float number")
 self.eps = value

 def plotFunc(self):
 # 一维画图
 if self.dim == 1:
 # x = np.sort(self.x, axis=0)
 x = self.x
 y = self.func(x)
 fig, ax = plt.subplots()
 ax.plot(x, y, 'o')
 ax.set(xlabel='x ', ylabel='y', title='Loss Curve')
 ax.grid()
 plt.show()
 # 二维画图
 if self.dim == 2:
 # x = np.sort(self.x, axis=0)
 x = self.x
 y = self.func(x)
 xs = x[:, 0]
 ys = x[:, 1]
 zs = y
 fig = plt.figure()
 ax = fig.add_subplot(111, projection='3d')
 ax.scatter(xs, ys, zs, c='r', marker='o')

 ax.set_xlabel('X Label')
 ax.set_ylabel('Y Label')
 ax.set_zlabel('Z Label')
 plt.show()
 else:
 # plt.axis('off')
 plt.text(
 0.5,
 0.5,
 "The dimension(x.dim > 2) \n is too high to draw",
 size=17,
 rotation=0.,
 ha="center",
 va="center",
 bbox=dict(
 boxstyle="round",
 ec=(1., 0.5, 0.5),
 fc=(1., 0.8, 0.8), ))
 plt.draw()
 plt.show()
 # print('The dimension(x.dim > 2) is too high to draw')

 # 梯度下降法只能求解凸函数
 def _gradient_descent(self, bs, lr, epoch):
 x = self.x
 # shuffle数据集没有必要
 # np.random.shuffle(x)
 y = self.func(x)
 w = np.ones((self.dim + 1, 1), dtype=float)
 for e in range(epoch):
 print('epoch:' + str(e), end=',')
 # 批量梯度下降,bs为1时 等价单样本梯度下降
 for i in range(0, self.data_num, bs):
 y_ = np.dot(x[i:i + bs], w)
 loss = y_ - y[i:i + bs].reshape(-1, 1)
 d = loss * x[i:i + bs]
 d = d.sum(axis=0) / bs
 d = lr * d
 d.shape = (-1, 1)
 w = w - d

 y_ = np.dot(self.x, w)
 loss_ = abs((y_ - y).sum())
 print('\tLoss = ' + str(loss_))
 print('拟合的结果为:', end=',')
 print(sum(w.tolist(), []))
 print()
 if loss_ < self.eps:
 print('The Gradient Descent algorithm has converged!!\n')
 break
 pass

 def __call__(self, bs=1, lr=0.1, epoch=10):
 if sys.version_info < (3, 4):
 raise RuntimeError('At least Python 3.4 is required')
 if not isinstance(bs, int) or not isinstance(epoch, int):
 raise Exception(
 "The type of BatchSize/Epoch should be an integer number")
 self._gradient_descent(bs, lr, epoch)
 pass

 pass


if __name__ == "__main__":
 if sys.version_info < (3, 4):
 raise RuntimeError('At least Python 3.4 is required')

 gd = GradientDescent([1.2, 1.4, 2.1, 4.5, 2.1])
 # gd = GradientDescent([1.2, 1.4, 2.1])
 print("要拟合的参数结果是: ")
 print(gd.FuncArgs)
 print("===================\n\n")
 # gd.EPS = 0.0
 gd.plotFunc()
 gd(10, 0.01)
 print("Finished!")

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python中列表元素连接方法join用法实例
Apr 07 Python
Python使用django获取用户IP地址的方法
May 11 Python
浅谈flask截获所有访问及before/after_request修饰器
Jan 18 Python
django与小程序实现登录验证功能的示例代码
Feb 19 Python
python获取Linux发行版名称
Aug 30 Python
Python爬取知乎图片代码实现解析
Sep 17 Python
Python换行与不换行的输出实例
Feb 19 Python
Django choices下拉列表绑定实例
Mar 13 Python
python3.8.3安装教程及环境配置的详细教程(64-bit)
Nov 28 Python
python中yield的用法详解
Jan 13 Python
解决PDF 转图片时丢文字的一种可能方式
Mar 04 Python
Python使用PyYAML库读写yaml文件的方法
Apr 06 Python
利用python实现逐步回归
Feb 24 #Python
python数据分析:关键字提取方式
Feb 24 #Python
python数据预处理 :数据共线性处理详解
Feb 24 #Python
使用python实现多维数据降维操作
Feb 24 #Python
python数据预处理 :数据抽样解析
Feb 24 #Python
Python找出列表中出现次数最多的元素三种方式
Feb 24 #Python
Python流程控制常用工具详解
Feb 24 #Python
You might like
收音机术语解释
2021/03/01 无线电
PHP实现加密的几种方式介绍
2015/02/22 PHP
简单了解PHP编程中数组的指针的使用
2015/11/30 PHP
php in_array() 检查数组中是否存在某个值详解
2016/11/23 PHP
PHP模版引擎原理、定义与用法实例
2019/03/29 PHP
不间断滚动JS打包类,基本可以实现所有的滚动效果,太强了
2007/12/08 Javascript
复制Input内容的js代码_支持所有浏览器,修正了Firefox3.5以上的问题
2010/06/21 Javascript
javascript获取函数名称、函数参数、对象属性名称的代码实例
2014/04/12 Javascript
jQuery截取指定长度字符串代码
2014/08/21 Javascript
jQuery实现鼠标滚轮动态改变样式或效果
2015/01/05 Javascript
js实现跨域的几种方法汇总(图片ping、JSONP和CORS)
2015/10/25 Javascript
通过学习bootstrop导航条学会修改bootstrop颜色基调
2017/06/11 Javascript
EasyUI在Panel上动态添加LinkButton按钮
2017/08/11 Javascript
web3.js增加eth.getRawTransactionByHash(txhash)方法步骤
2018/03/15 Javascript
webpack多入口多出口的实现方法
2018/08/17 Javascript
jdk1.8+vue elementui实现多级菜单功能
2020/09/24 Javascript
[12:29]《一刀刀一天》之DOTA全时刻19:蝙蝠骑士田伯光再度不举
2014/06/10 DOTA
[01:03]DOTA2新的征程 你的脚印值得踏上
2014/08/13 DOTA
Python中利用xpath解析HTML的方法
2018/05/14 Python
Python实现数据结构线性链表(单链表)算法示例
2019/05/04 Python
Pycharm新手教程(只需要看这篇就够了)
2019/06/18 Python
使用virtualenv创建Python环境及PyQT5环境配置的方法
2019/09/10 Python
python多进程 主进程和子进程间共享和不共享全局变量实例
2020/04/25 Python
Anaconda3中的Jupyter notebook添加目录插件的实现
2020/05/18 Python
Python用requests库爬取返回为空的解决办法
2021/02/21 Python
值类型与引用类型有什么不同?请举例说明?并分别列举几种相应的数据类型
2015/10/24 面试题
安全标准化实施方案
2014/02/20 职场文书
学前班评语大全
2014/05/04 职场文书
农村党支部书记四风问题个人对照检查材料
2014/09/21 职场文书
2014年实习期工作总结
2014/11/27 职场文书
2015年个人实习工作总结
2014/12/12 职场文书
党员示范岗材料
2014/12/19 职场文书
嘉宾邀请函
2015/01/31 职场文书
冲出亚马逊观后感
2015/06/03 职场文书
python实现股票历史数据可视化分析案例
2021/06/10 Python
Mysql调整优化之四种分区方式以及组合分区
2022/04/13 MySQL