python梯度下降算法的实现


Posted in Python onFebruary 24, 2020

本文实例为大家分享了python实现梯度下降算法的具体代码,供大家参考,具体内容如下

简介

本文使用python实现了梯度下降算法,支持y = Wx+b的线性回归
目前支持批量梯度算法和随机梯度下降算法(bs=1)
也支持输入特征向量的x维度小于3的图像可视化
代码要求python版本>3.4

代码

'''
梯度下降算法
Batch Gradient Descent
Stochastic Gradient Descent SGD
'''
__author__ = 'epleone'
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import sys

# 使用随机数种子, 让每次的随机数生成相同,方便调试
# np.random.seed(111111111)


class GradientDescent(object):
 eps = 1.0e-8
 max_iter = 1000000 # 暂时不需要
 dim = 1
 func_args = [2.1, 2.7] # [w_0, .., w_dim, b]

 def __init__(self, func_arg=None, N=1000):
 self.data_num = N
 if func_arg is not None:
 self.FuncArgs = func_arg
 self._getData()

 def _getData(self):
 x = 20 * (np.random.rand(self.data_num, self.dim) - 0.5)
 b_1 = np.ones((self.data_num, 1), dtype=np.float)
 # x = np.concatenate((x, b_1), axis=1)
 self.x = np.concatenate((x, b_1), axis=1)

 def func(self, x):
 # noise太大的话, 梯度下降法失去作用
 noise = 0.01 * np.random.randn(self.data_num) + 0
 w = np.array(self.func_args)
 # y1 = w * self.x[0, ] # 直接相乘
 y = np.dot(self.x, w) # 矩阵乘法
 y += noise
 return y

 @property
 def FuncArgs(self):
 return self.func_args

 @FuncArgs.setter
 def FuncArgs(self, args):
 if not isinstance(args, list):
 raise Exception(
 'args is not list, it should be like [w_0, ..., w_dim, b]')
 if len(args) == 0:
 raise Exception('args is empty list!!')
 if len(args) == 1:
 args.append(0.0)
 self.func_args = args
 self.dim = len(args) - 1
 self._getData()

 @property
 def EPS(self):
 return self.eps

 @EPS.setter
 def EPS(self, value):
 if not isinstance(value, float) and not isinstance(value, int):
 raise Exception("The type of eps should be an float number")
 self.eps = value

 def plotFunc(self):
 # 一维画图
 if self.dim == 1:
 # x = np.sort(self.x, axis=0)
 x = self.x
 y = self.func(x)
 fig, ax = plt.subplots()
 ax.plot(x, y, 'o')
 ax.set(xlabel='x ', ylabel='y', title='Loss Curve')
 ax.grid()
 plt.show()
 # 二维画图
 if self.dim == 2:
 # x = np.sort(self.x, axis=0)
 x = self.x
 y = self.func(x)
 xs = x[:, 0]
 ys = x[:, 1]
 zs = y
 fig = plt.figure()
 ax = fig.add_subplot(111, projection='3d')
 ax.scatter(xs, ys, zs, c='r', marker='o')

 ax.set_xlabel('X Label')
 ax.set_ylabel('Y Label')
 ax.set_zlabel('Z Label')
 plt.show()
 else:
 # plt.axis('off')
 plt.text(
 0.5,
 0.5,
 "The dimension(x.dim > 2) \n is too high to draw",
 size=17,
 rotation=0.,
 ha="center",
 va="center",
 bbox=dict(
 boxstyle="round",
 ec=(1., 0.5, 0.5),
 fc=(1., 0.8, 0.8), ))
 plt.draw()
 plt.show()
 # print('The dimension(x.dim > 2) is too high to draw')

 # 梯度下降法只能求解凸函数
 def _gradient_descent(self, bs, lr, epoch):
 x = self.x
 # shuffle数据集没有必要
 # np.random.shuffle(x)
 y = self.func(x)
 w = np.ones((self.dim + 1, 1), dtype=float)
 for e in range(epoch):
 print('epoch:' + str(e), end=',')
 # 批量梯度下降,bs为1时 等价单样本梯度下降
 for i in range(0, self.data_num, bs):
 y_ = np.dot(x[i:i + bs], w)
 loss = y_ - y[i:i + bs].reshape(-1, 1)
 d = loss * x[i:i + bs]
 d = d.sum(axis=0) / bs
 d = lr * d
 d.shape = (-1, 1)
 w = w - d

 y_ = np.dot(self.x, w)
 loss_ = abs((y_ - y).sum())
 print('\tLoss = ' + str(loss_))
 print('拟合的结果为:', end=',')
 print(sum(w.tolist(), []))
 print()
 if loss_ < self.eps:
 print('The Gradient Descent algorithm has converged!!\n')
 break
 pass

 def __call__(self, bs=1, lr=0.1, epoch=10):
 if sys.version_info < (3, 4):
 raise RuntimeError('At least Python 3.4 is required')
 if not isinstance(bs, int) or not isinstance(epoch, int):
 raise Exception(
 "The type of BatchSize/Epoch should be an integer number")
 self._gradient_descent(bs, lr, epoch)
 pass

 pass


if __name__ == "__main__":
 if sys.version_info < (3, 4):
 raise RuntimeError('At least Python 3.4 is required')

 gd = GradientDescent([1.2, 1.4, 2.1, 4.5, 2.1])
 # gd = GradientDescent([1.2, 1.4, 2.1])
 print("要拟合的参数结果是: ")
 print(gd.FuncArgs)
 print("===================\n\n")
 # gd.EPS = 0.0
 gd.plotFunc()
 gd(10, 0.01)
 print("Finished!")

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python的id()函数解密过程
Dec 25 Python
python实现在sqlite动态创建表的方法
May 08 Python
使用Python来编写HTTP服务器的超级指南
Feb 18 Python
简单讲解Python中的字符串与字符串的输入输出
Mar 13 Python
Python递归函数定义与用法示例
Jun 02 Python
Python正则表达式常用函数总结
Jun 24 Python
python实现log日志的示例代码
Apr 28 Python
Python3.5 处理文本txt,删除不需要的行方法
Dec 10 Python
python实现简易学生信息管理系统
Apr 05 Python
完美解决keras 读取多个hdf5文件进行训练的问题
Jul 01 Python
详解基于Facecognition+Opencv快速搭建人脸识别及跟踪应用
Jan 21 Python
python ConfigParser库的使用及遇到的坑
Feb 12 Python
利用python实现逐步回归
Feb 24 #Python
python数据分析:关键字提取方式
Feb 24 #Python
python数据预处理 :数据共线性处理详解
Feb 24 #Python
使用python实现多维数据降维操作
Feb 24 #Python
python数据预处理 :数据抽样解析
Feb 24 #Python
Python找出列表中出现次数最多的元素三种方式
Feb 24 #Python
Python流程控制常用工具详解
Feb 24 #Python
You might like
分页显示Oracle数据库记录的类之二
2006/10/09 PHP
php zlib压缩和解压缩swf文件的代码
2008/12/30 PHP
php学习之数据类型之间的转换代码
2011/05/29 PHP
PHP中获取文件扩展名的N种方法小结
2012/02/27 PHP
php注册审核重点解析(数据访问)
2017/05/23 PHP
javascript同步Import,同步调用外部js的方法
2008/07/08 Javascript
Javascript 设计模式(二) 闭包
2010/05/26 Javascript
ajax异步刷新实现更新数据库
2012/12/03 Javascript
JS+JSP checkBox 全选具体实现
2014/01/02 Javascript
href下载文件根据id取url并下载
2014/05/28 Javascript
jquery性能优化高级技巧
2015/08/24 Javascript
使用jQuery Ajax 请求webservice来实现更简练的Ajax
2016/08/04 Javascript
BootStrap下拉框在firefox浏览器界面不友好的解决方案
2016/08/18 Javascript
easyui-edatagrid.js实现回车键结束编辑功能的实例
2017/04/12 Javascript
jQuery实现的页面遮罩层功能示例【测试可用】
2017/10/14 jQuery
微信小程序上传图片实例
2018/05/28 Javascript
微信小程序使用gitee进行版本管理
2018/09/20 Javascript
Vue注册组件命名时不能用大写的原因浅析
2019/04/25 Javascript
[01:03:41]完美世界DOTA2联赛PWL S3 DLG vs Phoenix 第一场 12.17
2020/12/19 DOTA
手动实现把python项目发布为exe可执行程序过程分享
2014/10/23 Python
python通过get,post方式发送http请求和接收http响应的方法
2015/05/26 Python
Python中的深拷贝和浅拷贝详解
2015/06/03 Python
分析用Python脚本关闭文件操作的机制
2015/06/28 Python
python机器学习案例教程——K最近邻算法的实现
2017/12/28 Python
python 申请内存空间,用于创建多维数组的实例
2019/12/02 Python
Python关于反射的实例代码分享
2020/02/20 Python
Pandas的Apply函数具体使用
2020/07/21 Python
详解基于python的全局与局部序列比对的实现(DNA)
2020/10/07 Python
纯CSS实现预加载动画效果
2017/09/06 HTML / CSS
Meli Melo官网:名媛们钟爱的英国奢侈手包品牌
2017/04/17 全球购物
Doyoueven官网:澳大利亚健身服饰和配饰品牌
2019/03/24 全球购物
StubHub希腊:购买体育赛事、音乐会和剧院门票
2019/08/03 全球购物
执行力心得体会范文
2016/01/11 职场文书
php TP5框架生成二维码链接
2021/04/01 PHP
python使用tkinter实现透明窗体上绘制随机出现的小球(实例代码)
2021/05/17 Python
Python数据可视化之用Matplotlib绘制常用图形
2021/06/03 Python