pytorch如何冻结某层参数的实现


Posted in Python onJanuary 10, 2020

在迁移学习finetune时我们通常需要冻结前几层的参数不参与训练,在Pytorch中的实现如下:

class Model(nn.Module):
 def __init__(self):
  super(Transfer_model, self).__init__()
  self.linear1 = nn.Linear(20, 50)
  self.linear2 = nn.Linear(50, 20)
  self.linear3 = nn.Linear(20, 2)

 def forward(self, x):
 pass

假如我们想要冻结linear1层,需要做如下操作:

model = Model()
# 这里是一般情况,共享层往往不止一层,所以做一个for循环
for para in model.linear1.parameters():
 para.requires_grad = False
# 假如真的只有一层也可以这样操作:
# model.linear1.weight.requires_grad = False

 最后我们需要将需要优化的参数传入优化器,不需要传入的参数过滤掉,所以要用到filter()函数。

optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=0.1)

其它的博客中都没有讲解filter()函数的作用,在这里我简单讲一下有助于更好的理解。

filter(function, iterable)

  • function: 判断函数
  • iterable: 可迭代对象

filter() 函数用于过滤序列,过滤掉不符合条件的元素,返回一个迭代器对象,如果要转换为列表,可以使用 list() 来转换。

该接收两个参数,第一个为函数,第二个为序列,序列的每个元素作为参数传递给函数进行判,然后返回 True 或 False,最后将返回 True 的元素放到新列表中。

filter()函数将requires_grad = True的参数传入优化器进行反向传播,requires_grad = False的则被过滤掉。

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Windows下Python的Django框架环境部署及应用编写入门
Mar 10 Python
python与php实现分割文件代码
Mar 06 Python
Python3实现简单可学习的手写体识别(实例讲解)
Oct 21 Python
Python实现中一次读取多个值的方法
Apr 22 Python
tensorflow实现简单的卷积神经网络
May 24 Python
PYQT5设置textEdit自动滚屏的方法
Jun 14 Python
详解python解压压缩包的五种方法
Jul 05 Python
利用python list完成最简单的DB连接池方法
Aug 09 Python
如何利用Python开发一个简单的猜数字游戏
Sep 22 Python
matplotlib 对坐标的控制,加图例注释的操作
Apr 17 Python
PyCharm MySQL可视化Database配置过程图解
Jun 09 Python
keras实现theano和tensorflow训练的模型相互转换
Jun 19 Python
python标识符命名规范原理解析
Jan 10 #Python
pytorch1.0中torch.nn.Conv2d用法详解
Jan 10 #Python
pytorch 利用lstm做mnist手写数字识别分类的实例
Jan 10 #Python
Tensorflow Summary用法学习笔记
Jan 10 #Python
TENSORFLOW变量作用域(VARIABLE SCOPE)
Jan 10 #Python
python numpy数组复制使用实例解析
Jan 10 #Python
关于Pytorch的MNIST数据集的预处理详解
Jan 10 #Python
You might like
基于Snoopy的PHP近似完美获取网站编码的代码
2011/10/23 PHP
PHP Class&Object -- PHP 自排序二叉树的深入解析
2013/06/25 PHP
合格的PHP程序员必备技能
2015/11/13 PHP
简单谈谈php延迟静态绑定
2016/01/26 PHP
php基于dom实现的图书xml格式数据示例
2017/02/03 PHP
索趣科技的答案
2007/02/07 Javascript
优秀js开源框架-jQuery使用手册(1)
2007/03/10 Javascript
javascript 关闭IE6、IE7
2009/06/01 Javascript
jquery+php实现滚动的数字特效
2015/11/29 Javascript
AngularJS实现表单手动验证和表单自动验证
2015/12/09 Javascript
AngularJS入门教程之Scope(作用域)
2016/07/27 Javascript
jQuery simpleModal插件的使用介绍
2016/08/30 Javascript
jQuery实现的图片轮播效果完整示例
2016/09/12 Javascript
基于JS组件实现拖动滑块验证功能(代码分享)
2016/11/18 Javascript
vue 2.0组件与v-model详解
2017/03/27 Javascript
Linux CentOS系统下安装node.js与express的方法
2017/04/01 Javascript
vue-cli项目根据线上环境分别打出测试包和生产包
2018/05/23 Javascript
Vue 框架之键盘事件、健值修饰符、双向数据绑定
2018/11/14 Javascript
解决微信小程序调用moveToLocation失效问题【超简单】
2019/04/12 Javascript
[04:26]2014DOTA2国际邀请赛-Newbee顺利进入胜者组决赛 独家专访战神7
2014/07/19 DOTA
Python批量转换文件编码格式
2015/05/17 Python
python命令行解析之parse_known_args()函数和parse_args()使用区别介绍
2018/01/24 Python
Python使用re模块正则提取字符串中括号内的内容示例
2018/06/01 Python
Python3.4 splinter(模拟填写表单)使用方法
2018/10/13 Python
Python opencv实现人眼/人脸识别以及实时打码处理
2019/04/29 Python
python对csv文件追加写入列的方法
2019/08/01 Python
python解析yaml文件过程详解
2019/08/30 Python
打包PyQt5应用时的注意事项
2020/02/14 Python
英国最大的经认证的有机超市:Planet Organic
2018/02/02 全球购物
Pam & Gela官网:美国性感前卫女装品牌
2018/07/19 全球购物
《可爱的动物》教学反思
2014/02/22 职场文书
超市活动计划书
2014/04/24 职场文书
广告宣传策划方案
2014/05/21 职场文书
夫妻忠诚协议范文
2014/11/16 职场文书
物业接待员岗位职责
2015/04/15 职场文书
Vue组件化(ref,props, mixin,.插件)详解
2022/05/15 Vue.js