pytorch如何冻结某层参数的实现


Posted in Python onJanuary 10, 2020

在迁移学习finetune时我们通常需要冻结前几层的参数不参与训练,在Pytorch中的实现如下:

class Model(nn.Module):
 def __init__(self):
  super(Transfer_model, self).__init__()
  self.linear1 = nn.Linear(20, 50)
  self.linear2 = nn.Linear(50, 20)
  self.linear3 = nn.Linear(20, 2)

 def forward(self, x):
 pass

假如我们想要冻结linear1层,需要做如下操作:

model = Model()
# 这里是一般情况,共享层往往不止一层,所以做一个for循环
for para in model.linear1.parameters():
 para.requires_grad = False
# 假如真的只有一层也可以这样操作:
# model.linear1.weight.requires_grad = False

 最后我们需要将需要优化的参数传入优化器,不需要传入的参数过滤掉,所以要用到filter()函数。

optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=0.1)

其它的博客中都没有讲解filter()函数的作用,在这里我简单讲一下有助于更好的理解。

filter(function, iterable)

  • function: 判断函数
  • iterable: 可迭代对象

filter() 函数用于过滤序列,过滤掉不符合条件的元素,返回一个迭代器对象,如果要转换为列表,可以使用 list() 来转换。

该接收两个参数,第一个为函数,第二个为序列,序列的每个元素作为参数传递给函数进行判,然后返回 True 或 False,最后将返回 True 的元素放到新列表中。

filter()函数将requires_grad = True的参数传入优化器进行反向传播,requires_grad = False的则被过滤掉。

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python实现的一只从百度开始不断搜索的小爬虫
Aug 13 Python
Python实现的批量下载RFC文档
Mar 10 Python
python字典操作实例详解
Nov 16 Python
python3+PyQt5+Qt Designer实现扩展对话框
Apr 20 Python
Python递归函数实例讲解
Feb 27 Python
python批量修改图片尺寸,并保存指定路径的实现方法
Jul 04 Python
PowerBI和Python关于数据分析的对比
Jul 11 Python
使用python写一个自动浏览文章的脚本实例
Dec 05 Python
python+selenium 脚本实现每天自动登记的思路详解
Mar 11 Python
Python Scrapy框架:通用爬虫之CrawlSpider用法简单示例
Apr 11 Python
利用Python实现字幕挂载(把字幕文件与视频合并)思路详解
Oct 21 Python
利用python为PostgreSQL的表自动添加分区
Jan 18 Python
python标识符命名规范原理解析
Jan 10 #Python
pytorch1.0中torch.nn.Conv2d用法详解
Jan 10 #Python
pytorch 利用lstm做mnist手写数字识别分类的实例
Jan 10 #Python
Tensorflow Summary用法学习笔记
Jan 10 #Python
TENSORFLOW变量作用域(VARIABLE SCOPE)
Jan 10 #Python
python numpy数组复制使用实例解析
Jan 10 #Python
关于Pytorch的MNIST数据集的预处理详解
Jan 10 #Python
You might like
PHP的serialize序列化数据以及JSON格式化数据分析
2015/10/10 PHP
分享PHP-pcntl 实现多进程代码
2016/09/30 PHP
PHP操作路由器实现方法示例
2019/04/27 PHP
javaScript - 如何引入js代码
2021/03/09 Javascript
JavaScript 继承使用分析
2011/05/12 Javascript
jQuery源码分析-01总体架构分析
2011/11/14 Javascript
jquery ajax jsonp跨域调用实例代码
2013/12/11 Javascript
不使用jquery实现js打字效果示例分享
2014/01/19 Javascript
纯Javascript实现ping功能的方法
2015/03/20 Javascript
基于JavaScript实现Json数据根据某个字段进行排序
2015/11/24 Javascript
在WordPress中加入Google搜索功能的简单步骤讲解
2016/01/04 Javascript
jQuery技巧之让任何组件都支持类似DOM的事件管理
2016/04/05 Javascript
快速使用Bootstrap搭建传送带
2016/05/06 Javascript
JS 事件绑定、事件监听、事件委托详细介绍
2016/09/28 Javascript
Bootstrap Table使用心得总结
2016/11/29 Javascript
js实现时间轴自动排列效果
2017/03/09 Javascript
angularjs实现首页轮播图效果
2017/04/14 Javascript
socket.io学习教程之基础介绍(一)
2017/04/29 Javascript
Spring boot 和Vue开发中CORS跨域问题解决
2018/09/05 Javascript
vue App.vue中的公共组件改变值触发其他组件或.vue页面监听
2019/05/31 Javascript
解决Layui中layer报错的问题
2019/09/03 Javascript
在Python中利用Into包整洁地进行数据迁移的教程
2015/03/30 Python
Python中用altzone()方法处理时区的教程
2015/05/22 Python
Python函数参数操作详解
2018/08/03 Python
python虚拟环境完美部署教程
2019/08/06 Python
Python基于numpy模块实现回归预测
2020/05/14 Python
一款纯css3实现的圆形旋转分享按钮旋转角度可自己调整
2014/09/02 HTML / CSS
Linux内核的同步机制是什么?主要有哪几种内核锁
2016/07/11 面试题
如何用Java实现列出某个目录下的所有子目录
2015/07/20 面试题
火锅店创业计划书范文
2014/02/02 职场文书
大学生个人实习的自我评价
2014/02/15 职场文书
南湾猴岛导游词
2015/02/09 职场文书
大学生英文求职信范文
2015/03/19 职场文书
运动会加油稿
2015/07/22 职场文书
爱护公物主题班会
2015/08/17 职场文书
使用Redis实现实时排行榜功能
2021/07/02 Redis