pytorch如何冻结某层参数的实现


Posted in Python onJanuary 10, 2020

在迁移学习finetune时我们通常需要冻结前几层的参数不参与训练,在Pytorch中的实现如下:

class Model(nn.Module):
 def __init__(self):
  super(Transfer_model, self).__init__()
  self.linear1 = nn.Linear(20, 50)
  self.linear2 = nn.Linear(50, 20)
  self.linear3 = nn.Linear(20, 2)

 def forward(self, x):
 pass

假如我们想要冻结linear1层,需要做如下操作:

model = Model()
# 这里是一般情况,共享层往往不止一层,所以做一个for循环
for para in model.linear1.parameters():
 para.requires_grad = False
# 假如真的只有一层也可以这样操作:
# model.linear1.weight.requires_grad = False

 最后我们需要将需要优化的参数传入优化器,不需要传入的参数过滤掉,所以要用到filter()函数。

optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=0.1)

其它的博客中都没有讲解filter()函数的作用,在这里我简单讲一下有助于更好的理解。

filter(function, iterable)

  • function: 判断函数
  • iterable: 可迭代对象

filter() 函数用于过滤序列,过滤掉不符合条件的元素,返回一个迭代器对象,如果要转换为列表,可以使用 list() 来转换。

该接收两个参数,第一个为函数,第二个为序列,序列的每个元素作为参数传递给函数进行判,然后返回 True 或 False,最后将返回 True 的元素放到新列表中。

filter()函数将requires_grad = True的参数传入优化器进行反向传播,requires_grad = False的则被过滤掉。

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python中实现字符串类型与字典类型相互转换的方法
Aug 18 Python
python中类的一些方法分析
Sep 25 Python
Python中最常用的操作列表的几种方法归纳
Apr 24 Python
centos6.4下python3.6.1安装教程
Jul 21 Python
python里运用私有属性和方法总结
Jul 08 Python
Django后端接收嵌套Json数据及解析详解
Jul 17 Python
对Django外键关系的描述
Jul 26 Python
python协程gevent案例 爬取斗鱼图片过程解析
Aug 27 Python
用pytorch的nn.Module构造简单全链接层实例
Jan 14 Python
解决Python命令行下退格,删除,方向键乱码(亲测有效)
Jan 16 Python
python json 递归打印所有json子节点信息的例子
Feb 27 Python
Python气泡提示与标签的实现
Apr 01 Python
python标识符命名规范原理解析
Jan 10 #Python
pytorch1.0中torch.nn.Conv2d用法详解
Jan 10 #Python
pytorch 利用lstm做mnist手写数字识别分类的实例
Jan 10 #Python
Tensorflow Summary用法学习笔记
Jan 10 #Python
TENSORFLOW变量作用域(VARIABLE SCOPE)
Jan 10 #Python
python numpy数组复制使用实例解析
Jan 10 #Python
关于Pytorch的MNIST数据集的预处理详解
Jan 10 #Python
You might like
session在php5.3中的变化 session_is_registered() is deprecated in
2013/11/12 PHP
PHP反射基础知识回顾
2020/09/10 PHP
prototype Element学习笔记(篇一)
2008/10/26 Javascript
ExtJS 2.0实用简明教程 之ExtJS版的Hello
2009/04/29 Javascript
jQuery LigerUI 插件介绍及使用之ligerDrag和ligerResizable示例代码打包
2011/04/06 Javascript
关于div自适应高度/左右高度自适应一致的js代码
2013/03/22 Javascript
Javascript页面添加到收藏夹的简单方法
2013/08/07 Javascript
javascript连续赋值问题
2015/07/08 Javascript
浅谈JSON.parse()和JSON.stringify()
2015/07/14 Javascript
JavaScript判断手机号运营商是移动、联通、电信还是其他(代码简单)
2015/09/25 Javascript
Bootstrap中CSS的使用方法
2016/02/17 Javascript
关于webuploader插件使用过程遇到的小问题
2016/11/07 Javascript
基于Bootstrap 3 JQuery及RegExp的表单验证功能
2017/02/16 Javascript
vue2.0 与 bootstrap datetimepicker的结合使用实例
2017/05/22 Javascript
JS时间控制实现动态效果的实例讲解
2017/07/31 Javascript
Vue中引入样式文件的方法
2017/08/18 Javascript
JS实现去除数组中重复json的方法示例
2017/12/21 Javascript
vue-router配合ElementUI实现导航的实例
2018/02/11 Javascript
Vue render渲染时间戳转时间,时间转时间戳及渲染进度条效果
2018/07/27 Javascript
让 babel webpack vue 配置文件支持智能提示的方法
2019/06/22 Javascript
JavaScript JSON数据处理全集(小结)
2019/08/15 Javascript
vue请求服务器数据后绑定不上的解决方法
2019/10/30 Javascript
通过实例了解Nodejs模块系统及require机制
2020/07/16 NodeJs
优化Python代码使其加快作用域内的查找
2015/03/30 Python
Python配置文件解析模块ConfigParser使用实例
2015/04/13 Python
Python中二维列表如何获取子区域元素的组成
2017/01/19 Python
PyTorch的torch.cat用法
2020/06/28 Python
Python在字符串中处理html和xml的方法
2020/07/31 Python
Python爬虫使用bs4方法实现数据解析
2020/08/25 Python
亚马逊加拿大网站:Amazon.ca
2020/01/06 全球购物
完美实现CSS垂直居中的11种方法
2021/03/27 HTML / CSS
中国入世承诺
2014/04/01 职场文书
体育个人工作总结
2015/02/09 职场文书
化工厂员工工作总结
2015/10/15 职场文书
实例详解Python的进程,线程和协程
2022/03/13 Python
jdbc中自带MySQL 连接池实践示例
2022/07/23 MySQL