python梯度下降算法的实现


Posted in Python onFebruary 24, 2020

本文实例为大家分享了python实现梯度下降算法的具体代码,供大家参考,具体内容如下

简介

本文使用python实现了梯度下降算法,支持y = Wx+b的线性回归
目前支持批量梯度算法和随机梯度下降算法(bs=1)
也支持输入特征向量的x维度小于3的图像可视化
代码要求python版本>3.4

代码

'''
梯度下降算法
Batch Gradient Descent
Stochastic Gradient Descent SGD
'''
__author__ = 'epleone'
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import sys

# 使用随机数种子, 让每次的随机数生成相同,方便调试
# np.random.seed(111111111)


class GradientDescent(object):
 eps = 1.0e-8
 max_iter = 1000000 # 暂时不需要
 dim = 1
 func_args = [2.1, 2.7] # [w_0, .., w_dim, b]

 def __init__(self, func_arg=None, N=1000):
 self.data_num = N
 if func_arg is not None:
 self.FuncArgs = func_arg
 self._getData()

 def _getData(self):
 x = 20 * (np.random.rand(self.data_num, self.dim) - 0.5)
 b_1 = np.ones((self.data_num, 1), dtype=np.float)
 # x = np.concatenate((x, b_1), axis=1)
 self.x = np.concatenate((x, b_1), axis=1)

 def func(self, x):
 # noise太大的话, 梯度下降法失去作用
 noise = 0.01 * np.random.randn(self.data_num) + 0
 w = np.array(self.func_args)
 # y1 = w * self.x[0, ] # 直接相乘
 y = np.dot(self.x, w) # 矩阵乘法
 y += noise
 return y

 @property
 def FuncArgs(self):
 return self.func_args

 @FuncArgs.setter
 def FuncArgs(self, args):
 if not isinstance(args, list):
 raise Exception(
 'args is not list, it should be like [w_0, ..., w_dim, b]')
 if len(args) == 0:
 raise Exception('args is empty list!!')
 if len(args) == 1:
 args.append(0.0)
 self.func_args = args
 self.dim = len(args) - 1
 self._getData()

 @property
 def EPS(self):
 return self.eps

 @EPS.setter
 def EPS(self, value):
 if not isinstance(value, float) and not isinstance(value, int):
 raise Exception("The type of eps should be an float number")
 self.eps = value

 def plotFunc(self):
 # 一维画图
 if self.dim == 1:
 # x = np.sort(self.x, axis=0)
 x = self.x
 y = self.func(x)
 fig, ax = plt.subplots()
 ax.plot(x, y, 'o')
 ax.set(xlabel='x ', ylabel='y', title='Loss Curve')
 ax.grid()
 plt.show()
 # 二维画图
 if self.dim == 2:
 # x = np.sort(self.x, axis=0)
 x = self.x
 y = self.func(x)
 xs = x[:, 0]
 ys = x[:, 1]
 zs = y
 fig = plt.figure()
 ax = fig.add_subplot(111, projection='3d')
 ax.scatter(xs, ys, zs, c='r', marker='o')

 ax.set_xlabel('X Label')
 ax.set_ylabel('Y Label')
 ax.set_zlabel('Z Label')
 plt.show()
 else:
 # plt.axis('off')
 plt.text(
 0.5,
 0.5,
 "The dimension(x.dim > 2) \n is too high to draw",
 size=17,
 rotation=0.,
 ha="center",
 va="center",
 bbox=dict(
 boxstyle="round",
 ec=(1., 0.5, 0.5),
 fc=(1., 0.8, 0.8), ))
 plt.draw()
 plt.show()
 # print('The dimension(x.dim > 2) is too high to draw')

 # 梯度下降法只能求解凸函数
 def _gradient_descent(self, bs, lr, epoch):
 x = self.x
 # shuffle数据集没有必要
 # np.random.shuffle(x)
 y = self.func(x)
 w = np.ones((self.dim + 1, 1), dtype=float)
 for e in range(epoch):
 print('epoch:' + str(e), end=',')
 # 批量梯度下降,bs为1时 等价单样本梯度下降
 for i in range(0, self.data_num, bs):
 y_ = np.dot(x[i:i + bs], w)
 loss = y_ - y[i:i + bs].reshape(-1, 1)
 d = loss * x[i:i + bs]
 d = d.sum(axis=0) / bs
 d = lr * d
 d.shape = (-1, 1)
 w = w - d

 y_ = np.dot(self.x, w)
 loss_ = abs((y_ - y).sum())
 print('\tLoss = ' + str(loss_))
 print('拟合的结果为:', end=',')
 print(sum(w.tolist(), []))
 print()
 if loss_ < self.eps:
 print('The Gradient Descent algorithm has converged!!\n')
 break
 pass

 def __call__(self, bs=1, lr=0.1, epoch=10):
 if sys.version_info < (3, 4):
 raise RuntimeError('At least Python 3.4 is required')
 if not isinstance(bs, int) or not isinstance(epoch, int):
 raise Exception(
 "The type of BatchSize/Epoch should be an integer number")
 self._gradient_descent(bs, lr, epoch)
 pass

 pass


if __name__ == "__main__":
 if sys.version_info < (3, 4):
 raise RuntimeError('At least Python 3.4 is required')

 gd = GradientDescent([1.2, 1.4, 2.1, 4.5, 2.1])
 # gd = GradientDescent([1.2, 1.4, 2.1])
 print("要拟合的参数结果是: ")
 print(gd.FuncArgs)
 print("===================\n\n")
 # gd.EPS = 0.0
 gd.plotFunc()
 gd(10, 0.01)
 print("Finished!")

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python操作MySQL数据库9个实用实例
Dec 11 Python
Python搭建HTTP服务器和FTP服务器
Mar 09 Python
python+django+sql学生信息管理后台开发
Jan 11 Python
Python查找文件中包含中文的行方法
Dec 19 Python
使用Python制作简单的小程序IP查看器功能
Apr 16 Python
Python项目 基于Scapy实现SYN泛洪攻击的方法
Jul 23 Python
Python3 requests文件下载 期间显示文件信息和下载进度代码实例
Aug 16 Python
PyTorch使用cpu加载模型运算方式
Jan 13 Python
python 如何快速复制序列
Sep 07 Python
python opencv人脸识别考勤系统的完整源码
Apr 26 Python
PyTorch中permute的使用方法
Apr 26 Python
pytorch实现加载保存查看checkpoint文件
Jul 15 Python
利用python实现逐步回归
Feb 24 #Python
python数据分析:关键字提取方式
Feb 24 #Python
python数据预处理 :数据共线性处理详解
Feb 24 #Python
使用python实现多维数据降维操作
Feb 24 #Python
python数据预处理 :数据抽样解析
Feb 24 #Python
Python找出列表中出现次数最多的元素三种方式
Feb 24 #Python
Python流程控制常用工具详解
Feb 24 #Python
You might like
php图片验证码代码
2008/03/27 PHP
php 解压rar文件及zip文件的方法
2014/05/05 PHP
PHP实现文件下载详解
2014/11/27 PHP
jQuery实现表头固定效果的实例代码
2013/05/24 Javascript
鼠标选择动态改变网页背景颜色的JS代码
2013/12/10 Javascript
JavaScript实现穷举排列(permutation)算法谜题解答
2014/12/29 Javascript
JavaScript插件化开发教程(五)
2015/02/01 Javascript
jQuery实现当前页面标签高亮显示的方法
2015/03/10 Javascript
对象题目的一个坑 理解Javascript对象
2015/12/22 Javascript
jquery获取所有选中的checkbox实现代码
2016/05/26 Javascript
jQuery简单实现tab选项卡切换效果
2016/06/20 Javascript
jstree创建无限分级树的方法【基于ajax动态创建子节点】
2016/10/25 Javascript
jquery延迟对象解析
2016/10/26 Javascript
Javascript 闭包详解及实例代码
2016/11/30 Javascript
JS实现物体带缓冲的间歇运动效果示例
2016/12/22 Javascript
详解javascript常用工具类的封装
2018/01/30 Javascript
如何利用ES6进行Promise封装总结
2019/02/11 Javascript
JavaScript鼠标拖拽事件详解
2020/04/03 Javascript
[30:00]完美世界DOTA2联赛PWL S2 Rebirth vs LBZS 第二场 11.28
2020/12/01 DOTA
python正则表达式去掉数字中的逗号(python正则匹配逗号)
2013/12/25 Python
Centos5.x下升级python到python2.7版本教程
2015/02/14 Python
在Python中使用成员运算符的示例
2015/05/13 Python
python实时监控cpu小工具
2018/06/21 Python
使用Python如何测试InnoDB与MyISAM的读写性能
2018/09/18 Python
浅谈python中频繁的print到底能浪费多长时间
2020/02/21 Python
英国网络托管和域名领导者:Web Hosting UK
2017/10/15 全球购物
英国123鲜花网站:123 Flowers
2019/07/07 全球购物
澳大利亚波希米亚风时尚品牌:Tree of Life
2019/09/15 全球购物
英国复古服装购物网站:Collectif
2019/10/30 全球购物
L’Artisan Parfumeur官网:法国香水品牌
2020/08/11 全球购物
《唯一的听众》教学反思
2014/02/20 职场文书
小学教师暑期培训心得体会
2016/01/09 职场文书
MySQL官方导出工具mysqlpump的使用
2021/05/21 MySQL
浅谈react useEffect闭包的坑
2021/06/08 Javascript
PHP使用QR Code生成二维码实例
2021/07/07 PHP
SQL语句多表联合查询的方法示例
2022/04/18 MySQL