pytorch载入预训练模型后,实现训练指定层


Posted in Python onJanuary 06, 2020

1、有了已经训练好的模型参数,对这个模型的某些层做了改变,如何利用这些训练好的模型参数继续训练:

pretrained_params = torch.load('Pretrained_Model')
model = The_New_Model(xxx)
model.load_state_dict(pretrained_params.state_dict(), strict=False)

strict=False 使得预训练模型参数中和新模型对应上的参数会被载入,对应不上或没有的参数被抛弃。

2、如果载入的这些参数中,有些参数不要求被更新,即固定不变,不参与训练,需要手动设置这些参数的梯度属性为Fasle,并且在optimizer传参时筛选掉这些参数:

# 载入预训练模型参数后...
for name, value in model.named_parameters():
  if name 满足某些条件:
    value.requires_grad = False

# setup optimizer
params = filter(lambda p: p.requires_grad, model.parameters())
optimizer = torch.optim.Adam(params, lr=1e-4)

将满足条件的参数的 requires_grad 属性设置为False, 同时 filter 函数将模型中属性 requires_grad = True 的参数帅选出来,传到优化器(以Adam为例)中,只有这些参数会被求导数和更新。

3、如果载入的这些参数中,所有参数都更新,但要求一些参数和另一些参数的更新速度(学习率learning rate)不一样,最好知道这些参数的名称都有什么:

# 载入预训练模型参数后...
for name, value in model.named_parameters():
  print(name)
# 或
print(model.state_dict().keys())

假设该模型中有encoder,viewer和decoder两部分,参数名称分别是:

'encoder.visual_emb.0.weight',
'encoder.visual_emb.0.bias',
'viewer.bd.Wsi',
'viewer.bd.bias',
'decoder.core.layer_0.weight_ih',
'decoder.core.layer_0.weight_hh',

假设要求encode、viewer的学习率为1e-6, decoder的学习率为1e-4,那么在将参数传入优化器时:

ignored_params = list(map(id, model.decoder.parameters()))
base_params = filter(lambda p: id(p) not in ignored_params, model.parameters())
optimizer = torch.optim.Adam([{'params':base_params,'lr':1e-6},
               {'params':model.decoder.parameters()}
               ],
               lr=1e-4, momentum=0.9)

代码的结果是除decoder参数的learning_rate=1e-4 外,其他参数的额learning_rate=1e-6。

在传入optimizer时,和一般的传参方法torch.optim.Adam(model.parameters(), lr=xxx) 不同,参数部分用了一个list, list的每个元素有params和lr两个键值。如果没有 lr则应用Adam的lr属性。Adam的属性除了lr, 其他都是参数所共有的(比如momentum)。

以上这篇pytorch载入预训练模型后,实现训练指定层就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

参考:

Python 相关文章推荐
学习python处理python编码问题
Mar 13 Python
python实现神经网络感知器算法
Dec 20 Python
Python基于whois模块简单识别网站域名及所有者的方法
Apr 23 Python
在Python中分别打印列表中的每一个元素方法
Nov 07 Python
Python实现字符型图片验证码识别完整过程详解
May 10 Python
python+tkinter实现学生管理系统
Aug 20 Python
Django 框架模型操作入门教程
Nov 05 Python
详解Anconda环境下载python包的教程(图形界面+命令行+pycharm安装)
Nov 11 Python
pyinstaller 3.6版本通过pip安装失败的解决办法(推荐)
Jan 18 Python
详解python变量与数据类型
Aug 25 Python
python爬虫筛选工作实例讲解
Nov 23 Python
利用Python批量识别电子账单数据的方法
Feb 08 Python
python与mysql数据库交互的实现
Jan 06 #Python
win10系统下python3安装及pip换源和使用教程
Jan 06 #Python
基于python实现文件加密功能
Jan 06 #Python
Pytorch 实现冻结指定卷积层的参数
Jan 06 #Python
如何使用python实现模拟鼠标点击
Jan 06 #Python
pytorch 实现查看网络中的参数
Jan 06 #Python
Python3 虚拟开发环境搭建过程(图文详解)
Jan 06 #Python
You might like
php访问数组最后一个元素的函数end()用法
2015/03/18 PHP
PHP基于面向对象封装的分页类示例
2019/03/15 PHP
PHP Trait功能与用法实例分析
2020/06/03 PHP
javascript 一个函数对同一元素的多个事件响应
2009/07/25 Javascript
javascript 多种搜索引擎集成的页面实现代码
2010/01/02 Javascript
基于jQuery实现的Ajax 验证用户名是否存在的实现代码
2011/04/06 Javascript
js实现的标题栏新消息闪烁提示效果
2014/06/06 Javascript
基于socket.io+express实现多房间聊天
2016/03/17 Javascript
js/jquery控制页面动态加载数据 滑动滚动条自动加载事件的方法
2017/02/08 Javascript
JavaScript中清空数组的三种方式
2017/03/22 Javascript
JS实现线性表的顺序表示方法示例【经典数据结构】
2017/04/11 Javascript
koa socket即时通讯的示例代码
2018/09/07 Javascript
微信小程序日历效果
2018/12/29 Javascript
vue路由--网站导航功能详解
2019/03/29 Javascript
vue-router源码之history类的浅析
2019/05/21 Javascript
使用Vue.observable()进行状态管理的实例代码详解
2019/05/26 Javascript
python分割和拼接字符串
2013/11/01 Python
Python 3.x 连接数据库示例(pymysql 方式)
2017/01/19 Python
sublime text 3配置使用python操作方法
2017/06/11 Python
浅谈flask截获所有访问及before/after_request修饰器
2018/01/18 Python
Selenium 模拟浏览器动态加载页面的实现方法
2018/05/16 Python
python实现抖音点赞功能
2019/04/07 Python
python按键按住不放持续响应的实例代码
2019/07/17 Python
Python实现socket非阻塞通讯功能示例
2019/11/06 Python
Python 字节流,字符串,十六进制相互转换实例(binascii,bytes)
2020/05/11 Python
Python-openCV开运算实例
2020/07/05 Python
Python多分支if语句的使用
2020/09/03 Python
HTML5中的postMessage API基本使用教程
2016/05/20 HTML / CSS
英国在线滑雪板和冲浪商店:The Board Basement
2020/01/11 全球购物
测试驱动开发的主要步骤是什么
2014/12/10 面试题
社区工作者先进事迹
2014/01/18 职场文书
计算机科学系职业生涯规划书
2014/03/08 职场文书
企业承诺书怎么写
2014/05/24 职场文书
信访维稳工作汇报
2014/10/27 职场文书
2016公司新年问候语
2015/11/11 职场文书
找规律教学反思
2016/02/23 职场文书