浅谈matplotlib 绘制梯度下降求解过程


Posted in Python onJuly 12, 2020

机器学习过程中经常需要可视化,有助于加强对模型和参数的理解。

下面对梯度下降过程进行动图演示,可以修改不同的学习率,观看效果。

import numpy as np
import matplotlib.pyplot as plt
from IPython import display

X = 2*np.random.rand(100,1)
y = 4+3*X+np.random.randn(100,1) # randn正态分布
X_b = np.c_[np.ones((100,1)),X] # c_行数相等,左右拼接

eta = 0.1 # 学习率
n_iter = 1000 # 迭代次数
m = 100 # 样本点个数
theta = np.random.randn(2,1) # 参数初始值

plt.figure(figsize=(8,6))
mngr = plt.get_current_fig_manager() # 获取当前figure manager
mngr.window.wm_geometry("+520+520") # 调整窗口在屏幕上弹出的位置,注意写在打开交互模式之前
# 上面固定窗口,方便screentogif定位录制,只会这种弱弱的方法
plt.ion()# 打开交互模式
plt.rcParams["font.sans-serif"] = "SimHei"# 消除中文乱码

for iter in range(n_iter):
  plt.cla() # 清除原图像

  gradients = 2/m*X_b.T.dot(X_b.dot(theta)-y)
  theta = theta - eta*gradients
  X_new = np.array([[0],[2]])
  X_new_b = np.c_[np.ones((2,1)),X_new]
  y_pred = X_new_b.dot(theta)

  plt.axis([0,2,0,15])
  plt.plot(X,y,"b.")
  plt.plot(X_new,y_pred,"r-")
  plt.title("学习率:{:.2f}".format(eta))
  plt.pause(0.3) # 暂停一会
  display.clear_output(wait=True)# 刷新图像


plt.ioff()# 关闭交互模式  
plt.show()

浅谈matplotlib 绘制梯度下降求解过程

学习率:0.1,较合适

浅谈matplotlib 绘制梯度下降求解过程

学习率:0.02,收敛变慢了

浅谈matplotlib 绘制梯度下降求解过程

学习率:0.45,在最佳参数附近震荡

浅谈matplotlib 绘制梯度下降求解过程

学习率:0.5,不收敛

到此这篇关于浅谈matplotlib 绘制梯度下降求解过程的文章就介绍到这了,更多相关matplotlib 梯度下降内容请搜索三水点靠木以前的文章或继续浏览下面的相关文章希望大家以后多多支持三水点靠木!

Python 相关文章推荐
python处理大数字的方法
May 27 Python
python实现Flappy Bird源码
Dec 24 Python
Python中时间datetime的处理与转换用法总结
Feb 18 Python
python在新的图片窗口显示图片(图像)的方法
Jul 11 Python
Python上下文管理器类和上下文管理器装饰器contextmanager用法实例分析
Nov 07 Python
TensorFlow加载模型时出错的解决方式
Feb 06 Python
Python流程控制常用工具详解
Feb 24 Python
python使用信号量动态更新配置文件的操作
Apr 01 Python
python -v 报错问题的解决方法
Sep 15 Python
python 写一个性能测试工具(一)
Oct 24 Python
Python用dilb提取照片上人脸的示例
Oct 26 Python
Opencv+Python识别PCB板图片的步骤
Jan 07 Python
使用matplotlib的pyplot模块绘图的实现示例
Jul 12 #Python
django template实现定义临时变量,自定义赋值、自增实例
Jul 12 #Python
Django后端分离 使用element-ui文件上传方式
Jul 12 #Python
PyQt5-QDateEdit的简单使用操作
Jul 12 #Python
Python logging日志模块 配置文件方式
Jul 12 #Python
django rest framework 过滤时间操作
Jul 12 #Python
使用python脚本自动生成K8S-YAML的方法示例
Jul 12 #Python
You might like
为什么夜间收到的中波电台比白天多
2021/03/01 无线电
用PHP实现将GB编码转换为UTF8
2006/11/25 PHP
修改php.ini实现Mysql导入数据库文件最大限制的修改方法
2007/12/11 PHP
攻克CakePHP系列二 表单数据显示
2008/10/22 PHP
PHP 读取文件内容代码(txt,js等)
2009/12/06 PHP
深入探讨:Nginx 502 Bad Gateway错误的解决方法
2013/06/03 PHP
详解Yii实现分页的两种方法
2017/01/14 PHP
PHP中error_reporting函数用法详细介绍
2017/06/11 PHP
PHP实现阿里大鱼短信验证的实例代码
2017/07/10 PHP
PHP中抽象类,接口功能、定义方法示例
2019/02/26 PHP
Prototype使用指南之hash.js
2007/01/10 Javascript
jQuery ajax在GBK编码下表单提交终极解决方案(非二次编码方法)
2010/10/20 Javascript
jquery常用技巧及常用方法列表集合
2011/04/06 Javascript
利用谷歌地图API获取点与点的距离的js代码
2012/10/11 Javascript
JavaScript代码性能优化总结(推荐)
2016/05/16 Javascript
BootStrap智能表单实战系列(三)分块表单配置详解
2016/06/13 Javascript
Angular设置title信息解决SEO方面存在问题
2016/08/19 Javascript
ES6中参数的默认值语法介绍
2017/05/03 Javascript
Node.js微信 access_token ( jsapi_ticket ) 存取与刷新的示例
2017/09/30 Javascript
vue双花括号的使用方法 附练习题
2017/11/07 Javascript
AngularJS基于http请求实现下载php生成的excel文件功能示例
2018/01/23 Javascript
基于three.js实现的3D粒子动效实例代码
2019/04/09 Javascript
微信小程序实现类似微信点击语音播放效果
2020/03/30 Javascript
vue中uni-app 实现小程序登录注册功能
2019/10/12 Javascript
python读取Android permission文件
2013/11/01 Python
Python调用C++,通过Pybind11制作Python接口
2018/10/16 Python
导入tensorflow:ImportError: libcublas.so.9.0 报错
2020/01/06 Python
Python 分布式缓存之Reids数据类型操作详解
2020/06/24 Python
阿根廷在线宠物商店:Puppis
2018/03/23 全球购物
中间件分为哪几类
2012/03/14 面试题
大班亲子运动会方案
2014/06/10 职场文书
社团活动总结书
2014/06/27 职场文书
教师节学生演讲稿
2014/09/03 职场文书
大学生党校培训心得体会
2014/09/11 职场文书
单位考核鉴定意见
2015/06/05 职场文书
初中政治教学反思
2016/02/23 职场文书