pytorch如何冻结某层参数的实现


Posted in Python onJanuary 10, 2020

在迁移学习finetune时我们通常需要冻结前几层的参数不参与训练,在Pytorch中的实现如下:

class Model(nn.Module):
 def __init__(self):
  super(Transfer_model, self).__init__()
  self.linear1 = nn.Linear(20, 50)
  self.linear2 = nn.Linear(50, 20)
  self.linear3 = nn.Linear(20, 2)

 def forward(self, x):
 pass

假如我们想要冻结linear1层,需要做如下操作:

model = Model()
# 这里是一般情况,共享层往往不止一层,所以做一个for循环
for para in model.linear1.parameters():
 para.requires_grad = False
# 假如真的只有一层也可以这样操作:
# model.linear1.weight.requires_grad = False

 最后我们需要将需要优化的参数传入优化器,不需要传入的参数过滤掉,所以要用到filter()函数。

optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=0.1)

其它的博客中都没有讲解filter()函数的作用,在这里我简单讲一下有助于更好的理解。

filter(function, iterable)

  • function: 判断函数
  • iterable: 可迭代对象

filter() 函数用于过滤序列,过滤掉不符合条件的元素,返回一个迭代器对象,如果要转换为列表,可以使用 list() 来转换。

该接收两个参数,第一个为函数,第二个为序列,序列的每个元素作为参数传递给函数进行判,然后返回 True 或 False,最后将返回 True 的元素放到新列表中。

filter()函数将requires_grad = True的参数传入优化器进行反向传播,requires_grad = False的则被过滤掉。

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python字符串特性及常用字符串方法的简单笔记
Jan 04 Python
如何使用python爬取csdn博客访问量
Feb 14 Python
Python数据类型之List列表实例详解
May 08 Python
python地震数据可视化详解
Jun 18 Python
Python线上环境使用日志的及配置文件
Jul 28 Python
使用Pyinstaller转换.py文件为.exe可执行程序过程详解
Aug 06 Python
python多进程并发demo实例解析
Dec 13 Python
Python figure参数及subplot子图绘制代码
Apr 18 Python
全网首秀之Pycharm十大实用技巧(推荐)
Apr 27 Python
Python实现王者荣耀自动刷金币的完整步骤
Jan 22 Python
如何利用Matlab制作一款真正的拼图小游戏
May 11 Python
Python数据结构之队列详解
Mar 21 Python
python标识符命名规范原理解析
Jan 10 #Python
pytorch1.0中torch.nn.Conv2d用法详解
Jan 10 #Python
pytorch 利用lstm做mnist手写数字识别分类的实例
Jan 10 #Python
Tensorflow Summary用法学习笔记
Jan 10 #Python
TENSORFLOW变量作用域(VARIABLE SCOPE)
Jan 10 #Python
python numpy数组复制使用实例解析
Jan 10 #Python
关于Pytorch的MNIST数据集的预处理详解
Jan 10 #Python
You might like
实现dedecms全站URL静态化改造的代码
2007/03/29 PHP
php实现多城市切换特效
2015/08/09 PHP
Ajax提交表单时验证码自动验证 php后端验证码检测
2016/07/20 PHP
PHP实现二维数组根据key进行排序的方法
2016/12/30 PHP
php+redis实现多台服务器内网存储session并读取示例
2017/01/12 PHP
Javascript 解疑
2009/11/11 Javascript
在Firefox下js select标签点击无法弹出
2014/03/06 Javascript
js使用onmousemove和onmouseout获取鼠标坐标的方法
2015/03/31 Javascript
jQuery实现自定义事件的方法
2015/04/17 Javascript
jQuery基于图层模仿五星星评价功能的方法
2015/05/07 Javascript
jQuery插件pagewalkthrough实现引导页效果
2015/07/05 Javascript
Webpack+Vue如何导入Jquery和Jquery的第三方插件
2017/02/20 Javascript
angular.js中解决跨域问题的三种方式
2017/07/12 Javascript
vue 虚拟dom的patch源码分析
2018/03/01 Javascript
一文看懂如何简单实现节流函数和防抖函数
2019/09/05 Javascript
es6函数之箭头函数用法实例详解
2020/04/25 Javascript
解决antd 下拉框 input [defaultValue] 的值的问题
2020/10/31 Javascript
微信小程序反编译的实现
2020/12/10 Javascript
Python中map和列表推导效率比较实例分析
2015/06/17 Python
Python通过paramiko远程下载Linux服务器上的文件实例
2018/12/27 Python
python实现截取屏幕保存文件,删除N天前截图的例子
2019/08/27 Python
详解Python Opencv和PIL读取图像文件的差别
2019/12/27 Python
使用keras和tensorflow保存为可部署的pb格式
2020/05/25 Python
python为什么会环境变量设置不成功
2020/06/23 Python
用Python制作mini翻译器的实现示例
2020/08/17 Python
如何用Matlab和Python读取Netcdf文件
2021/02/19 Python
车间班组长岗位职责
2013/11/13 职场文书
运动会广播稿30字
2014/01/21 职场文书
代理商会议邀请函
2014/01/27 职场文书
秋季运动会广播稿大全
2014/02/17 职场文书
毕业生面试求职信
2014/06/23 职场文书
入伍通知书
2015/04/23 职场文书
土木工程毕业答辩开场白
2015/05/29 职场文书
听课评课活动心得体会
2016/01/15 职场文书
matlab xlabel位置的设置方式
2021/05/21 Python
简单总结SpringMVC拦截器的使用方法
2021/06/28 Java/Android