pytorch如何冻结某层参数的实现


Posted in Python onJanuary 10, 2020

在迁移学习finetune时我们通常需要冻结前几层的参数不参与训练,在Pytorch中的实现如下:

class Model(nn.Module):
 def __init__(self):
  super(Transfer_model, self).__init__()
  self.linear1 = nn.Linear(20, 50)
  self.linear2 = nn.Linear(50, 20)
  self.linear3 = nn.Linear(20, 2)

 def forward(self, x):
 pass

假如我们想要冻结linear1层,需要做如下操作:

model = Model()
# 这里是一般情况,共享层往往不止一层,所以做一个for循环
for para in model.linear1.parameters():
 para.requires_grad = False
# 假如真的只有一层也可以这样操作:
# model.linear1.weight.requires_grad = False

 最后我们需要将需要优化的参数传入优化器,不需要传入的参数过滤掉,所以要用到filter()函数。

optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=0.1)

其它的博客中都没有讲解filter()函数的作用,在这里我简单讲一下有助于更好的理解。

filter(function, iterable)

  • function: 判断函数
  • iterable: 可迭代对象

filter() 函数用于过滤序列,过滤掉不符合条件的元素,返回一个迭代器对象,如果要转换为列表,可以使用 list() 来转换。

该接收两个参数,第一个为函数,第二个为序列,序列的每个元素作为参数传递给函数进行判,然后返回 True 或 False,最后将返回 True 的元素放到新列表中。

filter()函数将requires_grad = True的参数传入优化器进行反向传播,requires_grad = False的则被过滤掉。

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python实现多线程采集的2个代码例子
Jul 07 Python
Python的Flask框架中Flask-Admin库的简单入门指引
Apr 07 Python
浅谈python为什么不需要三目运算符和switch
Jun 17 Python
Python序列循环移位的3种方法推荐
Apr 09 Python
selenium+python自动化测试环境搭建步骤
Jun 03 Python
详解python中的time和datetime的常用方法
Jul 08 Python
python解析多层json操作示例
Dec 30 Python
如何在Python 游戏中模拟引力
Mar 27 Python
Python类中的装饰器在当前类中的声明与调用详解
Apr 15 Python
Python-jenkins 获取job构建信息方式
May 12 Python
Selenium alert 弹窗处理的示例代码
Aug 06 Python
Python 把两层列表展开平铺成一层(5种实现方式)
Apr 07 Python
python标识符命名规范原理解析
Jan 10 #Python
pytorch1.0中torch.nn.Conv2d用法详解
Jan 10 #Python
pytorch 利用lstm做mnist手写数字识别分类的实例
Jan 10 #Python
Tensorflow Summary用法学习笔记
Jan 10 #Python
TENSORFLOW变量作用域(VARIABLE SCOPE)
Jan 10 #Python
python numpy数组复制使用实例解析
Jan 10 #Python
关于Pytorch的MNIST数据集的预处理详解
Jan 10 #Python
You might like
PHP配置心得包含MYSQL5乱码解决
2006/11/20 PHP
对象失去焦点时自己动提交数据的实现代码
2012/11/06 PHP
PHP实现支持加盐的图片加密解密
2016/09/09 PHP
简单谈谈PHP中的trait
2017/02/25 PHP
Laravel接收前端ajax传来的数据的实例代码
2017/07/20 PHP
PHP生成推广海报的方法分享
2018/04/22 PHP
详解php协程知识点
2018/09/21 PHP
二级域名转向类
2006/11/09 Javascript
jquery异步调用页面后台方法‏(asp.net)
2011/03/01 Javascript
分享js粘帖屏幕截图到web页面插件screenshot-paste
2020/08/21 Javascript
jQuery Ajax页面局部加载方法汇总
2016/06/02 Javascript
微信小程序 数组中的push与concat的区别
2017/01/05 Javascript
微信小程序新增的拖动组件movable-view使用教程
2017/05/20 Javascript
React Native之ListView实现九宫格效果的示例
2017/08/02 Javascript
JS中appendChild追加子节点无效的解决方法
2018/10/14 Javascript
js实现多个倒计时并行 js拼团倒计时
2019/02/25 Javascript
一些手写JavaScript常用的函数汇总
2019/04/16 Javascript
Python SQLite3数据库操作类分享
2014/06/10 Python
Python中的匿名函数使用简介
2015/04/27 Python
Python3.6正式版新特性预览
2016/12/15 Python
Flask数据库迁移简单介绍
2017/10/24 Python
Python 私有化操作实例分析
2019/11/21 Python
python找出列表中大于某个阈值的数据段示例
2019/11/24 Python
用HTML5实现手机摇一摇的功能的教程
2012/10/30 HTML / CSS
西班牙手机之家:Phone House
2018/10/18 全球购物
怀旧香味蜡烛:Homesick
2019/11/02 全球购物
27个经典Linux面试题及答案,你知道几个?
2014/03/11 面试题
介绍JAVA 中的Collection FrameWork(及如何写自己的数据结构)
2014/10/31 面试题
护士岗位职责
2014/02/16 职场文书
春节联欢会主持词
2014/03/24 职场文书
八项规定对照检查材料
2014/08/31 职场文书
经验交流材料格式
2014/12/30 职场文书
MySQL创建管理子分区
2022/04/13 MySQL
java高级用法JNA强大的Memory和Pointer
2022/04/19 Java/Android
Python开发五子棋小游戏
2022/04/28 Python
Fluentd搭建日志收集服务
2022/09/23 Servers