pytorch 预训练模型读取修改相关参数的填坑问题


Posted in Python onJune 05, 2021

pytorch 预训练模型读取修改相关参数的填坑

修改部分层,仍然调用之前的模型参数。

resnet = resnet50(pretrained=False)
resnet.load_state_dict(torch.load(args.predir))
 
res_conv31 = Bottleneck_dilated(1024, 256,dilated_rate = 2)
print("---------------------",res_conv31)
print("---------------------",resnet.layer3[1])
 
res_conv31.load_state_dict(resnet.layer3[1].state_dict())

网络预训练模型与之前的模型对应不上,名称差个前缀

model_dict = model.state_dict()
# print(model_dict)
pretrained_dict = torch.load("/yzc/reid_testpcb/se_resnet50-ce0d4300.pth")
keys = []
for k, v in pretrained_dict.items():
       keys.append(k)
i = 0
for k, v in model_dict.items():
    if v.size() == pretrained_dict[keys[i]].size():
         model_dict[k] = pretrained_dict[keys[i]]
         #print(model_dict[k])
         i = i + 1
model.load_state_dict(model_dict)

最后是修改参数名拿来用的,

from collections import OrderedDict
pretrained_dict = torch.load('premodel')
 
new_state_dict = OrderedDict()
 
# for k, v in mgn_state_dict.items():
#     name = k[7:]  # remove `module.`
#     new_state_dict[name] = v
# self.model = self.model.load_state_dict(new_state_dict)
 
for k, v in pretrained_dict.items():
    name = "model.module."+k   # remove `module.`
    # print(name)
    new_state_dict[name] = v
self.model.load_state_dict(new_state_dict)

pytorch:加载预训练模型中的部分参数,并固定该部分参数(真实有效)

大家在学习pytorch时,可能想利用pytorch进行fine-tune,但是又烦恼于参数的加载问题。下面我将讲诉我的使用心得。

Step1: 加载预训练模型,并去除需要再次训练的层

#注意:需要重新训练的层的名字要和之前的不同。
model=resnet()#自己构建的模型,以resnet为例
model_dict = model.state_dict()
pretrained_dict = torch.load('xxx.pkl')
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
model_dict.update(pretrained_dict)
model.load_state_dict(model_dict)

Step2:固定部分参数

#k是可训练参数的名字,v是包含可训练参数的一个实体
#可以先print(k),找到自己想进行调整的层,并将该层的名字加入到if语句中:
for k,v in model.named_parameters():
    if k!='xxx.weight' and k!='xxx.bias' :
        v.requires_grad=False#固定参数

Step3:训练部分参数

#将要训练的参数放入优化器
optimizer2=torch.optim.Adam(params=[model.xxx.weight,model.xxx.bias],lr=learning_rate,betas=(0.9,0.999),weight_decay=1e-5)

Step4:检查部分参数是否固定

debug之后,程序正常运行,最好检查一下网络的参数是否真的被固定了,如何没固定,网络的状态接近于重新训练,可能会导致网络性能不稳定,也没办法得到想要得到的性能提升。

for k,v in model.named_parameters():
   if k!='xxx.weight' and k!='xxx.bias' :
   print(v.requires_grad)#理想状态下,所有值都是False

需要注意的是,操作失误最大的影响是,loss函数几乎不会发生变化,一直处于最开始的状态,这很可能是因为所有参数都被固定了。

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python使用xlrd模块操作Excel数据导入的方法
May 26 Python
用Python的Flask框架结合MySQL写一个内存监控程序
Nov 07 Python
Python实现判断字符串中包含某个字符的判断函数示例
Jan 08 Python
python实现壁纸批量下载代码实例
Jan 25 Python
python+pandas生成指定日期和重采样的方法
Apr 11 Python
win10下python3.5.2和tensorflow安装环境搭建教程
Sep 19 Python
对python中的iter()函数与next()函数详解
Oct 18 Python
Python numpy.array()生成相同元素数组的示例
Nov 12 Python
一文秒懂python读写csv xml json文件各种骚操作
Jul 04 Python
选择Python写网络爬虫的优势和理由
Jul 07 Python
调用其他python脚本文件里面的类和方法过程解析
Nov 15 Python
如何用Python徒手写线性回归
Jan 25 Python
解决pytorch 损失函数中输入输出不匹配的问题
Jun 05 #Python
Pytorch distributed 多卡并行载入模型操作
Jun 05 #Python
Pytorch中的学习率衰减及其用法详解
Jun 05 #Python
pytorch finetuning 自己的图片进行训练操作
Jun 05 #Python
Python 如何将integer转化为罗马数(3999以内)
Jun 05 #Python
刚学完怎么用Python实现定时任务,转头就跑去撩妹!
OpenCV中resize函数插值算法的实现过程(五种)
Jun 05 #Python
You might like
十天学会php之第八天
2006/10/09 PHP
PHP If Else(elsefi) 语句
2013/04/07 PHP
js的闭包的一个示例说明
2008/11/18 Javascript
JavaScript 动态添加表格行 使用模板、标记
2009/10/24 Javascript
将CKfinder整合进CKEditor3.0的新方法
2010/01/10 Javascript
jQuery 添加/移除CSS类实现代码
2010/02/11 Javascript
读jQuery之一(对象的组成)
2011/06/11 Javascript
IE、FF、Chrome浏览器中的JS差异介绍
2013/08/13 Javascript
javascript操作表格排序实例分析
2015/05/06 Javascript
JS/Jquery判断对象为空的方法
2015/06/11 Javascript
javascript运动详解
2015/07/06 Javascript
基于Jquery代码实现支持PC端手机端幻灯片代码
2015/11/17 Javascript
jQuery实现二级下拉菜单效果
2016/01/05 Javascript
jQuery 3.0 的变化及使用方法
2016/02/01 Javascript
AngularJS 2.0入门权威指南
2016/10/08 Javascript
addeventlistener监听scroll跟touch(实例讲解)
2017/08/04 Javascript
Vue Router的懒加载路径的解决方法
2018/06/21 Javascript
详解JavaScript中的坐标和距离
2019/05/27 Javascript
layUI使用layer.open,在content打开数据表格,获取值并返回的方法
2019/09/26 Javascript
使用pkg打包ThinkJS项目的方法步骤
2019/12/30 Javascript
python安装numpy&安装matplotlib& scipy的教程
2017/11/02 Python
python线程池(threadpool)模块使用笔记详解
2017/11/17 Python
Sanic框架配置操作分析
2018/07/17 Python
python pytest进阶之xunit fixture详解
2019/06/27 Python
通过 Python 和 OpenCV 实现目标数量监控
2020/01/05 Python
python爬取招聘要求等信息实例
2020/11/20 Python
Css3实现无缝滚动防抖
2020/09/14 HTML / CSS
HTML5 文件域+FileReader 分段读取文件并上传到服务器
2017/10/23 HTML / CSS
Onzie官网:美国时尚瑜伽品牌
2019/08/21 全球购物
英国在线玫瑰专家:InterRose
2019/12/01 全球购物
介绍一下你对SOA的认识
2016/04/24 面试题
实习自荐信
2013/10/13 职场文书
小孩百日宴答谢词
2014/01/15 职场文书
幼儿园校园小喇叭广播稿
2014/10/17 职场文书
工作感想范文
2015/08/07 职场文书
用Python生成会跳舞的美女
2022/01/18 Python