python梯度下降算法的实现


Posted in Python onFebruary 24, 2020

本文实例为大家分享了python实现梯度下降算法的具体代码,供大家参考,具体内容如下

简介

本文使用python实现了梯度下降算法,支持y = Wx+b的线性回归
目前支持批量梯度算法和随机梯度下降算法(bs=1)
也支持输入特征向量的x维度小于3的图像可视化
代码要求python版本>3.4

代码

'''
梯度下降算法
Batch Gradient Descent
Stochastic Gradient Descent SGD
'''
__author__ = 'epleone'
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import sys

# 使用随机数种子, 让每次的随机数生成相同,方便调试
# np.random.seed(111111111)


class GradientDescent(object):
 eps = 1.0e-8
 max_iter = 1000000 # 暂时不需要
 dim = 1
 func_args = [2.1, 2.7] # [w_0, .., w_dim, b]

 def __init__(self, func_arg=None, N=1000):
 self.data_num = N
 if func_arg is not None:
 self.FuncArgs = func_arg
 self._getData()

 def _getData(self):
 x = 20 * (np.random.rand(self.data_num, self.dim) - 0.5)
 b_1 = np.ones((self.data_num, 1), dtype=np.float)
 # x = np.concatenate((x, b_1), axis=1)
 self.x = np.concatenate((x, b_1), axis=1)

 def func(self, x):
 # noise太大的话, 梯度下降法失去作用
 noise = 0.01 * np.random.randn(self.data_num) + 0
 w = np.array(self.func_args)
 # y1 = w * self.x[0, ] # 直接相乘
 y = np.dot(self.x, w) # 矩阵乘法
 y += noise
 return y

 @property
 def FuncArgs(self):
 return self.func_args

 @FuncArgs.setter
 def FuncArgs(self, args):
 if not isinstance(args, list):
 raise Exception(
 'args is not list, it should be like [w_0, ..., w_dim, b]')
 if len(args) == 0:
 raise Exception('args is empty list!!')
 if len(args) == 1:
 args.append(0.0)
 self.func_args = args
 self.dim = len(args) - 1
 self._getData()

 @property
 def EPS(self):
 return self.eps

 @EPS.setter
 def EPS(self, value):
 if not isinstance(value, float) and not isinstance(value, int):
 raise Exception("The type of eps should be an float number")
 self.eps = value

 def plotFunc(self):
 # 一维画图
 if self.dim == 1:
 # x = np.sort(self.x, axis=0)
 x = self.x
 y = self.func(x)
 fig, ax = plt.subplots()
 ax.plot(x, y, 'o')
 ax.set(xlabel='x ', ylabel='y', title='Loss Curve')
 ax.grid()
 plt.show()
 # 二维画图
 if self.dim == 2:
 # x = np.sort(self.x, axis=0)
 x = self.x
 y = self.func(x)
 xs = x[:, 0]
 ys = x[:, 1]
 zs = y
 fig = plt.figure()
 ax = fig.add_subplot(111, projection='3d')
 ax.scatter(xs, ys, zs, c='r', marker='o')

 ax.set_xlabel('X Label')
 ax.set_ylabel('Y Label')
 ax.set_zlabel('Z Label')
 plt.show()
 else:
 # plt.axis('off')
 plt.text(
 0.5,
 0.5,
 "The dimension(x.dim > 2) \n is too high to draw",
 size=17,
 rotation=0.,
 ha="center",
 va="center",
 bbox=dict(
 boxstyle="round",
 ec=(1., 0.5, 0.5),
 fc=(1., 0.8, 0.8), ))
 plt.draw()
 plt.show()
 # print('The dimension(x.dim > 2) is too high to draw')

 # 梯度下降法只能求解凸函数
 def _gradient_descent(self, bs, lr, epoch):
 x = self.x
 # shuffle数据集没有必要
 # np.random.shuffle(x)
 y = self.func(x)
 w = np.ones((self.dim + 1, 1), dtype=float)
 for e in range(epoch):
 print('epoch:' + str(e), end=',')
 # 批量梯度下降,bs为1时 等价单样本梯度下降
 for i in range(0, self.data_num, bs):
 y_ = np.dot(x[i:i + bs], w)
 loss = y_ - y[i:i + bs].reshape(-1, 1)
 d = loss * x[i:i + bs]
 d = d.sum(axis=0) / bs
 d = lr * d
 d.shape = (-1, 1)
 w = w - d

 y_ = np.dot(self.x, w)
 loss_ = abs((y_ - y).sum())
 print('\tLoss = ' + str(loss_))
 print('拟合的结果为:', end=',')
 print(sum(w.tolist(), []))
 print()
 if loss_ < self.eps:
 print('The Gradient Descent algorithm has converged!!\n')
 break
 pass

 def __call__(self, bs=1, lr=0.1, epoch=10):
 if sys.version_info < (3, 4):
 raise RuntimeError('At least Python 3.4 is required')
 if not isinstance(bs, int) or not isinstance(epoch, int):
 raise Exception(
 "The type of BatchSize/Epoch should be an integer number")
 self._gradient_descent(bs, lr, epoch)
 pass

 pass


if __name__ == "__main__":
 if sys.version_info < (3, 4):
 raise RuntimeError('At least Python 3.4 is required')

 gd = GradientDescent([1.2, 1.4, 2.1, 4.5, 2.1])
 # gd = GradientDescent([1.2, 1.4, 2.1])
 print("要拟合的参数结果是: ")
 print(gd.FuncArgs)
 print("===================\n\n")
 # gd.EPS = 0.0
 gd.plotFunc()
 gd(10, 0.01)
 print("Finished!")

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python插入排序算法的实现代码
Nov 21 Python
Django自定义manage命令实例代码
Feb 11 Python
python如何压缩新文件到已有ZIP文件
Mar 14 Python
Python3爬虫学习之MySQL数据库存储爬取的信息详解
Dec 12 Python
Python 利用切片从列表中取出一部分使用的方法
Feb 01 Python
Pyqt5 基本界面组件之inputDialog的使用
Jun 25 Python
Python3批量移动指定文件到指定文件夹方法示例
Sep 02 Python
Python获取时间戳代码实例
Sep 24 Python
解决python 读取excel时 日期变成数字并加.0的问题
Oct 08 Python
python 类之间的参数传递方式
Dec 20 Python
python 工具 字符串转numpy浮点数组的实现
Mar 14 Python
python中os.path.join()函数实例用法
May 26 Python
利用python实现逐步回归
Feb 24 #Python
python数据分析:关键字提取方式
Feb 24 #Python
python数据预处理 :数据共线性处理详解
Feb 24 #Python
使用python实现多维数据降维操作
Feb 24 #Python
python数据预处理 :数据抽样解析
Feb 24 #Python
Python找出列表中出现次数最多的元素三种方式
Feb 24 #Python
Python流程控制常用工具详解
Feb 24 #Python
You might like
浅谈php数组array_change_key_case() 函数和array_chunk()函数
2016/10/22 PHP
php菜单/评论数据递归分级算法的实现方法
2019/08/01 PHP
PHP常用函数之根据生日计算年龄功能示例
2019/10/21 PHP
用ASP将SQL搜索出来的内容导出为TXT的代码
2007/07/27 Javascript
小型js框架veryide.librar源代码
2009/03/05 Javascript
使用js+jquery实现无限极联动
2013/05/23 Javascript
Jquery操作radio的简单实例
2014/01/06 Javascript
javascript版的in_array函数(判断数组中是否存在特定值)
2014/05/09 Javascript
JQuery设置获取下拉菜单某个选项的值(比较全)
2014/08/05 Javascript
Javascript 创建类并动态添加属性及方法的简单实现
2016/10/20 Javascript
Bootstrap表单制作代码
2017/03/17 Javascript
在Vue中使用echarts的方法
2018/02/05 Javascript
KOA+egg.js集成kafka消息队列的示例
2018/11/09 Javascript
详解vue中v-model和v-bind绑定数据的异同
2020/08/10 Javascript
jquery实现图片放大镜效果
2020/12/23 jQuery
[06:07]DOTA2-DPC中国联赛3月5日Recap集锦
2021/03/11 DOTA
用Python实现一个简单的多线程TCP服务器的教程
2015/05/05 Python
分析Python中设计模式之Decorator装饰器模式的要点
2016/03/02 Python
python分布式环境下的限流器的示例
2017/10/26 Python
python爬虫使用cookie登录详解
2017/12/27 Python
python生成ppt的方法
2018/06/07 Python
pycharm创建一个python包方法图解
2019/04/10 Python
使用Python实现 学生学籍管理系统
2019/11/26 Python
Python 实现黑客帝国中的字符雨的示例代码
2020/02/20 Python
python virtualenv虚拟环境配置与使用教程详解
2020/07/13 Python
英国领先的电动可调床制造商:Laybrook
2019/12/26 全球购物
伊琍体标语
2014/06/25 职场文书
大学生社会实践活动总结
2014/07/03 职场文书
商场促销活动策划方案
2014/08/18 职场文书
小学生推普周国旗下讲话稿
2014/09/21 职场文书
教师党员自我剖析材料
2014/09/29 职场文书
2015自愿离婚协议书范本
2015/01/28 职场文书
《我的长生果》教学反思
2016/02/20 职场文书
2019通用版劳动合同范本!
2019/07/11 职场文书
nginx配置ssl实现https的方法示例
2021/03/31 Servers
JDBC连接的六步实例代码(与mysql连接)
2021/05/12 MySQL