python梯度下降算法的实现


Posted in Python onFebruary 24, 2020

本文实例为大家分享了python实现梯度下降算法的具体代码,供大家参考,具体内容如下

简介

本文使用python实现了梯度下降算法,支持y = Wx+b的线性回归
目前支持批量梯度算法和随机梯度下降算法(bs=1)
也支持输入特征向量的x维度小于3的图像可视化
代码要求python版本>3.4

代码

'''
梯度下降算法
Batch Gradient Descent
Stochastic Gradient Descent SGD
'''
__author__ = 'epleone'
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import sys

# 使用随机数种子, 让每次的随机数生成相同,方便调试
# np.random.seed(111111111)


class GradientDescent(object):
 eps = 1.0e-8
 max_iter = 1000000 # 暂时不需要
 dim = 1
 func_args = [2.1, 2.7] # [w_0, .., w_dim, b]

 def __init__(self, func_arg=None, N=1000):
 self.data_num = N
 if func_arg is not None:
 self.FuncArgs = func_arg
 self._getData()

 def _getData(self):
 x = 20 * (np.random.rand(self.data_num, self.dim) - 0.5)
 b_1 = np.ones((self.data_num, 1), dtype=np.float)
 # x = np.concatenate((x, b_1), axis=1)
 self.x = np.concatenate((x, b_1), axis=1)

 def func(self, x):
 # noise太大的话, 梯度下降法失去作用
 noise = 0.01 * np.random.randn(self.data_num) + 0
 w = np.array(self.func_args)
 # y1 = w * self.x[0, ] # 直接相乘
 y = np.dot(self.x, w) # 矩阵乘法
 y += noise
 return y

 @property
 def FuncArgs(self):
 return self.func_args

 @FuncArgs.setter
 def FuncArgs(self, args):
 if not isinstance(args, list):
 raise Exception(
 'args is not list, it should be like [w_0, ..., w_dim, b]')
 if len(args) == 0:
 raise Exception('args is empty list!!')
 if len(args) == 1:
 args.append(0.0)
 self.func_args = args
 self.dim = len(args) - 1
 self._getData()

 @property
 def EPS(self):
 return self.eps

 @EPS.setter
 def EPS(self, value):
 if not isinstance(value, float) and not isinstance(value, int):
 raise Exception("The type of eps should be an float number")
 self.eps = value

 def plotFunc(self):
 # 一维画图
 if self.dim == 1:
 # x = np.sort(self.x, axis=0)
 x = self.x
 y = self.func(x)
 fig, ax = plt.subplots()
 ax.plot(x, y, 'o')
 ax.set(xlabel='x ', ylabel='y', title='Loss Curve')
 ax.grid()
 plt.show()
 # 二维画图
 if self.dim == 2:
 # x = np.sort(self.x, axis=0)
 x = self.x
 y = self.func(x)
 xs = x[:, 0]
 ys = x[:, 1]
 zs = y
 fig = plt.figure()
 ax = fig.add_subplot(111, projection='3d')
 ax.scatter(xs, ys, zs, c='r', marker='o')

 ax.set_xlabel('X Label')
 ax.set_ylabel('Y Label')
 ax.set_zlabel('Z Label')
 plt.show()
 else:
 # plt.axis('off')
 plt.text(
 0.5,
 0.5,
 "The dimension(x.dim > 2) \n is too high to draw",
 size=17,
 rotation=0.,
 ha="center",
 va="center",
 bbox=dict(
 boxstyle="round",
 ec=(1., 0.5, 0.5),
 fc=(1., 0.8, 0.8), ))
 plt.draw()
 plt.show()
 # print('The dimension(x.dim > 2) is too high to draw')

 # 梯度下降法只能求解凸函数
 def _gradient_descent(self, bs, lr, epoch):
 x = self.x
 # shuffle数据集没有必要
 # np.random.shuffle(x)
 y = self.func(x)
 w = np.ones((self.dim + 1, 1), dtype=float)
 for e in range(epoch):
 print('epoch:' + str(e), end=',')
 # 批量梯度下降,bs为1时 等价单样本梯度下降
 for i in range(0, self.data_num, bs):
 y_ = np.dot(x[i:i + bs], w)
 loss = y_ - y[i:i + bs].reshape(-1, 1)
 d = loss * x[i:i + bs]
 d = d.sum(axis=0) / bs
 d = lr * d
 d.shape = (-1, 1)
 w = w - d

 y_ = np.dot(self.x, w)
 loss_ = abs((y_ - y).sum())
 print('\tLoss = ' + str(loss_))
 print('拟合的结果为:', end=',')
 print(sum(w.tolist(), []))
 print()
 if loss_ < self.eps:
 print('The Gradient Descent algorithm has converged!!\n')
 break
 pass

 def __call__(self, bs=1, lr=0.1, epoch=10):
 if sys.version_info < (3, 4):
 raise RuntimeError('At least Python 3.4 is required')
 if not isinstance(bs, int) or not isinstance(epoch, int):
 raise Exception(
 "The type of BatchSize/Epoch should be an integer number")
 self._gradient_descent(bs, lr, epoch)
 pass

 pass


if __name__ == "__main__":
 if sys.version_info < (3, 4):
 raise RuntimeError('At least Python 3.4 is required')

 gd = GradientDescent([1.2, 1.4, 2.1, 4.5, 2.1])
 # gd = GradientDescent([1.2, 1.4, 2.1])
 print("要拟合的参数结果是: ")
 print(gd.FuncArgs)
 print("===================\n\n")
 # gd.EPS = 0.0
 gd.plotFunc()
 gd(10, 0.01)
 print("Finished!")

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
跟老齐学Python之集合(set)
Sep 24 Python
python Matplotlib画图之调整字体大小的示例
Nov 20 Python
十分钟利用Python制作属于你自己的个性logo
May 07 Python
Python嵌套列表转一维的方法(压平嵌套列表)
Jul 03 Python
详解django.contirb.auth-认证
Jul 16 Python
Matplotlib scatter绘制散点图的方法实现
Jan 02 Python
python matplotlib包图像配色方案分享
Mar 14 Python
keras分类模型中的输入数据与标签的维度实例
Jul 03 Python
pycharm全局搜索的具体步骤
Jul 28 Python
聊聊python中的异常嵌套
Sep 01 Python
详解Python遍历列表时删除元素的正确做法
Jan 07 Python
python超详细实现完整学生成绩管理系统
Mar 17 Python
利用python实现逐步回归
Feb 24 #Python
python数据分析:关键字提取方式
Feb 24 #Python
python数据预处理 :数据共线性处理详解
Feb 24 #Python
使用python实现多维数据降维操作
Feb 24 #Python
python数据预处理 :数据抽样解析
Feb 24 #Python
Python找出列表中出现次数最多的元素三种方式
Feb 24 #Python
Python流程控制常用工具详解
Feb 24 #Python
You might like
PHP使用GIFEncoder类生成gif动态滚动字幕
2014/07/01 PHP
PHP实现动态执行代码的方法
2016/03/25 PHP
javascript 24小时弹出一次的代码(利用cookies)
2009/09/03 Javascript
javascript下arguments,caller,callee,call,apply示例及理解
2009/12/24 Javascript
javascript面向对象之定义成员方法实例分析
2015/01/13 Javascript
JavaScript DOM进阶方法
2015/04/13 Javascript
使用AngularJS实现可伸缩的页面切换的方法
2015/06/19 Javascript
修改jquery中dialog的title属性方法(推荐)
2016/08/26 Javascript
利用forever和pm2部署node.js项目过程
2017/05/10 Javascript
实例分析js事件循环机制
2017/12/13 Javascript
JS实现的抛物线运动效果示例
2018/01/30 Javascript
JavaScript中的事件与异常捕获详析
2019/02/24 Javascript
微信小程序中target和currentTarget的区别小结
2020/11/06 Javascript
[31:00]2014 DOTA2华西杯精英邀请赛5 24 NewBee VS iG
2014/05/25 DOTA
python中sets模块的用法实例
2014/09/30 Python
Python 自动补全(vim)
2014/11/30 Python
Python中用format函数格式化字符串的用法
2015/04/08 Python
Python正则表达式完全指南
2017/05/25 Python
python实现判断一个字符串是否是合法IP地址的示例
2018/06/04 Python
vue.js实现输入框输入值内容实时响应变化示例
2018/07/07 Python
Python Django给admin添加Action的方法实例详解
2019/04/29 Python
基于Keras中Conv1D和Conv2D的区别说明
2020/06/19 Python
Python docutils文档编译过程方法解析
2020/06/23 Python
玛蒂尔达简服装:Matilda Jane Clothing
2019/02/13 全球购物
入党思想汇报
2014/01/05 职场文书
社区工作者先进事迹
2014/01/18 职场文书
ktv总经理岗位职责
2014/02/17 职场文书
心理健康活动总结
2014/04/30 职场文书
2014年教师党员自我评议
2014/09/19 职场文书
2015年挂职锻炼工作总结
2014/12/12 职场文书
幼儿园父亲节活动总结
2015/02/12 职场文书
穆斯林的葬礼读书笔记
2015/06/26 职场文书
Python基于Tkinter开发一个爬取B站直播弹幕的工具
2021/05/06 Python
Redisson实现Redis分布式锁的几种方式
2021/08/07 Redis
详细介绍MySQL中limit和offset的用法
2022/05/06 MySQL
python playwrigh框架入门安装使用
2022/07/23 Python