pytorch如何冻结某层参数的实现


Posted in Python onJanuary 10, 2020

在迁移学习finetune时我们通常需要冻结前几层的参数不参与训练,在Pytorch中的实现如下:

class Model(nn.Module):
 def __init__(self):
  super(Transfer_model, self).__init__()
  self.linear1 = nn.Linear(20, 50)
  self.linear2 = nn.Linear(50, 20)
  self.linear3 = nn.Linear(20, 2)

 def forward(self, x):
 pass

假如我们想要冻结linear1层,需要做如下操作:

model = Model()
# 这里是一般情况,共享层往往不止一层,所以做一个for循环
for para in model.linear1.parameters():
 para.requires_grad = False
# 假如真的只有一层也可以这样操作:
# model.linear1.weight.requires_grad = False

 最后我们需要将需要优化的参数传入优化器,不需要传入的参数过滤掉,所以要用到filter()函数。

optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=0.1)

其它的博客中都没有讲解filter()函数的作用,在这里我简单讲一下有助于更好的理解。

filter(function, iterable)

  • function: 判断函数
  • iterable: 可迭代对象

filter() 函数用于过滤序列,过滤掉不符合条件的元素,返回一个迭代器对象,如果要转换为列表,可以使用 list() 来转换。

该接收两个参数,第一个为函数,第二个为序列,序列的每个元素作为参数传递给函数进行判,然后返回 True 或 False,最后将返回 True 的元素放到新列表中。

filter()函数将requires_grad = True的参数传入优化器进行反向传播,requires_grad = False的则被过滤掉。

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python实现栈的方法
May 26 Python
Python的Django框架中的URL配置与松耦合
Jul 15 Python
python编程实现12306的一个小爬虫实例
Dec 27 Python
Python线程创建和终止实例代码
Jan 20 Python
Python解析并读取PDF文件内容的方法
May 08 Python
异步任务队列Celery在Django中的使用方法
Jun 07 Python
python版本单链表实现代码
Sep 28 Python
python消除序列的重复值并保持顺序不变的实例
Nov 08 Python
Django跨域请求CSRF的方法示例
Nov 11 Python
解决python中显示图片的plt.imshow plt.show()内存泄漏问题
Apr 24 Python
Python urllib2运行过程原理解析
Jun 04 Python
python爬取企查查企业信息之selenium自动模拟登录企查查
Apr 08 Python
python标识符命名规范原理解析
Jan 10 #Python
pytorch1.0中torch.nn.Conv2d用法详解
Jan 10 #Python
pytorch 利用lstm做mnist手写数字识别分类的实例
Jan 10 #Python
Tensorflow Summary用法学习笔记
Jan 10 #Python
TENSORFLOW变量作用域(VARIABLE SCOPE)
Jan 10 #Python
python numpy数组复制使用实例解析
Jan 10 #Python
关于Pytorch的MNIST数据集的预处理详解
Jan 10 #Python
You might like
php编程实现获取excel文档内容的代码实例
2011/06/28 PHP
PHP中生成UUID自定义函数分享
2015/06/10 PHP
php原生导出excel文件的两种方法(推荐)
2016/11/19 PHP
对laravel的csrf 防御机制详解,及form中csrf_token()的存在介绍
2019/10/24 PHP
javascript String 对象
2008/04/25 Javascript
对JavaScript中this指针的新理解分享
2015/01/31 Javascript
jQuery实现html表格动态添加新行的方法
2015/05/28 Javascript
JS实现的网页背景闪电闪烁效果代码
2015/10/17 Javascript
微信小程序 选择器(时间,日期,地区)实例详解
2016/11/16 Javascript
浅谈MUI框架中加载外部网页或服务器数据的方法
2018/01/31 Javascript
快速了解vue-cli 3.0 新特性
2018/02/28 Javascript
微信小程序之判断页面滚动方向的示例代码
2018/08/30 Javascript
Node.js模拟发起http请求从异步转同步的5种用法
2018/09/26 Javascript
AngularJs1.x自定义指令独立作用域的函数传入参数方法
2018/10/09 Javascript
微信小程序外卖选购页实现切换分类与数量加减功能案例
2019/01/15 Javascript
Android 自定义view仿微信相机单击拍照长按录视频按钮
2019/07/19 Javascript
JavaScript实现网页计算器功能
2020/10/29 Javascript
详解python之简单主机批量管理工具
2017/01/27 Python
Python数据集切分实例
2018/12/08 Python
python区分不同数据类型的方法
2019/10/14 Python
Python程序暂停的正常处理方法
2019/11/07 Python
keras自定义损失函数并且模型加载的写法介绍
2020/06/15 Python
解决Keras 自定义层时遇到版本的问题
2020/06/16 Python
python 使用OpenCV进行简单的人像分割与合成
2021/02/02 Python
你应该知道的30个css选择器
2014/03/19 HTML / CSS
英国最大的汽车交易网站:Auto Trader UK
2016/09/23 全球购物
The North Face北面德国官网:美国著名户外品牌
2018/12/12 全球购物
有趣的流行文化T恤、马克杯、手机壳和更多:Look Human
2019/01/07 全球购物
Sarenza德国:法国最大的时尚鞋和包包网上商店
2019/06/08 全球购物
JAVA软件工程师测试题
2014/07/25 面试题
秋季运动会广播稿大全
2014/02/17 职场文书
委托书的格式
2014/08/01 职场文书
2014向国旗敬礼网上签名活动总结
2014/09/27 职场文书
结婚通知短信怎么写
2015/04/17 职场文书
python 算法题——快乐数的多种解法
2021/05/27 Python
DjangoRestFramework 使用 simpleJWT 登陆认证完整记录
2021/06/22 Python