python梯度下降算法的实现


Posted in Python onFebruary 24, 2020

本文实例为大家分享了python实现梯度下降算法的具体代码,供大家参考,具体内容如下

简介

本文使用python实现了梯度下降算法,支持y = Wx+b的线性回归
目前支持批量梯度算法和随机梯度下降算法(bs=1)
也支持输入特征向量的x维度小于3的图像可视化
代码要求python版本>3.4

代码

'''
梯度下降算法
Batch Gradient Descent
Stochastic Gradient Descent SGD
'''
__author__ = 'epleone'
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import sys

# 使用随机数种子, 让每次的随机数生成相同,方便调试
# np.random.seed(111111111)


class GradientDescent(object):
 eps = 1.0e-8
 max_iter = 1000000 # 暂时不需要
 dim = 1
 func_args = [2.1, 2.7] # [w_0, .., w_dim, b]

 def __init__(self, func_arg=None, N=1000):
 self.data_num = N
 if func_arg is not None:
 self.FuncArgs = func_arg
 self._getData()

 def _getData(self):
 x = 20 * (np.random.rand(self.data_num, self.dim) - 0.5)
 b_1 = np.ones((self.data_num, 1), dtype=np.float)
 # x = np.concatenate((x, b_1), axis=1)
 self.x = np.concatenate((x, b_1), axis=1)

 def func(self, x):
 # noise太大的话, 梯度下降法失去作用
 noise = 0.01 * np.random.randn(self.data_num) + 0
 w = np.array(self.func_args)
 # y1 = w * self.x[0, ] # 直接相乘
 y = np.dot(self.x, w) # 矩阵乘法
 y += noise
 return y

 @property
 def FuncArgs(self):
 return self.func_args

 @FuncArgs.setter
 def FuncArgs(self, args):
 if not isinstance(args, list):
 raise Exception(
 'args is not list, it should be like [w_0, ..., w_dim, b]')
 if len(args) == 0:
 raise Exception('args is empty list!!')
 if len(args) == 1:
 args.append(0.0)
 self.func_args = args
 self.dim = len(args) - 1
 self._getData()

 @property
 def EPS(self):
 return self.eps

 @EPS.setter
 def EPS(self, value):
 if not isinstance(value, float) and not isinstance(value, int):
 raise Exception("The type of eps should be an float number")
 self.eps = value

 def plotFunc(self):
 # 一维画图
 if self.dim == 1:
 # x = np.sort(self.x, axis=0)
 x = self.x
 y = self.func(x)
 fig, ax = plt.subplots()
 ax.plot(x, y, 'o')
 ax.set(xlabel='x ', ylabel='y', title='Loss Curve')
 ax.grid()
 plt.show()
 # 二维画图
 if self.dim == 2:
 # x = np.sort(self.x, axis=0)
 x = self.x
 y = self.func(x)
 xs = x[:, 0]
 ys = x[:, 1]
 zs = y
 fig = plt.figure()
 ax = fig.add_subplot(111, projection='3d')
 ax.scatter(xs, ys, zs, c='r', marker='o')

 ax.set_xlabel('X Label')
 ax.set_ylabel('Y Label')
 ax.set_zlabel('Z Label')
 plt.show()
 else:
 # plt.axis('off')
 plt.text(
 0.5,
 0.5,
 "The dimension(x.dim > 2) \n is too high to draw",
 size=17,
 rotation=0.,
 ha="center",
 va="center",
 bbox=dict(
 boxstyle="round",
 ec=(1., 0.5, 0.5),
 fc=(1., 0.8, 0.8), ))
 plt.draw()
 plt.show()
 # print('The dimension(x.dim > 2) is too high to draw')

 # 梯度下降法只能求解凸函数
 def _gradient_descent(self, bs, lr, epoch):
 x = self.x
 # shuffle数据集没有必要
 # np.random.shuffle(x)
 y = self.func(x)
 w = np.ones((self.dim + 1, 1), dtype=float)
 for e in range(epoch):
 print('epoch:' + str(e), end=',')
 # 批量梯度下降,bs为1时 等价单样本梯度下降
 for i in range(0, self.data_num, bs):
 y_ = np.dot(x[i:i + bs], w)
 loss = y_ - y[i:i + bs].reshape(-1, 1)
 d = loss * x[i:i + bs]
 d = d.sum(axis=0) / bs
 d = lr * d
 d.shape = (-1, 1)
 w = w - d

 y_ = np.dot(self.x, w)
 loss_ = abs((y_ - y).sum())
 print('\tLoss = ' + str(loss_))
 print('拟合的结果为:', end=',')
 print(sum(w.tolist(), []))
 print()
 if loss_ < self.eps:
 print('The Gradient Descent algorithm has converged!!\n')
 break
 pass

 def __call__(self, bs=1, lr=0.1, epoch=10):
 if sys.version_info < (3, 4):
 raise RuntimeError('At least Python 3.4 is required')
 if not isinstance(bs, int) or not isinstance(epoch, int):
 raise Exception(
 "The type of BatchSize/Epoch should be an integer number")
 self._gradient_descent(bs, lr, epoch)
 pass

 pass


if __name__ == "__main__":
 if sys.version_info < (3, 4):
 raise RuntimeError('At least Python 3.4 is required')

 gd = GradientDescent([1.2, 1.4, 2.1, 4.5, 2.1])
 # gd = GradientDescent([1.2, 1.4, 2.1])
 print("要拟合的参数结果是: ")
 print(gd.FuncArgs)
 print("===================\n\n")
 # gd.EPS = 0.0
 gd.plotFunc()
 gd(10, 0.01)
 print("Finished!")

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
windows下python模拟鼠标点击和键盘输示例
Feb 28 Python
python计数排序和基数排序算法实例
Apr 25 Python
在Python中操作字符串之rstrip()方法的使用
May 19 Python
Python cookbook(数据结构与算法)将序列分解为单独变量的方法
Feb 13 Python
Django 重写用户模型的实现
Jul 29 Python
Python队列RabbitMQ 使用方法实例记录
Aug 05 Python
Django使用中间件解决前后端同源策略问题
Sep 02 Python
Python中的list与tuple集合区别解析
Oct 12 Python
Python文件操作基础流程解析
Mar 19 Python
带你学习Python如何实现回归树模型
Jul 16 Python
python使用Word2Vec进行情感分析解析
Jul 31 Python
Python爬取网站图片并保存的实现示例
Feb 26 Python
利用python实现逐步回归
Feb 24 #Python
python数据分析:关键字提取方式
Feb 24 #Python
python数据预处理 :数据共线性处理详解
Feb 24 #Python
使用python实现多维数据降维操作
Feb 24 #Python
python数据预处理 :数据抽样解析
Feb 24 #Python
Python找出列表中出现次数最多的元素三种方式
Feb 24 #Python
Python流程控制常用工具详解
Feb 24 #Python
You might like
简单的php数据库操作类代码(增,删,改,查)
2013/04/08 PHP
解析php中die(),exit(),return的区别
2013/06/20 PHP
thinkPHP简单实现多个子查询语句的方法
2016/12/05 PHP
总结PHP内存释放以及垃圾回收
2018/03/29 PHP
Laravel框架使用Redis的方法详解
2018/05/30 PHP
CSS+JS构建的图片查看器
2006/07/22 Javascript
jQuery 技巧小结
2010/04/02 Javascript
JS脚本defer的作用示例介绍
2014/01/02 Javascript
js创建jsonArray传输至后台及后台全面解析
2016/04/11 Javascript
用JS动态设置CSS样式常见方法小结(推荐)
2016/11/10 Javascript
jQuery如何跳转到另一个网页 就这么简单
2016/12/28 Javascript
JS实现的自动打字效果示例
2017/03/10 Javascript
JS计算输出100元钱买100只鸡问题的解决方法
2018/01/04 Javascript
JS使用Date对象实时显示当前系统时间简单示例
2018/08/23 Javascript
解决vue路由后界面没有变化,但是链接有的问题
2018/09/01 Javascript
通过JS深度判断两个对象字段相同
2019/06/14 Javascript
vue的注意规范之v-if 与 v-for 一起使用教程
2019/08/04 Javascript
JavaScript实现省份城市的三级联动
2020/02/11 Javascript
深入分析JavaScript 事件循环(Event Loop)
2020/06/19 Javascript
js实现简单图片拖拽效果
2021/02/22 Javascript
[01:20]PWL S2开团时刻第三期——团战可以输 蝙蝠必须死
2020/11/26 DOTA
在Python中使用成员运算符的示例
2015/05/13 Python
selenium python浏览器多窗口处理代码示例
2018/01/15 Python
python 多线程中子线程和主线程相互通信方法
2018/11/09 Python
浅谈Python批处理文件夹中的txt文件
2019/03/11 Python
Python操作MySQL数据库的两种方式实例分析【pymysql和pandas】
2019/03/18 Python
Python利用imshow制作自定义渐变填充柱状图(colorbar)
2020/12/10 Python
详解python的变量缓存机制
2021/01/24 Python
Stefania Mode英国:奢华设计师和时尚服装
2017/10/23 全球购物
波兰最早的运动鞋精品店之一:Street Supply
2019/08/29 全球购物
小学老师寄语大全
2014/04/04 职场文书
重大事项社会稳定风险评估方案
2014/06/15 职场文书
旅行社优秀创业计划书
2014/08/16 职场文书
租房安全协议书
2014/08/20 职场文书
退休劳动合同怎么写?
2019/10/25 职场文书
pytest实现多进程与多线程运行超好用的插件
2022/07/15 Python