python应用Axes3D绘图(批量梯度下降算法)


Posted in Python onMarch 25, 2020

本文实例为大家分享了python批量梯度下降算法的具体代码,供大家参考,具体内容如下

问题:

将拥有两个自变量的二阶函数绘制到空间坐标系中,并通过批量梯度下降算法找到并绘制其极值点

大体思路:

首先,根据题意确定目标函数:f(w1,w2) = w1^2 + w2^2 + 2 w1 w2 + 500
然后,针对w1,w2分别求偏导,编写主方法求极值点
而后,创建三维坐标系绘制函数图像以及其极值点即可

具体代码实现以及成像结果如下:

import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d.axes3d import Axes3D

#f(w1,w2) = w1^2 + w2^2 + 2*w1*w2 + 500
def targetFunction(W): #目标函数
 w1,w2 = W
 return w1 ** 2 + w2**2 + 2*w1*w2+500

def gradientFunction(W): #梯度函数:分别对w1,w2求偏导
 w1,w2 = W
 w1_grad = 2*w1+2*w2
 w2_grad = 2*w2 + 2*w1
 return np.array([w1_grad,w2_grad])

def batch_gradient_distance(targetFunc,gradientFunc,init_W,learning_rate = 0.01,tolerance = 0.0000001): #核心算法
 W = init_W
 target_value = targetFunc(W)
 counts = 0 #用于计算次数
 while counts<5000:
 gradient = gradientFunc(W)
 next_W = W-gradient*learning_rate
 next_target_value = targetFunc(next_W)
 if abs(next_target_value-target_value) <tolerance:
 print("此结果经过了", counts, "次循环")
 return next_W
 else:
 W,target_value = next_W,next_target_value
 counts += 1
 else:
 print("没有取到极值点")


if __name__ == '__main__':
 np.random.seed(0) #保证每次运行随机出来的结果一致
 init_W = np.array([np.random.random(),np.random.random()]) #随机初始的w1,w2
 w1,w2 = batch_gradient_distance(targetFunction,gradientFunction,init_W)
 print(w1,w2)
 #画图
 x1=np.arange(-10,11,1) #为了绘制函数的原图像
 x2=np.arange(-10,11,1)

 x1, x2 = np.meshgrid(x1, x2) # meshgrid :3D坐标系

 z=x1**2 + x2**2 + 2*x1*x2+500

 fig = plt.figure()
 ax = Axes3D(fig)
 ax.plot_surface(x1, x2, z) #绘制3D坐标系中的函数图像
 ax.scatter(w1,w2, targetFunction([w1,w2]), s=50, c='red') #绘制已经找到的极值点
 ax.legend() #使坐标系为网格状

 plt.show() #显示

函数以及其极值点成像如下(红点为极值点):

python应用Axes3D绘图(批量梯度下降算法)

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python 用户登录验证的小例子
Mar 06 Python
关于你不想知道的所有Python3 unicode特性
Nov 28 Python
python获取指定目录下所有文件名列表的方法
May 20 Python
批处理与python代码混合编程的方法
May 19 Python
Python实现网站注册验证码生成类
Jun 08 Python
python中如何正确使用正则表达式的详细模式(Verbose mode expression)
Nov 08 Python
Python设计模式之装饰模式实例详解
Jan 21 Python
Python切片操作去除字符串首尾的空格
Apr 22 Python
Django中的FBV和CBV用法详解
Sep 15 Python
Python操作Sonqube API获取检测结果并打印过程解析
Nov 27 Python
Python 实现日志同时输出到屏幕和文件
Feb 19 Python
Anaconda+Pycharm环境下的PyTorch配置方法
Mar 13 Python
2020新版本pycharm+anaconda+opencv+pyqt环境配置学习笔记,亲测可用
Mar 24 #Python
python实现梯度下降和逻辑回归
Mar 24 #Python
详解Python 实现 ZeroMQ 的三种基本工作模式
Mar 24 #Python
python使用梯度下降算法实现一个多线性回归
Mar 24 #Python
PyQt5+python3+pycharm开发环境配置教程
Mar 24 #Python
python实现最速下降法
Mar 24 #Python
python实现梯度法 python最速下降法
Mar 24 #Python
You might like
PHP采用get获取url汉字出现乱码的解决方法
2014/11/13 PHP
php基于mcrypt_encrypt和mcrypt_decrypt实现字符串加密解密的方法
2016/07/12 PHP
PHP实现支持加盐的图片加密解密
2016/09/09 PHP
php实现微信模拟登陆、获取用户列表及群发消息功能示例
2017/06/28 PHP
网页设计常用的一些技巧
2006/12/22 Javascript
js 优化次数过多的循环 考虑到性能问题
2011/03/05 Javascript
JQuery扩展插件Validate—6 radio、checkbox、select的验证
2011/09/05 Javascript
JS编程小常识很有用
2012/11/26 Javascript
关于IE BUG与字符串截取substr的解决办法
2013/04/10 Javascript
jquery cookie的用法总结
2013/11/18 Javascript
JS执行删除前的判断代码
2014/02/18 Javascript
js点击文本框后才加载验证码实例代码
2015/10/20 Javascript
AngularJS过滤器filter用法分析
2016/12/11 Javascript
jQuery实现动态文字搜索功能
2017/01/05 Javascript
详谈js使用in和hasOwnProperty获取对象属性的区别
2017/04/25 Javascript
Mac下安装vue
2018/04/11 Javascript
vue 中 beforeRouteEnter 死循环的问题
2019/04/23 Javascript
微信小程序实现禁止分享代码实例
2019/10/19 Javascript
d3.js实现图形拖拽
2019/12/19 Javascript
使用Python的内建模块collections的教程
2015/04/28 Python
python使用htmllib分析网页内容的方法
2015/05/08 Python
python简单实现刷新智联简历
2016/03/30 Python
浅谈Django REST Framework限速
2017/12/12 Python
pyqt 实现QlineEdit 输入密码显示成圆点的方法
2019/06/24 Python
python生成xml时规定dtd实例方法
2020/09/21 Python
pytorch 把图片数据转化成tensor的操作
2021/03/04 Python
DC Shoes荷兰官方网站:美国极限运动品牌
2019/10/22 全球购物
药学专业大学生个人的自我评价
2013/11/04 职场文书
汽车运用工程系毕业生自荐信
2013/12/27 职场文书
如何编写优秀的食品项目创业计划书
2014/01/23 职场文书
恐龙的灭绝教学反思
2014/02/12 职场文书
物业公司的岗位任命书
2014/06/06 职场文书
2015年中秋晚会主持词
2015/07/01 职场文书
使用pytorch实现线性回归
2021/04/11 Python
Vue全家桶入门基础教程
2021/05/14 Vue.js
使用tensorflow 实现反向传播求导
2021/05/26 Python