python梯度下降算法的实现


Posted in Python onFebruary 24, 2020

本文实例为大家分享了python实现梯度下降算法的具体代码,供大家参考,具体内容如下

简介

本文使用python实现了梯度下降算法,支持y = Wx+b的线性回归
目前支持批量梯度算法和随机梯度下降算法(bs=1)
也支持输入特征向量的x维度小于3的图像可视化
代码要求python版本>3.4

代码

'''
梯度下降算法
Batch Gradient Descent
Stochastic Gradient Descent SGD
'''
__author__ = 'epleone'
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import sys

# 使用随机数种子, 让每次的随机数生成相同,方便调试
# np.random.seed(111111111)


class GradientDescent(object):
 eps = 1.0e-8
 max_iter = 1000000 # 暂时不需要
 dim = 1
 func_args = [2.1, 2.7] # [w_0, .., w_dim, b]

 def __init__(self, func_arg=None, N=1000):
 self.data_num = N
 if func_arg is not None:
 self.FuncArgs = func_arg
 self._getData()

 def _getData(self):
 x = 20 * (np.random.rand(self.data_num, self.dim) - 0.5)
 b_1 = np.ones((self.data_num, 1), dtype=np.float)
 # x = np.concatenate((x, b_1), axis=1)
 self.x = np.concatenate((x, b_1), axis=1)

 def func(self, x):
 # noise太大的话, 梯度下降法失去作用
 noise = 0.01 * np.random.randn(self.data_num) + 0
 w = np.array(self.func_args)
 # y1 = w * self.x[0, ] # 直接相乘
 y = np.dot(self.x, w) # 矩阵乘法
 y += noise
 return y

 @property
 def FuncArgs(self):
 return self.func_args

 @FuncArgs.setter
 def FuncArgs(self, args):
 if not isinstance(args, list):
 raise Exception(
 'args is not list, it should be like [w_0, ..., w_dim, b]')
 if len(args) == 0:
 raise Exception('args is empty list!!')
 if len(args) == 1:
 args.append(0.0)
 self.func_args = args
 self.dim = len(args) - 1
 self._getData()

 @property
 def EPS(self):
 return self.eps

 @EPS.setter
 def EPS(self, value):
 if not isinstance(value, float) and not isinstance(value, int):
 raise Exception("The type of eps should be an float number")
 self.eps = value

 def plotFunc(self):
 # 一维画图
 if self.dim == 1:
 # x = np.sort(self.x, axis=0)
 x = self.x
 y = self.func(x)
 fig, ax = plt.subplots()
 ax.plot(x, y, 'o')
 ax.set(xlabel='x ', ylabel='y', title='Loss Curve')
 ax.grid()
 plt.show()
 # 二维画图
 if self.dim == 2:
 # x = np.sort(self.x, axis=0)
 x = self.x
 y = self.func(x)
 xs = x[:, 0]
 ys = x[:, 1]
 zs = y
 fig = plt.figure()
 ax = fig.add_subplot(111, projection='3d')
 ax.scatter(xs, ys, zs, c='r', marker='o')

 ax.set_xlabel('X Label')
 ax.set_ylabel('Y Label')
 ax.set_zlabel('Z Label')
 plt.show()
 else:
 # plt.axis('off')
 plt.text(
 0.5,
 0.5,
 "The dimension(x.dim > 2) \n is too high to draw",
 size=17,
 rotation=0.,
 ha="center",
 va="center",
 bbox=dict(
 boxstyle="round",
 ec=(1., 0.5, 0.5),
 fc=(1., 0.8, 0.8), ))
 plt.draw()
 plt.show()
 # print('The dimension(x.dim > 2) is too high to draw')

 # 梯度下降法只能求解凸函数
 def _gradient_descent(self, bs, lr, epoch):
 x = self.x
 # shuffle数据集没有必要
 # np.random.shuffle(x)
 y = self.func(x)
 w = np.ones((self.dim + 1, 1), dtype=float)
 for e in range(epoch):
 print('epoch:' + str(e), end=',')
 # 批量梯度下降,bs为1时 等价单样本梯度下降
 for i in range(0, self.data_num, bs):
 y_ = np.dot(x[i:i + bs], w)
 loss = y_ - y[i:i + bs].reshape(-1, 1)
 d = loss * x[i:i + bs]
 d = d.sum(axis=0) / bs
 d = lr * d
 d.shape = (-1, 1)
 w = w - d

 y_ = np.dot(self.x, w)
 loss_ = abs((y_ - y).sum())
 print('\tLoss = ' + str(loss_))
 print('拟合的结果为:', end=',')
 print(sum(w.tolist(), []))
 print()
 if loss_ < self.eps:
 print('The Gradient Descent algorithm has converged!!\n')
 break
 pass

 def __call__(self, bs=1, lr=0.1, epoch=10):
 if sys.version_info < (3, 4):
 raise RuntimeError('At least Python 3.4 is required')
 if not isinstance(bs, int) or not isinstance(epoch, int):
 raise Exception(
 "The type of BatchSize/Epoch should be an integer number")
 self._gradient_descent(bs, lr, epoch)
 pass

 pass


if __name__ == "__main__":
 if sys.version_info < (3, 4):
 raise RuntimeError('At least Python 3.4 is required')

 gd = GradientDescent([1.2, 1.4, 2.1, 4.5, 2.1])
 # gd = GradientDescent([1.2, 1.4, 2.1])
 print("要拟合的参数结果是: ")
 print(gd.FuncArgs)
 print("===================\n\n")
 # gd.EPS = 0.0
 gd.plotFunc()
 gd(10, 0.01)
 print("Finished!")

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python下使用Psyco模块优化运行速度
Apr 05 Python
PYTHON压平嵌套列表的简单实现
Jun 08 Python
Python面向对象编程中关于类和方法的学习笔记
Jun 30 Python
Python OpenCV实现图片上输出中文
Jan 22 Python
python 输出上个月的月末日期实例
Apr 11 Python
基于Python 装饰器装饰类中的方法实例
Apr 21 Python
int在python中的含义以及用法
Jun 27 Python
解决tensorflow添加ptb库的问题
Feb 10 Python
浅谈opencv自动光学检测、目标分割和检测(连通区域和findContours)
Jun 04 Python
Pandas缺失值2种处理方式代码实例
Jun 13 Python
总结Pyinstaller打包的高级用法
Jun 28 Python
使用opencv-python如何打开USB或者笔记本前置摄像头
Jun 21 Python
利用python实现逐步回归
Feb 24 #Python
python数据分析:关键字提取方式
Feb 24 #Python
python数据预处理 :数据共线性处理详解
Feb 24 #Python
使用python实现多维数据降维操作
Feb 24 #Python
python数据预处理 :数据抽样解析
Feb 24 #Python
Python找出列表中出现次数最多的元素三种方式
Feb 24 #Python
Python流程控制常用工具详解
Feb 24 #Python
You might like
服务器端解压缩zip的脚本
2006/12/22 PHP
adodb与adodb_lite之比较
2006/12/31 PHP
解决dede生成静态页和动态页转换的一些问题,及火车采集入库生成动态的办法
2007/03/29 PHP
在win7中搭建Linux+PHP 开发环境
2014/10/08 PHP
php身份证号码检查类实例
2015/06/18 PHP
Thinkphp和onethink实现微信支付插件
2016/04/13 PHP
Nginx环境下PHP flush失效的解决方法
2016/10/19 PHP
JavaScript Undefined,Null类型和NaN值区别
2008/10/22 Javascript
jquery蒙版控件实现代码
2010/12/08 Javascript
JavaScript中检查对象property的存在性方法介绍
2014/12/30 Javascript
jquery UI Datepicker时间控件的使用方法(基础版)
2015/11/07 Javascript
JavaScript 消息框效果【实现代码】
2016/04/27 Javascript
[原创]Javascript 实现广告后加载 可加载百度谷歌联盟广告
2016/05/11 Javascript
Javascript 基础---Ajax入门必看
2016/07/06 Javascript
浅谈Javascript数据属性与访问器属性
2016/07/26 Javascript
原生js验证简洁注册登录页面
2016/12/17 Javascript
js实现手机发送验证码功能
2017/03/13 Javascript
JavaScript获取URL参数的方法之一
2017/03/24 Javascript
JS正则验证多个邮箱完整实例【邮箱用分号隔开】
2017/04/19 Javascript
webpack学习--webpack经典7分钟入门教程
2017/06/28 Javascript
vue 组件中slot插口的具体用法
2018/04/03 Javascript
JS数组求和的常用方法总结【5种方法】
2019/01/14 Javascript
vue开发chrome插件,实现获取界面数据和保存到数据库功能
2020/12/01 Vue.js
[07:38]2014DOTA2国际邀请赛 Newbee顺利挺进胜者组赛后专访
2014/07/15 DOTA
Python进阶学习之特殊方法实例详析
2017/12/01 Python
解决Ubuntu pip 安装 mysql-python包出错的问题
2018/06/11 Python
python3发送request请求及查看返回结果实例
2020/04/30 Python
欧洲领先的电子和电信零售商和服务提供商:Currys PC World Business
2017/12/05 全球购物
Java里面如何把一个Array数组转换成Collection, List
2013/07/26 面试题
四年级下册教学反思
2014/02/01 职场文书
设计师个人求职信范文
2014/02/02 职场文书
化学教学随笔感言
2014/02/19 职场文书
村主任“四风”问题个人整改措施
2014/10/04 职场文书
区域经理岗位职责
2015/02/02 职场文书
mongodb的安装和开机自启动详细讲解
2021/08/02 MongoDB
【js设计模式】SOLID五大设计原则
2022/03/24 Javascript