python实现梯度下降算法的实例详解


Posted in Python onAugust 17, 2020

python版本选择

这里选的python版本是2.7,因为我之前用python3试了几次,发现在画3d图的时候会报错,所以改用了2.7。

数据集选择

数据集我选了一个包含两个变量,三个参数的数据集,这样可以画出3d图形对结果进行验证。

部分函数总结

symbols()函数:首先要安装sympy库才可以使用。用法:

>>> x1 = symbols('x2')
>>> x1 + 1
x2 + 1

在这个例子中,x1和x2是不一样的,x2代表的是一个函数的变量,而x1代表的是python中的一个变量,它可以表示函数的变量,也可以表示其他的任何量,它替代x2进行函数的计算。实际使用的时候我们可以将x1,x2都命名为x,但是我们要知道他们俩的区别。
再看看这个例子:

>>> x = symbols('x')
>>> expr = x + 1
>>> x = 2
>>> print(expr)
x + 1

作为python变量的x被2这个数值覆盖了,所以它现在不再表示函数变量x,而expr依然是函数变量x+1的别名,所以结果依然是x+1。
subs()函数:既然普通的方法无法为函数变量赋值,那就肯定有函数来实现这个功能,用法:

>>> (1 + x*y).subs(x, pi)#一个参数时的用法
pi*y + 1
>>> (1 + x*y).subs({x:pi, y:2})#多个参数时的用法
1 + 2*pi

diff()函数:求偏导数,用法:result=diff(fun,x),这个就是求fun函数对x变量的偏导数,结果result也是一个变量,需要赋值才能得到准确结果。

代码实现:

from __future__ import division
from sympy import symbols, diff, expand
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

data = {'x1': [100, 50, 100, 100, 50, 80, 75, 65, 90, 90],
        'x2': [4, 3, 4, 2, 2, 2, 3, 4, 3, 2],
        'y': [9.3, 4.8, 8.9, 6.5, 4.2, 6.2, 7.4, 6.0, 7.6, 6.1]}#初始化数据集
theta0, theta1, theta2 = symbols('theta0 theta1 theta2', real=True)  # y=theta0+theta1*x1+theta2*x2,定义参数
costfuc = 0 * theta0
for i in range(10):
    costfuc += (theta0 + theta1 * data['x1'][i] + theta2 * data['x2'][i] - data['y'][i]) ** 2
costfuc /= 20#初始化代价函数
dtheta0 = diff(costfuc, theta0)
dtheta1 = diff(costfuc, theta1)
dtheta2 = diff(costfuc, theta2)

rtheta0 = 1
rtheta1 = 1
rtheta2 = 1#为参数赋初始值

costvalue = costfuc.subs({theta0: rtheta0, theta1: rtheta1, theta2: rtheta2})
newcostvalue = 0#用cost的值的变化程度来判断是否已经到最小值了
count = 0
alpha = 0.0001#设置学习率,一定要设置的比较小,否则无法到达最小值
while (costvalue - newcostvalue > 0.00001 or newcostvalue - costvalue > 0.00001) and count < 1000:
    count += 1
    costvalue = newcostvalue
    rtheta0 = rtheta0 - alpha * dtheta0.subs({theta0: rtheta0, theta1: rtheta1, theta2: rtheta2})
    rtheta1 = rtheta1 - alpha * dtheta1.subs({theta0: rtheta0, theta1: rtheta1, theta2: rtheta2})
    rtheta2 = rtheta2 - alpha * dtheta2.subs({theta0: rtheta0, theta1: rtheta1, theta2: rtheta2})
    newcostvalue = costfuc.subs({theta0: rtheta0, theta1: rtheta1, theta2: rtheta2})
rtheta0 = round(rtheta0, 4)
rtheta1 = round(rtheta1, 4)
rtheta2 = round(rtheta2, 4)#给结果保留4位小数,防止数值溢出
print(rtheta0, rtheta1, rtheta2)

fig = plt.figure()
ax = Axes3D(fig)
ax.scatter(data['x1'], data['x2'], data['y'])  # 绘制散点图
xx = np.arange(20, 100, 1)
yy = np.arange(1, 5, 0.05)
X, Y = np.meshgrid(xx, yy)
Z = X * rtheta1 + Y * rtheta2 + rtheta0
ax.plot_surface(X, Y, Z, rstride=1, cstride=1, cmap=plt.get_cmap('rainbow'))

plt.show()#绘制3d图进行验证

结果:

python实现梯度下降算法的实例详解

python实现梯度下降算法的实例详解

实例扩展:

'''
梯度下降算法
Batch Gradient Descent
Stochastic Gradient Descent SGD
'''
__author__ = 'epleone'
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import sys

# 使用随机数种子, 让每次的随机数生成相同,方便调试
# np.random.seed(111111111)


class GradientDescent(object):
 eps = 1.0e-8
 max_iter = 1000000 # 暂时不需要
 dim = 1
 func_args = [2.1, 2.7] # [w_0, .., w_dim, b]

 def __init__(self, func_arg=None, N=1000):
 self.data_num = N
 if func_arg is not None:
 self.FuncArgs = func_arg
 self._getData()

 def _getData(self):
 x = 20 * (np.random.rand(self.data_num, self.dim) - 0.5)
 b_1 = np.ones((self.data_num, 1), dtype=np.float)
 # x = np.concatenate((x, b_1), axis=1)
 self.x = np.concatenate((x, b_1), axis=1)

 def func(self, x):
 # noise太大的话, 梯度下降法失去作用
 noise = 0.01 * np.random.randn(self.data_num) + 0
 w = np.array(self.func_args)
 # y1 = w * self.x[0, ] # 直接相乘
 y = np.dot(self.x, w) # 矩阵乘法
 y += noise
 return y

 @property
 def FuncArgs(self):
 return self.func_args

 @FuncArgs.setter
 def FuncArgs(self, args):
 if not isinstance(args, list):
 raise Exception(
 'args is not list, it should be like [w_0, ..., w_dim, b]')
 if len(args) == 0:
 raise Exception('args is empty list!!')
 if len(args) == 1:
 args.append(0.0)
 self.func_args = args
 self.dim = len(args) - 1
 self._getData()

 @property
 def EPS(self):
 return self.eps

 @EPS.setter
 def EPS(self, value):
 if not isinstance(value, float) and not isinstance(value, int):
 raise Exception("The type of eps should be an float number")
 self.eps = value

 def plotFunc(self):
 # 一维画图
 if self.dim == 1:
 # x = np.sort(self.x, axis=0)
 x = self.x
 y = self.func(x)
 fig, ax = plt.subplots()
 ax.plot(x, y, 'o')
 ax.set(xlabel='x ', ylabel='y', title='Loss Curve')
 ax.grid()
 plt.show()
 # 二维画图
 if self.dim == 2:
 # x = np.sort(self.x, axis=0)
 x = self.x
 y = self.func(x)
 xs = x[:, 0]
 ys = x[:, 1]
 zs = y
 fig = plt.figure()
 ax = fig.add_subplot(111, projection='3d')
 ax.scatter(xs, ys, zs, c='r', marker='o')

 ax.set_xlabel('X Label')
 ax.set_ylabel('Y Label')
 ax.set_zlabel('Z Label')
 plt.show()
 else:
 # plt.axis('off')
 plt.text(
 0.5,
 0.5,
 "The dimension(x.dim > 2) \n is too high to draw",
 size=17,
 rotation=0.,
 ha="center",
 va="center",
 bbox=dict(
 boxstyle="round",
 ec=(1., 0.5, 0.5),
 fc=(1., 0.8, 0.8), ))
 plt.draw()
 plt.show()
 # print('The dimension(x.dim > 2) is too high to draw')

 # 梯度下降法只能求解凸函数
 def _gradient_descent(self, bs, lr, epoch):
 x = self.x
 # shuffle数据集没有必要
 # np.random.shuffle(x)
 y = self.func(x)
 w = np.ones((self.dim + 1, 1), dtype=float)
 for e in range(epoch):
 print('epoch:' + str(e), end=',')
 # 批量梯度下降,bs为1时 等价单样本梯度下降
 for i in range(0, self.data_num, bs):
 y_ = np.dot(x[i:i + bs], w)
 loss = y_ - y[i:i + bs].reshape(-1, 1)
 d = loss * x[i:i + bs]
 d = d.sum(axis=0) / bs
 d = lr * d
 d.shape = (-1, 1)
 w = w - d

 y_ = np.dot(self.x, w)
 loss_ = abs((y_ - y).sum())
 print('\tLoss = ' + str(loss_))
 print('拟合的结果为:', end=',')
 print(sum(w.tolist(), []))
 print()
 if loss_ < self.eps:
 print('The Gradient Descent algorithm has converged!!\n')
 break
 pass

 def __call__(self, bs=1, lr=0.1, epoch=10):
 if sys.version_info < (3, 4):
 raise RuntimeError('At least Python 3.4 is required')
 if not isinstance(bs, int) or not isinstance(epoch, int):
 raise Exception(
 "The type of BatchSize/Epoch should be an integer number")
 self._gradient_descent(bs, lr, epoch)
 pass

 pass


if __name__ == "__main__":
 if sys.version_info < (3, 4):
 raise RuntimeError('At least Python 3.4 is required')

 gd = GradientDescent([1.2, 1.4, 2.1, 4.5, 2.1])
 # gd = GradientDescent([1.2, 1.4, 2.1])
 print("要拟合的参数结果是: ")
 print(gd.FuncArgs)
 print("===================\n\n")
 # gd.EPS = 0.0
 gd.plotFunc()
 gd(10, 0.01)
 print("Finished!")

到此这篇关于python实现梯度下降算法的实例详解的文章就介绍到这了,更多相关教你用python实现梯度下降算法内容请搜索三水点靠木以前的文章或继续浏览下面的相关文章希望大家以后多多支持三水点靠木!

Python 相关文章推荐
pycharm 使用心得(九)解决No Python interpreter selected的问题
Jun 06 Python
python实现挑选出来100以内的质数
Mar 24 Python
Python HTMLParser模块解析html获取url实例
Apr 08 Python
总结用Pdb库调试Python的方式及常用的命令
Aug 18 Python
Python读取Excel表格,并同时画折线图和柱状图的方法
Oct 14 Python
pandas分别写入excel的不同sheet方法
Dec 11 Python
用Q-learning算法实现自动走迷宫机器人的方法示例
Jun 03 Python
python安装本地whl的实例步骤
Oct 12 Python
关于Python Tkinter Button控件command传参问题的解决方式
Mar 04 Python
Numpy一维线性插值函数的用法
Apr 22 Python
如何让python的运行速度得到提升
Jul 08 Python
python turtle绘制多边形和跳跃和改变速度特效
Mar 16 Python
python3.5的包存放的具体路径
Aug 16 #Python
python根据字典的键来删除元素的方法
Aug 16 #Python
python实现取余操作的简单实例
Aug 16 #Python
python属于哪种语言
Aug 16 #Python
python中sys模块是做什么用的
Aug 16 #Python
python3获取控制台输入的数据的具体实例
Aug 16 #Python
python在一个范围内取随机数的简单实例
Aug 16 #Python
You might like
php Smarty模板生成html文档的方法
2010/04/12 PHP
PHP学习之数组的定义和填充
2011/04/17 PHP
thinkphp的CURD和查询方式介绍
2013/12/19 PHP
php redis setnx分布式锁简单原理解析
2020/10/23 PHP
jquery 入门教程 [翻译] 推荐
2009/08/17 Javascript
js 获取坐标 通过JS得到当前焦点(鼠标)的坐标属性
2013/01/04 Javascript
jQuery之选择组件的深入解析
2013/06/19 Javascript
tangram框架响应式加载图片方法
2013/11/21 Javascript
Event altKey,ctrlKey,shiftKey属性解析
2013/12/18 Javascript
javascript的回调函数应用示例
2014/02/20 Javascript
JavaScript前端图片加载管理器imagepool使用详解
2014/12/29 Javascript
js获取数组的最后一个元素
2015/04/14 Javascript
JavaScript汉诺塔问题解决方法
2015/04/21 Javascript
node+experss实现爬取电影天堂爬虫
2016/11/20 Javascript
详解Node.js利用node-git-server快速搭建git服务器
2017/09/27 Javascript
简单了解vue中的v-if和v-show的区别
2019/10/08 Javascript
javascript实现拖拽碰撞检测
2020/03/12 Javascript
js删除对象中的某一个字段的方法实现
2021/01/11 Javascript
[03:07]DOTA2英雄基础教程 冰霜诅咒极寒幽魂
2013/12/06 DOTA
[16:56]heroes英雄教学 司夜刺客
2014/09/18 DOTA
Django Highcharts制作图表
2016/08/27 Python
python 打印对象的所有属性值的方法
2016/09/11 Python
详解pyppeteer(python版puppeteer)基本使用
2019/06/12 Python
Python实现微信中找回好友、群聊用户撤回的消息功能示例
2019/08/23 Python
Python 生成器,迭代,yield关键字,send()传参给yield语句操作示例
2019/10/12 Python
python双端队列原理、实现与使用方法分析
2019/11/27 Python
Python发送手机动态验证码代码实例
2020/02/28 Python
在django admin中配置搜索域是一个外键时的处理方法
2020/05/20 Python
Python request中文乱码问题解决方案
2020/09/17 Python
Anaconda使用IDLE的实现示例
2020/09/23 Python
初中数学教学反思
2014/01/16 职场文书
幼儿园小班评语
2014/04/18 职场文书
市场开发与营销专业求职信范文
2014/05/01 职场文书
建筑安全生产目标责任书
2014/07/23 职场文书
幼儿园2016年感恩节活动总结
2016/04/01 职场文书
Win11如何查看显卡型号 Win11查看显卡型号的方法
2022/08/14 数码科技