pytorch 预训练模型读取修改相关参数的填坑问题


Posted in Python onJune 05, 2021

pytorch 预训练模型读取修改相关参数的填坑

修改部分层,仍然调用之前的模型参数。

resnet = resnet50(pretrained=False)
resnet.load_state_dict(torch.load(args.predir))
 
res_conv31 = Bottleneck_dilated(1024, 256,dilated_rate = 2)
print("---------------------",res_conv31)
print("---------------------",resnet.layer3[1])
 
res_conv31.load_state_dict(resnet.layer3[1].state_dict())

网络预训练模型与之前的模型对应不上,名称差个前缀

model_dict = model.state_dict()
# print(model_dict)
pretrained_dict = torch.load("/yzc/reid_testpcb/se_resnet50-ce0d4300.pth")
keys = []
for k, v in pretrained_dict.items():
       keys.append(k)
i = 0
for k, v in model_dict.items():
    if v.size() == pretrained_dict[keys[i]].size():
         model_dict[k] = pretrained_dict[keys[i]]
         #print(model_dict[k])
         i = i + 1
model.load_state_dict(model_dict)

最后是修改参数名拿来用的,

from collections import OrderedDict
pretrained_dict = torch.load('premodel')
 
new_state_dict = OrderedDict()
 
# for k, v in mgn_state_dict.items():
#     name = k[7:]  # remove `module.`
#     new_state_dict[name] = v
# self.model = self.model.load_state_dict(new_state_dict)
 
for k, v in pretrained_dict.items():
    name = "model.module."+k   # remove `module.`
    # print(name)
    new_state_dict[name] = v
self.model.load_state_dict(new_state_dict)

pytorch:加载预训练模型中的部分参数,并固定该部分参数(真实有效)

大家在学习pytorch时,可能想利用pytorch进行fine-tune,但是又烦恼于参数的加载问题。下面我将讲诉我的使用心得。

Step1: 加载预训练模型,并去除需要再次训练的层

#注意:需要重新训练的层的名字要和之前的不同。
model=resnet()#自己构建的模型,以resnet为例
model_dict = model.state_dict()
pretrained_dict = torch.load('xxx.pkl')
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
model_dict.update(pretrained_dict)
model.load_state_dict(model_dict)

Step2:固定部分参数

#k是可训练参数的名字,v是包含可训练参数的一个实体
#可以先print(k),找到自己想进行调整的层,并将该层的名字加入到if语句中:
for k,v in model.named_parameters():
    if k!='xxx.weight' and k!='xxx.bias' :
        v.requires_grad=False#固定参数

Step3:训练部分参数

#将要训练的参数放入优化器
optimizer2=torch.optim.Adam(params=[model.xxx.weight,model.xxx.bias],lr=learning_rate,betas=(0.9,0.999),weight_decay=1e-5)

Step4:检查部分参数是否固定

debug之后,程序正常运行,最好检查一下网络的参数是否真的被固定了,如何没固定,网络的状态接近于重新训练,可能会导致网络性能不稳定,也没办法得到想要得到的性能提升。

for k,v in model.named_parameters():
   if k!='xxx.weight' and k!='xxx.bias' :
   print(v.requires_grad)#理想状态下,所有值都是False

需要注意的是,操作失误最大的影响是,loss函数几乎不会发生变化,一直处于最开始的状态,这很可能是因为所有参数都被固定了。

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python多线程编程(二):启动线程的两种方法
Apr 05 Python
Python实用日期时间处理方法汇总
May 09 Python
Python 功能和特点(新手必学)
Dec 30 Python
使用Python3制作TCP端口扫描器
Apr 17 Python
单链表反转python实现代码示例
Feb 08 Python
基于python3 OpenCV3实现静态图片人脸识别
May 25 Python
python取余运算符知识点详解
Jun 27 Python
Python上下文管理器类和上下文管理器装饰器contextmanager用法实例分析
Nov 07 Python
django数据模型(Model)的字段类型解析
Dec 25 Python
对tensorflow中的strides参数使用详解
Jan 04 Python
Python日期格式和字符串格式相互转换的方法
Feb 18 Python
Python 中如何使用 virtualenv 管理虚拟环境
Jan 21 Python
解决pytorch 损失函数中输入输出不匹配的问题
Jun 05 #Python
Pytorch distributed 多卡并行载入模型操作
Jun 05 #Python
Pytorch中的学习率衰减及其用法详解
Jun 05 #Python
pytorch finetuning 自己的图片进行训练操作
Jun 05 #Python
Python 如何将integer转化为罗马数(3999以内)
Jun 05 #Python
刚学完怎么用Python实现定时任务,转头就跑去撩妹!
OpenCV中resize函数插值算法的实现过程(五种)
Jun 05 #Python
You might like
phpmyadmin里面导入sql语句格式的大量数据的方法
2010/06/05 PHP
ThinkPHP3.1之D方法实例详解
2014/06/20 PHP
Zend Framework教程之动作的基类Zend_Controller_Action详解
2016/03/07 PHP
PHP 数组基本操作方法详解
2016/06/17 PHP
php处理单文件、多文件上传代码分享
2016/08/24 PHP
JavaScipt基本教程之JavaScript语言的基础
2008/01/16 Javascript
在Ajax中使用Flash实现跨域数据读取的实现方法
2010/12/02 Javascript
ANT 压缩(去掉空格/注释)JS文件可提高js运行速度
2013/04/15 Javascript
JS判定是否原生方法
2013/07/22 Javascript
由点击页面其它地方隐藏div所想到的jQuery的delegate
2013/08/29 Javascript
js动态添加事件并可传参数示例代码
2013/10/21 Javascript
简单实现限制uploadify上传个数
2015/11/16 Javascript
vue中渐进过渡效果实现
2016/10/27 Javascript
JS实现的简单表单验证功能示例
2017/10/13 Javascript
webuploader实现上传图片到服务器功能
2018/08/16 Javascript
只有 20 行的 JavaScript 模板引擎实例详解
2020/05/11 Javascript
解决vue单页面多个组件嵌套监听浏览器窗口变化问题
2020/07/30 Javascript
python实现系统状态监测和故障转移实例方法
2013/11/18 Python
Python实现Linux命令xxd -i功能
2016/03/06 Python
tensorflow 中对数组元素的操作方法
2018/07/27 Python
详解python--模拟轮盘抽奖游戏
2019/04/12 Python
Python 实现数组相减示例
2019/12/27 Python
Django 5种类型Session使用方法解析
2020/04/29 Python
css3 旋转按钮 使用CSS3创建一个旋转可变色按钮
2012/12/31 HTML / CSS
CSS3 对过渡(transition)进行调速以及延时
2020/10/21 HTML / CSS
德国旅游网站:weg.de
2018/06/03 全球购物
农业资源与环境专业自荐信范文
2013/12/30 职场文书
初中三好学生事迹材料
2014/01/13 职场文书
深入开展党的群众路线教育实践活动方案
2014/02/04 职场文书
医学生临床实习自我评价
2014/03/07 职场文书
《毛主席在花山》教学反思
2014/04/20 职场文书
公司周年庆典标语
2014/10/07 职场文书
查摆问题整改措施范文
2014/10/11 职场文书
中学生逃课检讨书
2015/02/17 职场文书
JavaScript 语句之常用 for 循环详解
2021/03/29 Javascript
Python中的np.argmin()和np.argmax()函数用法
2021/06/02 Python