python梯度下降算法的实现


Posted in Python onFebruary 24, 2020

本文实例为大家分享了python实现梯度下降算法的具体代码,供大家参考,具体内容如下

简介

本文使用python实现了梯度下降算法,支持y = Wx+b的线性回归
目前支持批量梯度算法和随机梯度下降算法(bs=1)
也支持输入特征向量的x维度小于3的图像可视化
代码要求python版本>3.4

代码

'''
梯度下降算法
Batch Gradient Descent
Stochastic Gradient Descent SGD
'''
__author__ = 'epleone'
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import sys

# 使用随机数种子, 让每次的随机数生成相同,方便调试
# np.random.seed(111111111)


class GradientDescent(object):
 eps = 1.0e-8
 max_iter = 1000000 # 暂时不需要
 dim = 1
 func_args = [2.1, 2.7] # [w_0, .., w_dim, b]

 def __init__(self, func_arg=None, N=1000):
 self.data_num = N
 if func_arg is not None:
 self.FuncArgs = func_arg
 self._getData()

 def _getData(self):
 x = 20 * (np.random.rand(self.data_num, self.dim) - 0.5)
 b_1 = np.ones((self.data_num, 1), dtype=np.float)
 # x = np.concatenate((x, b_1), axis=1)
 self.x = np.concatenate((x, b_1), axis=1)

 def func(self, x):
 # noise太大的话, 梯度下降法失去作用
 noise = 0.01 * np.random.randn(self.data_num) + 0
 w = np.array(self.func_args)
 # y1 = w * self.x[0, ] # 直接相乘
 y = np.dot(self.x, w) # 矩阵乘法
 y += noise
 return y

 @property
 def FuncArgs(self):
 return self.func_args

 @FuncArgs.setter
 def FuncArgs(self, args):
 if not isinstance(args, list):
 raise Exception(
 'args is not list, it should be like [w_0, ..., w_dim, b]')
 if len(args) == 0:
 raise Exception('args is empty list!!')
 if len(args) == 1:
 args.append(0.0)
 self.func_args = args
 self.dim = len(args) - 1
 self._getData()

 @property
 def EPS(self):
 return self.eps

 @EPS.setter
 def EPS(self, value):
 if not isinstance(value, float) and not isinstance(value, int):
 raise Exception("The type of eps should be an float number")
 self.eps = value

 def plotFunc(self):
 # 一维画图
 if self.dim == 1:
 # x = np.sort(self.x, axis=0)
 x = self.x
 y = self.func(x)
 fig, ax = plt.subplots()
 ax.plot(x, y, 'o')
 ax.set(xlabel='x ', ylabel='y', title='Loss Curve')
 ax.grid()
 plt.show()
 # 二维画图
 if self.dim == 2:
 # x = np.sort(self.x, axis=0)
 x = self.x
 y = self.func(x)
 xs = x[:, 0]
 ys = x[:, 1]
 zs = y
 fig = plt.figure()
 ax = fig.add_subplot(111, projection='3d')
 ax.scatter(xs, ys, zs, c='r', marker='o')

 ax.set_xlabel('X Label')
 ax.set_ylabel('Y Label')
 ax.set_zlabel('Z Label')
 plt.show()
 else:
 # plt.axis('off')
 plt.text(
 0.5,
 0.5,
 "The dimension(x.dim > 2) \n is too high to draw",
 size=17,
 rotation=0.,
 ha="center",
 va="center",
 bbox=dict(
 boxstyle="round",
 ec=(1., 0.5, 0.5),
 fc=(1., 0.8, 0.8), ))
 plt.draw()
 plt.show()
 # print('The dimension(x.dim > 2) is too high to draw')

 # 梯度下降法只能求解凸函数
 def _gradient_descent(self, bs, lr, epoch):
 x = self.x
 # shuffle数据集没有必要
 # np.random.shuffle(x)
 y = self.func(x)
 w = np.ones((self.dim + 1, 1), dtype=float)
 for e in range(epoch):
 print('epoch:' + str(e), end=',')
 # 批量梯度下降,bs为1时 等价单样本梯度下降
 for i in range(0, self.data_num, bs):
 y_ = np.dot(x[i:i + bs], w)
 loss = y_ - y[i:i + bs].reshape(-1, 1)
 d = loss * x[i:i + bs]
 d = d.sum(axis=0) / bs
 d = lr * d
 d.shape = (-1, 1)
 w = w - d

 y_ = np.dot(self.x, w)
 loss_ = abs((y_ - y).sum())
 print('\tLoss = ' + str(loss_))
 print('拟合的结果为:', end=',')
 print(sum(w.tolist(), []))
 print()
 if loss_ < self.eps:
 print('The Gradient Descent algorithm has converged!!\n')
 break
 pass

 def __call__(self, bs=1, lr=0.1, epoch=10):
 if sys.version_info < (3, 4):
 raise RuntimeError('At least Python 3.4 is required')
 if not isinstance(bs, int) or not isinstance(epoch, int):
 raise Exception(
 "The type of BatchSize/Epoch should be an integer number")
 self._gradient_descent(bs, lr, epoch)
 pass

 pass


if __name__ == "__main__":
 if sys.version_info < (3, 4):
 raise RuntimeError('At least Python 3.4 is required')

 gd = GradientDescent([1.2, 1.4, 2.1, 4.5, 2.1])
 # gd = GradientDescent([1.2, 1.4, 2.1])
 print("要拟合的参数结果是: ")
 print(gd.FuncArgs)
 print("===================\n\n")
 # gd.EPS = 0.0
 gd.plotFunc()
 gd(10, 0.01)
 print("Finished!")

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python实现DES加密解密方法实例详解
Jun 30 Python
Python爬虫之xlml解析库(全面了解)
Aug 08 Python
Python3实现简单可学习的手写体识别(实例讲解)
Oct 21 Python
python实现读取excel写入mysql的小工具详解
Nov 20 Python
python按行读取文件,去掉每行的换行符\n的实例
Apr 19 Python
CentOS7下python3.7.0安装教程
Jul 30 Python
python学习之hook钩子的原理和使用
Oct 25 Python
Python气泡提示与标签的实现
Apr 01 Python
TensorFlow使用Graph的基本操作的实现
Apr 22 Python
Python3 搭建Qt5 环境的方法示例
Jul 16 Python
Python调用Redis的示例代码
Nov 24 Python
详解win10下pytorch-gpu安装以及CUDA详细安装过程
Jan 28 Python
利用python实现逐步回归
Feb 24 #Python
python数据分析:关键字提取方式
Feb 24 #Python
python数据预处理 :数据共线性处理详解
Feb 24 #Python
使用python实现多维数据降维操作
Feb 24 #Python
python数据预处理 :数据抽样解析
Feb 24 #Python
Python找出列表中出现次数最多的元素三种方式
Feb 24 #Python
Python流程控制常用工具详解
Feb 24 #Python
You might like
php入门学习知识点五 关于php数组的几个基本操作
2011/07/14 PHP
基于initPHP的框架介绍
2013/04/18 PHP
解析php类的注册与自动加载
2013/07/05 PHP
php使用Jpgraph绘制饼状图的方法
2015/06/10 PHP
PHP实现基于mysqli的Model基类完整实例
2016/04/08 PHP
《JavaScript函数式编程》读后感
2015/08/07 Javascript
基于Javascript实现弹出页面效果
2016/01/01 Javascript
js实现的二分查找算法实例
2016/01/21 Javascript
JS组件Bootstrap ContextMenu右键菜单使用方法
2016/04/17 Javascript
[原创]Javascript 实现广告后加载 可加载百度谷歌联盟广告
2016/05/11 Javascript
JavaScript &amp; jQuery完美判断图片是否加载完毕
2017/01/08 Javascript
javascript中this用法实例详解
2017/04/06 Javascript
最常用的jQuery表单验证(简单)
2017/05/23 jQuery
jQuery实现动态显示select下拉列表数据的方法
2018/02/05 jQuery
优雅的elementUI table单元格可编辑实现方法详解
2018/12/23 Javascript
vue实现鼠标经过动画
2019/10/16 Javascript
jQuery 常用特效实例小结【显示与隐藏、淡入淡出、滑动、动画等】
2020/05/19 jQuery
[01:12]快闪回顾DOTA2亚洲邀请赛(DAC) 静候2018新征程开启
2018/03/11 DOTA
浅析Python 中整型对象存储的位置
2016/05/16 Python
Python使用正则表达式过滤或替换HTML标签的方法详解
2017/09/25 Python
利用TensorFlow训练简单的二分类神经网络模型的方法
2018/03/05 Python
浅谈Python批处理文件夹中的txt文件
2019/03/11 Python
python提取照片坐标信息的实例代码
2019/08/14 Python
Python 多线程搜索txt文件的内容,并写入搜到的内容(Lock)方法
2019/08/23 Python
python获取响应某个字段值的3种实现方法
2020/04/30 Python
Python新手学习raise用法
2020/06/03 Python
德国街头和运动文化高品质商店:BSTN Store
2017/08/26 全球购物
卡西欧B级产品官方网站:Casio Outlet
2018/05/22 全球购物
TCP/IP的分层模型
2013/10/27 面试题
SOA的常见陷阱或者误解是什么
2014/10/05 面试题
乡镇消防安全责任书
2014/07/23 职场文书
家庭教育的心得体会
2014/09/01 职场文书
农村党支部书记四风问题个人对照检查材料
2014/09/21 职场文书
2014年设计师工作总结
2014/11/25 职场文书
为什么代码规范要求SQL语句不要过多的join
2021/06/23 MySQL
解决Laravel使用验证时跳转到首页的问题
2021/11/17 PHP