pytorch 预训练模型读取修改相关参数的填坑问题


Posted in Python onJune 05, 2021

pytorch 预训练模型读取修改相关参数的填坑

修改部分层,仍然调用之前的模型参数。

resnet = resnet50(pretrained=False)
resnet.load_state_dict(torch.load(args.predir))
 
res_conv31 = Bottleneck_dilated(1024, 256,dilated_rate = 2)
print("---------------------",res_conv31)
print("---------------------",resnet.layer3[1])
 
res_conv31.load_state_dict(resnet.layer3[1].state_dict())

网络预训练模型与之前的模型对应不上,名称差个前缀

model_dict = model.state_dict()
# print(model_dict)
pretrained_dict = torch.load("/yzc/reid_testpcb/se_resnet50-ce0d4300.pth")
keys = []
for k, v in pretrained_dict.items():
       keys.append(k)
i = 0
for k, v in model_dict.items():
    if v.size() == pretrained_dict[keys[i]].size():
         model_dict[k] = pretrained_dict[keys[i]]
         #print(model_dict[k])
         i = i + 1
model.load_state_dict(model_dict)

最后是修改参数名拿来用的,

from collections import OrderedDict
pretrained_dict = torch.load('premodel')
 
new_state_dict = OrderedDict()
 
# for k, v in mgn_state_dict.items():
#     name = k[7:]  # remove `module.`
#     new_state_dict[name] = v
# self.model = self.model.load_state_dict(new_state_dict)
 
for k, v in pretrained_dict.items():
    name = "model.module."+k   # remove `module.`
    # print(name)
    new_state_dict[name] = v
self.model.load_state_dict(new_state_dict)

pytorch:加载预训练模型中的部分参数,并固定该部分参数(真实有效)

大家在学习pytorch时,可能想利用pytorch进行fine-tune,但是又烦恼于参数的加载问题。下面我将讲诉我的使用心得。

Step1: 加载预训练模型,并去除需要再次训练的层

#注意:需要重新训练的层的名字要和之前的不同。
model=resnet()#自己构建的模型,以resnet为例
model_dict = model.state_dict()
pretrained_dict = torch.load('xxx.pkl')
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
model_dict.update(pretrained_dict)
model.load_state_dict(model_dict)

Step2:固定部分参数

#k是可训练参数的名字,v是包含可训练参数的一个实体
#可以先print(k),找到自己想进行调整的层,并将该层的名字加入到if语句中:
for k,v in model.named_parameters():
    if k!='xxx.weight' and k!='xxx.bias' :
        v.requires_grad=False#固定参数

Step3:训练部分参数

#将要训练的参数放入优化器
optimizer2=torch.optim.Adam(params=[model.xxx.weight,model.xxx.bias],lr=learning_rate,betas=(0.9,0.999),weight_decay=1e-5)

Step4:检查部分参数是否固定

debug之后,程序正常运行,最好检查一下网络的参数是否真的被固定了,如何没固定,网络的状态接近于重新训练,可能会导致网络性能不稳定,也没办法得到想要得到的性能提升。

for k,v in model.named_parameters():
   if k!='xxx.weight' and k!='xxx.bias' :
   print(v.requires_grad)#理想状态下,所有值都是False

需要注意的是,操作失误最大的影响是,loss函数几乎不会发生变化,一直处于最开始的状态,这很可能是因为所有参数都被固定了。

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python 中的with关键字使用详解
Sep 11 Python
Python3基于sax解析xml操作示例
May 22 Python
python 按不同维度求和,最值,均值的实例
Jun 28 Python
python对list中的每个元素进行某种操作的方法
Jun 29 Python
手把手教你如何安装Pycharm(详细图文教程)
Nov 28 Python
python中使用ctypes调用so传参设置遇到的问题及解决方法
Jun 19 Python
Tensorflow Summary用法学习笔记
Jan 10 Python
python函数定义和调用过程详解
Feb 09 Python
pycharm无法安装第三方库的问题及解决方法以scrapy为例(图解)
May 09 Python
python实现图片,视频人脸识别(dlib版)
Nov 18 Python
python随机打印成绩排名表
Jun 23 Python
Python超详细分步解析随机漫步
Mar 17 Python
解决pytorch 损失函数中输入输出不匹配的问题
Jun 05 #Python
Pytorch distributed 多卡并行载入模型操作
Jun 05 #Python
Pytorch中的学习率衰减及其用法详解
Jun 05 #Python
pytorch finetuning 自己的图片进行训练操作
Jun 05 #Python
Python 如何将integer转化为罗马数(3999以内)
Jun 05 #Python
刚学完怎么用Python实现定时任务,转头就跑去撩妹!
OpenCV中resize函数插值算法的实现过程(五种)
Jun 05 #Python
You might like
PHP 强制性文件下载功能的函数代码(任意文件格式)
2010/05/26 PHP
PHP获取浏览器信息类和客户端地理位置的2个方法
2014/04/24 PHP
ThinkPHP分页实例
2014/10/15 PHP
通过php添加xml文档内容的方法
2015/01/23 PHP
php unicode编码和字符串互转的方法
2020/08/12 PHP
基于php伪静态的实现方法解析
2020/07/31 PHP
网页里控制图片大小的相关代码
2006/06/25 Javascript
Prototype使用指南之dom.js
2007/01/10 Javascript
javascript EXCEL 操作类代码
2009/07/30 Javascript
Javascript常考语句107条收集
2010/03/09 Javascript
javascript XMLHttpRequest对象全面剖析
2010/04/24 Javascript
Javascript四舍五入Math.round()与Math.pow()使用介绍
2013/12/27 Javascript
屏蔽IE弹出"您查看的网页正在试图关闭窗口,是否关闭此窗口"的方法
2013/12/31 Javascript
jquery使用each方法遍历json格式数据实例
2015/05/18 Javascript
JS拖拽插件实现步骤
2015/08/03 Javascript
JavaScript tab选项卡插件实例代码
2016/02/23 Javascript
js中的关联数组与普通数组详解
2016/07/27 Javascript
jQuery实现微信长按识别二维码功能
2016/08/26 Javascript
jQuery实现简单复制json对象和json对象集合操作示例
2018/07/09 jQuery
小程序文字跑马灯效果
2018/12/28 Javascript
vue-cli 为项目设置别名的方法
2019/10/15 Javascript
VUEX 数据持久化,刷新后重新获取的例子
2019/11/12 Javascript
vue学习笔记之slot插槽基本用法实例分析
2020/02/01 Javascript
[36:14]DOTA2上海特级锦标赛D组小组赛#1 EG VS COL第二局
2016/02/28 DOTA
python端口扫描系统实现方法
2014/11/19 Python
Python3实现购物车功能
2018/04/18 Python
为什么str(float)在Python 3中比Python 2返回更多的数字
2018/10/16 Python
在Tensorflow中实现梯度下降法更新参数值
2020/01/23 Python
Python3交互式shell ipython3安装及使用详解
2020/07/11 Python
PyCharm安装PyQt5及其工具(Qt Designer、PyUIC、PyRcc)的步骤详解
2020/11/02 Python
Arti-shopping中文官网:大型海外商品一站式直邮平台
2020/03/23 全球购物
公关关系专员的自我评价分享
2013/11/20 职场文书
铲车司机岗位职责
2014/03/15 职场文书
全国爱牙日活动总结
2015/02/05 职场文书
浅谈mysql哪些情况会导致索引失效
2021/11/20 MySQL
TV动画《间谍过家家》公开PV
2022/03/20 日漫