python梯度下降算法的实现


Posted in Python onFebruary 24, 2020

本文实例为大家分享了python实现梯度下降算法的具体代码,供大家参考,具体内容如下

简介

本文使用python实现了梯度下降算法,支持y = Wx+b的线性回归
目前支持批量梯度算法和随机梯度下降算法(bs=1)
也支持输入特征向量的x维度小于3的图像可视化
代码要求python版本>3.4

代码

'''
梯度下降算法
Batch Gradient Descent
Stochastic Gradient Descent SGD
'''
__author__ = 'epleone'
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import sys

# 使用随机数种子, 让每次的随机数生成相同,方便调试
# np.random.seed(111111111)


class GradientDescent(object):
 eps = 1.0e-8
 max_iter = 1000000 # 暂时不需要
 dim = 1
 func_args = [2.1, 2.7] # [w_0, .., w_dim, b]

 def __init__(self, func_arg=None, N=1000):
 self.data_num = N
 if func_arg is not None:
 self.FuncArgs = func_arg
 self._getData()

 def _getData(self):
 x = 20 * (np.random.rand(self.data_num, self.dim) - 0.5)
 b_1 = np.ones((self.data_num, 1), dtype=np.float)
 # x = np.concatenate((x, b_1), axis=1)
 self.x = np.concatenate((x, b_1), axis=1)

 def func(self, x):
 # noise太大的话, 梯度下降法失去作用
 noise = 0.01 * np.random.randn(self.data_num) + 0
 w = np.array(self.func_args)
 # y1 = w * self.x[0, ] # 直接相乘
 y = np.dot(self.x, w) # 矩阵乘法
 y += noise
 return y

 @property
 def FuncArgs(self):
 return self.func_args

 @FuncArgs.setter
 def FuncArgs(self, args):
 if not isinstance(args, list):
 raise Exception(
 'args is not list, it should be like [w_0, ..., w_dim, b]')
 if len(args) == 0:
 raise Exception('args is empty list!!')
 if len(args) == 1:
 args.append(0.0)
 self.func_args = args
 self.dim = len(args) - 1
 self._getData()

 @property
 def EPS(self):
 return self.eps

 @EPS.setter
 def EPS(self, value):
 if not isinstance(value, float) and not isinstance(value, int):
 raise Exception("The type of eps should be an float number")
 self.eps = value

 def plotFunc(self):
 # 一维画图
 if self.dim == 1:
 # x = np.sort(self.x, axis=0)
 x = self.x
 y = self.func(x)
 fig, ax = plt.subplots()
 ax.plot(x, y, 'o')
 ax.set(xlabel='x ', ylabel='y', title='Loss Curve')
 ax.grid()
 plt.show()
 # 二维画图
 if self.dim == 2:
 # x = np.sort(self.x, axis=0)
 x = self.x
 y = self.func(x)
 xs = x[:, 0]
 ys = x[:, 1]
 zs = y
 fig = plt.figure()
 ax = fig.add_subplot(111, projection='3d')
 ax.scatter(xs, ys, zs, c='r', marker='o')

 ax.set_xlabel('X Label')
 ax.set_ylabel('Y Label')
 ax.set_zlabel('Z Label')
 plt.show()
 else:
 # plt.axis('off')
 plt.text(
 0.5,
 0.5,
 "The dimension(x.dim > 2) \n is too high to draw",
 size=17,
 rotation=0.,
 ha="center",
 va="center",
 bbox=dict(
 boxstyle="round",
 ec=(1., 0.5, 0.5),
 fc=(1., 0.8, 0.8), ))
 plt.draw()
 plt.show()
 # print('The dimension(x.dim > 2) is too high to draw')

 # 梯度下降法只能求解凸函数
 def _gradient_descent(self, bs, lr, epoch):
 x = self.x
 # shuffle数据集没有必要
 # np.random.shuffle(x)
 y = self.func(x)
 w = np.ones((self.dim + 1, 1), dtype=float)
 for e in range(epoch):
 print('epoch:' + str(e), end=',')
 # 批量梯度下降,bs为1时 等价单样本梯度下降
 for i in range(0, self.data_num, bs):
 y_ = np.dot(x[i:i + bs], w)
 loss = y_ - y[i:i + bs].reshape(-1, 1)
 d = loss * x[i:i + bs]
 d = d.sum(axis=0) / bs
 d = lr * d
 d.shape = (-1, 1)
 w = w - d

 y_ = np.dot(self.x, w)
 loss_ = abs((y_ - y).sum())
 print('\tLoss = ' + str(loss_))
 print('拟合的结果为:', end=',')
 print(sum(w.tolist(), []))
 print()
 if loss_ < self.eps:
 print('The Gradient Descent algorithm has converged!!\n')
 break
 pass

 def __call__(self, bs=1, lr=0.1, epoch=10):
 if sys.version_info < (3, 4):
 raise RuntimeError('At least Python 3.4 is required')
 if not isinstance(bs, int) or not isinstance(epoch, int):
 raise Exception(
 "The type of BatchSize/Epoch should be an integer number")
 self._gradient_descent(bs, lr, epoch)
 pass

 pass


if __name__ == "__main__":
 if sys.version_info < (3, 4):
 raise RuntimeError('At least Python 3.4 is required')

 gd = GradientDescent([1.2, 1.4, 2.1, 4.5, 2.1])
 # gd = GradientDescent([1.2, 1.4, 2.1])
 print("要拟合的参数结果是: ")
 print(gd.FuncArgs)
 print("===================\n\n")
 # gd.EPS = 0.0
 gd.plotFunc()
 gd(10, 0.01)
 print("Finished!")

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
详解Python中expandtabs()方法的使用
May 18 Python
详解python脚本自动生成需要文件实例代码
Feb 04 Python
python中import reload __import__的区别详解
Oct 16 Python
Python3一行代码实现图片文字识别的示例
Jan 15 Python
Python处理菜单消息操作示例【基于win32ui模块】
May 09 Python
python清除字符串中间空格的实例讲解
May 11 Python
详解python3中的真值测试
Aug 13 Python
python实现RabbitMQ的消息队列的示例代码
Nov 08 Python
Python实现随机创建电话号码的方法示例
Dec 07 Python
python通过链接抓取网站详解
Nov 20 Python
Python爬虫爬取糗事百科段子实例分享
Jul 31 Python
python爬取代理ip的示例
Dec 18 Python
利用python实现逐步回归
Feb 24 #Python
python数据分析:关键字提取方式
Feb 24 #Python
python数据预处理 :数据共线性处理详解
Feb 24 #Python
使用python实现多维数据降维操作
Feb 24 #Python
python数据预处理 :数据抽样解析
Feb 24 #Python
Python找出列表中出现次数最多的元素三种方式
Feb 24 #Python
Python流程控制常用工具详解
Feb 24 #Python
You might like
一个漂亮的php验证码类(分享)
2013/08/06 PHP
PHP实现获取第一个中文首字母并进行排序的方法
2017/05/09 PHP
php魔法函数与魔法常量使用介绍
2017/07/23 PHP
javascript 弹出窗口中是否显示地址栏的实现代码
2011/04/14 Javascript
40款非常棒的jQuery 插件和制作教程(系列一)
2011/10/26 Javascript
控制页面按钮在后台执行期间不重复提交的JS方法
2013/06/24 Javascript
setTimeout()递归调用不加引号出错的解决方法
2014/09/05 Javascript
javascript实现table选中的行以指定颜色高亮显示的方法
2015/05/13 Javascript
JavaScript组合模式学习要点
2016/08/26 Javascript
js实现非常棒的弹出div
2016/10/06 Javascript
jQuery动态添加与删除tr行实例代码
2016/10/18 Javascript
用js实现博客打赏功能
2016/10/24 Javascript
详解ECharts使用心得总结
2016/12/06 Javascript
ndm:NPM的桌面GUI应用程序
2018/10/15 Javascript
vue全屏事件开发详解
2020/06/17 Javascript
[02:17]《辉夜杯》TRG战队巡礼
2015/10/26 DOTA
[45:44]完美世界DOTA2联赛PWL S2 FTD vs PXG 第一场 11.27
2020/12/01 DOTA
python常见的格式化输出小结
2016/12/15 Python
Python的SimpleHTTPServer模块用处及使用方法简介
2018/01/22 Python
Windows系统Python直接调用C++ DLL的方法
2019/08/01 Python
Python利用for循环打印星号三角形的案例
2020/04/12 Python
Python之多进程与多线程的使用
2021/02/23 Python
测绘工程个人的自我评价
2013/11/10 职场文书
文秘大学生求职信
2014/02/25 职场文书
建筑工程技术专业求职信
2014/07/16 职场文书
经典演讲稿开场白
2014/08/25 职场文书
升国旗演讲稿
2014/09/05 职场文书
一份文言文检讨书
2014/09/13 职场文书
初中生庆国庆演讲稿范文2014
2014/09/25 职场文书
优秀共产党员推荐材料
2014/12/18 职场文书
新闻通讯稿模板
2015/07/22 职场文书
2016年学校十一国庆节活动总结
2016/04/01 职场文书
详细聊聊vue中组件的props属性
2021/11/02 Vue.js
Redis调用Lua脚本及使用场景快速掌握
2022/03/16 Redis
Redis监控工具RedisInsight安装与使用
2022/03/21 Redis
关于MySQL临时表为什么可以重名的问题
2022/03/22 MySQL