python梯度下降算法的实现


Posted in Python onFebruary 24, 2020

本文实例为大家分享了python实现梯度下降算法的具体代码,供大家参考,具体内容如下

简介

本文使用python实现了梯度下降算法,支持y = Wx+b的线性回归
目前支持批量梯度算法和随机梯度下降算法(bs=1)
也支持输入特征向量的x维度小于3的图像可视化
代码要求python版本>3.4

代码

'''
梯度下降算法
Batch Gradient Descent
Stochastic Gradient Descent SGD
'''
__author__ = 'epleone'
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import sys

# 使用随机数种子, 让每次的随机数生成相同,方便调试
# np.random.seed(111111111)


class GradientDescent(object):
 eps = 1.0e-8
 max_iter = 1000000 # 暂时不需要
 dim = 1
 func_args = [2.1, 2.7] # [w_0, .., w_dim, b]

 def __init__(self, func_arg=None, N=1000):
 self.data_num = N
 if func_arg is not None:
 self.FuncArgs = func_arg
 self._getData()

 def _getData(self):
 x = 20 * (np.random.rand(self.data_num, self.dim) - 0.5)
 b_1 = np.ones((self.data_num, 1), dtype=np.float)
 # x = np.concatenate((x, b_1), axis=1)
 self.x = np.concatenate((x, b_1), axis=1)

 def func(self, x):
 # noise太大的话, 梯度下降法失去作用
 noise = 0.01 * np.random.randn(self.data_num) + 0
 w = np.array(self.func_args)
 # y1 = w * self.x[0, ] # 直接相乘
 y = np.dot(self.x, w) # 矩阵乘法
 y += noise
 return y

 @property
 def FuncArgs(self):
 return self.func_args

 @FuncArgs.setter
 def FuncArgs(self, args):
 if not isinstance(args, list):
 raise Exception(
 'args is not list, it should be like [w_0, ..., w_dim, b]')
 if len(args) == 0:
 raise Exception('args is empty list!!')
 if len(args) == 1:
 args.append(0.0)
 self.func_args = args
 self.dim = len(args) - 1
 self._getData()

 @property
 def EPS(self):
 return self.eps

 @EPS.setter
 def EPS(self, value):
 if not isinstance(value, float) and not isinstance(value, int):
 raise Exception("The type of eps should be an float number")
 self.eps = value

 def plotFunc(self):
 # 一维画图
 if self.dim == 1:
 # x = np.sort(self.x, axis=0)
 x = self.x
 y = self.func(x)
 fig, ax = plt.subplots()
 ax.plot(x, y, 'o')
 ax.set(xlabel='x ', ylabel='y', title='Loss Curve')
 ax.grid()
 plt.show()
 # 二维画图
 if self.dim == 2:
 # x = np.sort(self.x, axis=0)
 x = self.x
 y = self.func(x)
 xs = x[:, 0]
 ys = x[:, 1]
 zs = y
 fig = plt.figure()
 ax = fig.add_subplot(111, projection='3d')
 ax.scatter(xs, ys, zs, c='r', marker='o')

 ax.set_xlabel('X Label')
 ax.set_ylabel('Y Label')
 ax.set_zlabel('Z Label')
 plt.show()
 else:
 # plt.axis('off')
 plt.text(
 0.5,
 0.5,
 "The dimension(x.dim > 2) \n is too high to draw",
 size=17,
 rotation=0.,
 ha="center",
 va="center",
 bbox=dict(
 boxstyle="round",
 ec=(1., 0.5, 0.5),
 fc=(1., 0.8, 0.8), ))
 plt.draw()
 plt.show()
 # print('The dimension(x.dim > 2) is too high to draw')

 # 梯度下降法只能求解凸函数
 def _gradient_descent(self, bs, lr, epoch):
 x = self.x
 # shuffle数据集没有必要
 # np.random.shuffle(x)
 y = self.func(x)
 w = np.ones((self.dim + 1, 1), dtype=float)
 for e in range(epoch):
 print('epoch:' + str(e), end=',')
 # 批量梯度下降,bs为1时 等价单样本梯度下降
 for i in range(0, self.data_num, bs):
 y_ = np.dot(x[i:i + bs], w)
 loss = y_ - y[i:i + bs].reshape(-1, 1)
 d = loss * x[i:i + bs]
 d = d.sum(axis=0) / bs
 d = lr * d
 d.shape = (-1, 1)
 w = w - d

 y_ = np.dot(self.x, w)
 loss_ = abs((y_ - y).sum())
 print('\tLoss = ' + str(loss_))
 print('拟合的结果为:', end=',')
 print(sum(w.tolist(), []))
 print()
 if loss_ < self.eps:
 print('The Gradient Descent algorithm has converged!!\n')
 break
 pass

 def __call__(self, bs=1, lr=0.1, epoch=10):
 if sys.version_info < (3, 4):
 raise RuntimeError('At least Python 3.4 is required')
 if not isinstance(bs, int) or not isinstance(epoch, int):
 raise Exception(
 "The type of BatchSize/Epoch should be an integer number")
 self._gradient_descent(bs, lr, epoch)
 pass

 pass


if __name__ == "__main__":
 if sys.version_info < (3, 4):
 raise RuntimeError('At least Python 3.4 is required')

 gd = GradientDescent([1.2, 1.4, 2.1, 4.5, 2.1])
 # gd = GradientDescent([1.2, 1.4, 2.1])
 print("要拟合的参数结果是: ")
 print(gd.FuncArgs)
 print("===================\n\n")
 # gd.EPS = 0.0
 gd.plotFunc()
 gd(10, 0.01)
 print("Finished!")

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Windows和Linux下Python输出彩色文字的方法教程
May 02 Python
python数据类型_元组、字典常用操作方法(介绍)
May 30 Python
Python与人工神经网络:使用神经网络识别手写图像介绍
Dec 19 Python
Python输入二维数组方法
Apr 13 Python
pytorch下大型数据集(大型图片)的导入方式
Jan 08 Python
Python正则表达式学习小例子
Mar 03 Python
Python 生成VOC格式的标签实例
Mar 10 Python
Python求凸包及多边形面积教程
Apr 12 Python
python框架flask入门之环境搭建及开启调试
Jun 07 Python
Python类class参数self原理解析
Nov 19 Python
如何设置PyCharm中的Python代码模版(推荐)
Nov 20 Python
上手简单,功能强大的Python爬虫框架——feapder
Apr 27 Python
利用python实现逐步回归
Feb 24 #Python
python数据分析:关键字提取方式
Feb 24 #Python
python数据预处理 :数据共线性处理详解
Feb 24 #Python
使用python实现多维数据降维操作
Feb 24 #Python
python数据预处理 :数据抽样解析
Feb 24 #Python
Python找出列表中出现次数最多的元素三种方式
Feb 24 #Python
Python流程控制常用工具详解
Feb 24 #Python
You might like
PHP源代码数组统计count分析
2011/08/02 PHP
PHP禁止个别IP访问网站
2013/10/30 PHP
Zend Framework教程之Zend_Db_Table表关联实例详解
2016/03/23 PHP
Laravel 在views中加载公共页面的实现代码
2019/10/22 PHP
php使用Swoole实现毫秒级定时任务的方法
2020/09/04 PHP
Javascript 模式实例 观察者模式
2009/10/24 Javascript
JavaScript对象数组的排序处理方法
2015/10/21 Javascript
深入浅析JavaScript中数据共享和数据传递
2016/04/25 Javascript
15位和18位身份证JS校验的简单实例
2016/07/18 Javascript
BootStrap树状图显示功能
2016/11/24 Javascript
JS去除重复并统计数量的实现方法
2016/12/15 Javascript
纯JS实现弹性导航条效果
2017/03/06 Javascript
jQuery实现分页功能(含ajax请求、后台数据、附完整demo)
2017/04/03 jQuery
Angularjs自定义指令Directive详解
2017/05/27 Javascript
nodejs实现范围请求的实现代码
2018/10/12 NodeJs
JS双向链表实现与使用方法示例(增加一个previous属性实现)
2019/01/31 Javascript
基于JS实现视频上传显示进度条
2020/05/12 Javascript
JS实现多选框的操作
2020/06/24 Javascript
[06:33]DOTA2亚洲邀请赛小组赛第二日 TOP10精彩集锦
2015/01/31 DOTA
[01:07:57]DOTA2-DPC中国联赛 正赛 Ehome vs Magma BO3 第二场 1月19日
2021/03/11 DOTA
使用Python AIML搭建聊天机器人的方法示例
2018/07/09 Python
python实现连连看辅助之图像识别延伸
2019/07/17 Python
导入tensorflow时报错:cannot import name 'abs'的解决
2019/10/10 Python
Python 基于jwt实现认证机制流程解析
2020/06/22 Python
Python利用matplotlib绘制散点图的新手教程
2020/11/05 Python
详解如何使用rem或viewport进行移动端适配
2020/08/14 HTML / CSS
C#面试问题
2016/07/29 面试题
高中毕业生的个人自我评价
2014/02/21 职场文书
婚礼主持结束词
2014/03/13 职场文书
保护环境演讲稿
2014/05/10 职场文书
2014中学教师节广播稿
2014/09/10 职场文书
2014年药店店长工作总结
2014/11/17 职场文书
Go语言使用select{}阻塞main函数介绍
2021/04/25 Golang
Python爬取英雄联盟MSI直播间弹幕并生成词云图
2021/06/01 Python
详解Java实现数据结构之并查集
2021/06/23 Java/Android
python的html标准库
2022/04/29 Python