Pytorch反向求导更新网络参数的方法


Posted in Python onAugust 17, 2019

方法一:手动计算变量的梯度,然后更新梯度

import torch
from torch.autograd import Variable
# 定义参数
w1 = Variable(torch.FloatTensor([1,2,3]),requires_grad = True)
# 定义输出
d = torch.mean(w1)
# 反向求导
d.backward()
# 定义学习率等参数
lr = 0.001
# 手动更新参数
w1.data.zero_() # BP求导更新参数之前,需先对导数置0
w1.data.sub_(lr*w1.grad.data)

一个网络中通常有很多变量,如果按照上述的方法手动求导,然后更新参数,是很麻烦的,这个时候可以调用torch.optim

方法二:使用torch.optim

import torch
from torch.autograd import Variable
import torch.nn as nn
import torch.optim as optim
# 这里假设我们定义了一个网络,为net
steps = 10000
# 定义一个optim对象
optimizer = optim.SGD(net.parameters(), lr = 0.01)
# 在for循环中更新参数
for i in range(steps):
 optimizer.zero_grad() # 对网络中参数当前的导数置0
 output = net(input) # 网络前向计算
 loss = criterion(output, target) # 计算损失
 loss.backward() # 得到模型中参数对当前输入的梯度
 optimizer.step() # 更新参数

注意:torch.optim只用于参数更新和对参数的梯度置0,不能计算参数的梯度,在使用torch.optim进行参数更新之前,需要写前向与反向传播求导的代码

以上这篇Pytorch反向求导更新网络参数的方法就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python读写Json涉及到中文的处理方法
Sep 12 Python
python中defaultdict的用法详解
Jun 07 Python
Python实现上下班抢个顺风单脚本
Feb 07 Python
python中的for循环
Sep 28 Python
python3.6下Numpy库下载与安装图文教程
Apr 02 Python
python根据多个文件名批量查找文件
Aug 13 Python
Virtualenv 搭建 Py项目运行环境的教程详解
Jun 22 Python
python和js交互调用的方法
Jun 23 Python
Python爬虫小例子——爬取51job发布的工作职位
Jul 10 Python
Python实现JS解密并爬取某音漫客网站
Oct 23 Python
Python web框架(django,flask)实现mysql数据库读写分离的示例
Nov 18 Python
python实现不同数据库间数据同步功能
Feb 25 Python
pytorch 模型可视化的例子
Aug 17 #Python
pytorch 输出中间层特征的实例
Aug 17 #Python
基于pytorch的保存和加载模型参数的方法
Aug 17 #Python
pytorch 固定部分参数训练的方法
Aug 17 #Python
python之PyQt按钮右键菜单功能的实现代码
Aug 17 #Python
pytorch 在网络中添加可训练参数,修改预训练权重文件的方法
Aug 17 #Python
python PyQt5/Pyside2 按钮右击菜单实例代码
Aug 17 #Python
You might like
PHP+MYSQL会员系统的开发实例教程
2014/08/23 PHP
php实现的任意进制互转类分享
2015/07/07 PHP
PHP信号处理机制的操作代码讲解
2019/04/19 PHP
PHP抽象类与接口的区别实例详解
2019/05/09 PHP
php用wangeditor3实现图片上传功能
2019/08/22 PHP
Ext 表单布局实例代码
2009/04/30 Javascript
每天一篇javascript学习小结(Boolean对象)
2015/11/12 Javascript
基于jQuery Tipso插件实现消息提示框特效
2016/03/16 Javascript
Bootstrap3学习笔记(三)之表格
2016/05/20 Javascript
浅谈jQuery hover(over, out)事件函数
2016/12/03 Javascript
Javascript中return的使用与闭包详解
2017/01/11 Javascript
微信小程序开发教程之增加mixin扩展
2017/08/09 Javascript
vue2.0使用swiper组件实现轮播效果
2017/11/27 Javascript
Koa2 之文件上传下载的示例代码
2018/03/29 Javascript
AngularJS中ng-options实现下拉列表的数据绑定方法
2018/08/13 Javascript
vue 对象添加或删除成员时无法实时更新的解决方法
2019/05/01 Javascript
javascript实现电商放大镜效果
2020/11/23 Javascript
python文件名和文件路径操作实例
2017/09/29 Python
python如何拆分含有多种分隔符的字符串
2018/03/20 Python
python队列queue模块详解
2018/04/27 Python
Python绘制的二项分布概率图示例
2018/08/22 Python
Python OpenCV对本地视频文件进行分帧保存的实例
2019/01/08 Python
Python面向对象之类的定义与继承用法示例
2019/01/14 Python
Django配置MySQL数据库的完整步骤
2019/09/07 Python
Python2 与Python3的版本区别实例分析
2020/03/30 Python
使用tensorflow实现VGG网络,训练mnist数据集方式
2020/05/26 Python
利用CSS3的checked伪类实现OL的隐藏显示的方法
2010/12/18 HTML / CSS
美国最流行的男士时尚网站:Touch of Modern
2018/02/05 全球购物
日常奢侈品,轻松购物:Verishop
2019/08/20 全球购物
公交公司毕业生求职信
2014/02/15 职场文书
拓展训练激励口号
2014/06/17 职场文书
介绍信格式样本
2015/05/05 职场文书
学校运动会开幕词
2016/03/03 职场文书
简短的36句中秋节祝福信息语句
2019/09/09 职场文书
Python绘制分类图的方法
2021/04/20 Python
Python中文纠错的简单实现
2021/07/07 Python