Pytorch反向求导更新网络参数的方法


Posted in Python onAugust 17, 2019

方法一:手动计算变量的梯度,然后更新梯度

import torch
from torch.autograd import Variable
# 定义参数
w1 = Variable(torch.FloatTensor([1,2,3]),requires_grad = True)
# 定义输出
d = torch.mean(w1)
# 反向求导
d.backward()
# 定义学习率等参数
lr = 0.001
# 手动更新参数
w1.data.zero_() # BP求导更新参数之前,需先对导数置0
w1.data.sub_(lr*w1.grad.data)

一个网络中通常有很多变量,如果按照上述的方法手动求导,然后更新参数,是很麻烦的,这个时候可以调用torch.optim

方法二:使用torch.optim

import torch
from torch.autograd import Variable
import torch.nn as nn
import torch.optim as optim
# 这里假设我们定义了一个网络,为net
steps = 10000
# 定义一个optim对象
optimizer = optim.SGD(net.parameters(), lr = 0.01)
# 在for循环中更新参数
for i in range(steps):
 optimizer.zero_grad() # 对网络中参数当前的导数置0
 output = net(input) # 网络前向计算
 loss = criterion(output, target) # 计算损失
 loss.backward() # 得到模型中参数对当前输入的梯度
 optimizer.step() # 更新参数

注意:torch.optim只用于参数更新和对参数的梯度置0,不能计算参数的梯度,在使用torch.optim进行参数更新之前,需要写前向与反向传播求导的代码

以上这篇Pytorch反向求导更新网络参数的方法就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python函数参数*args**kwargs用法实例
Dec 04 Python
python多线程编程中的join函数使用心得
Sep 02 Python
基于python实现微信模板消息
Dec 21 Python
python运行时间的几种方法
Jun 17 Python
利用Python3分析sitemap.xml并抓取导出全站链接详解
Jul 04 Python
浅析Python函数式编程
Oct 06 Python
python版DDOS攻击脚本
Jun 12 Python
python内置函数sorted()用法深入分析
Oct 08 Python
使用Python FastAPI构建Web服务的实现
Jun 08 Python
Python基于yaml文件配置logging日志过程解析
Jun 23 Python
python实现图像随机裁剪的示例代码
Dec 10 Python
用Python提取PDF表格的方法
Apr 11 Python
pytorch 模型可视化的例子
Aug 17 #Python
pytorch 输出中间层特征的实例
Aug 17 #Python
基于pytorch的保存和加载模型参数的方法
Aug 17 #Python
pytorch 固定部分参数训练的方法
Aug 17 #Python
python之PyQt按钮右键菜单功能的实现代码
Aug 17 #Python
pytorch 在网络中添加可训练参数,修改预训练权重文件的方法
Aug 17 #Python
python PyQt5/Pyside2 按钮右击菜单实例代码
Aug 17 #Python
You might like
印尼林东PWN黄金曼特宁咖啡豆:怎么冲世界上最醇厚的咖啡冲煮教程
2021/03/03 冲泡冲煮
修改Zend引擎实现PHP源码加密的原理及实践
2008/04/14 PHP
php网站地图生成类示例
2014/01/13 PHP
php采用curl访问域名返回405 method not allowed提示的解决方法
2014/06/26 PHP
PHP SPL标准库之文件操作(SplFileInfo和SplFileObject)实例
2015/05/11 PHP
yii2 RBAC使用DbManager实现后台权限判断的方法
2016/07/23 PHP
php mongodb操作类 带几个简单的例子
2016/08/25 PHP
PHP利用Mysql锁解决高并发的方法
2018/09/04 PHP
JavaScript 动态改变图片大小
2009/06/11 Javascript
dotopAlert 提示用户需安装播放器的代码
2012/09/17 Javascript
屏蔽网页右键复制和ctrl+c复制的js代码
2013/01/04 Javascript
多种方法实现360浏览器下禁止自动填写用户名密码
2014/06/16 Javascript
jQuery CSS()方法改变现有的CSS样式表
2014/09/09 Javascript
js格式化时间的方法
2015/12/18 Javascript
Javascript数组Array基础介绍
2016/03/13 Javascript
jquery跟随屏幕滚动效果的实现代码
2016/04/13 Javascript
Jquery插件仿百度搜索关键字自动匹配功能
2016/05/11 Javascript
node.js实现博客小爬虫的实例代码
2016/10/08 Javascript
深入学习js瀑布流布局
2016/10/14 Javascript
AngularJS入门教程之多视图切换用法示例
2016/11/02 Javascript
Bootstrap模态框水平垂直居中与增加拖拽功能
2016/11/09 Javascript
jquery实现全选、全不选以及单选功能
2017/03/23 jQuery
React Native中TabBarIOS的简单使用方法示例
2017/10/13 Javascript
vue脚手架搭建项目的兼容性配置详解
2018/07/17 Javascript
python复制与引用用法分析
2015/04/08 Python
Python3匿名函数用法示例
2018/07/25 Python
python如何查看微信消息撤回
2018/11/27 Python
Python 通过调用接口获取公交信息的实例
2018/12/17 Python
对Pycharm创建py文件时自定义头部模板的方法详解
2019/02/12 Python
关于Tensorflow使用CPU报错的解决方式
2020/02/05 Python
员工拾金不昧表扬信
2014/01/09 职场文书
2014年机关植树节活动方案
2014/02/27 职场文书
党支部三会一课计划
2014/09/24 职场文书
三下乡个人总结
2015/03/04 职场文书
2016年学校招生广告语
2016/01/28 职场文书
详解SQL的窗口函数
2022/04/21 Oracle