pytorch 预训练模型读取修改相关参数的填坑问题


Posted in Python onJune 05, 2021

pytorch 预训练模型读取修改相关参数的填坑

修改部分层,仍然调用之前的模型参数。

resnet = resnet50(pretrained=False)
resnet.load_state_dict(torch.load(args.predir))
 
res_conv31 = Bottleneck_dilated(1024, 256,dilated_rate = 2)
print("---------------------",res_conv31)
print("---------------------",resnet.layer3[1])
 
res_conv31.load_state_dict(resnet.layer3[1].state_dict())

网络预训练模型与之前的模型对应不上,名称差个前缀

model_dict = model.state_dict()
# print(model_dict)
pretrained_dict = torch.load("/yzc/reid_testpcb/se_resnet50-ce0d4300.pth")
keys = []
for k, v in pretrained_dict.items():
       keys.append(k)
i = 0
for k, v in model_dict.items():
    if v.size() == pretrained_dict[keys[i]].size():
         model_dict[k] = pretrained_dict[keys[i]]
         #print(model_dict[k])
         i = i + 1
model.load_state_dict(model_dict)

最后是修改参数名拿来用的,

from collections import OrderedDict
pretrained_dict = torch.load('premodel')
 
new_state_dict = OrderedDict()
 
# for k, v in mgn_state_dict.items():
#     name = k[7:]  # remove `module.`
#     new_state_dict[name] = v
# self.model = self.model.load_state_dict(new_state_dict)
 
for k, v in pretrained_dict.items():
    name = "model.module."+k   # remove `module.`
    # print(name)
    new_state_dict[name] = v
self.model.load_state_dict(new_state_dict)

pytorch:加载预训练模型中的部分参数,并固定该部分参数(真实有效)

大家在学习pytorch时,可能想利用pytorch进行fine-tune,但是又烦恼于参数的加载问题。下面我将讲诉我的使用心得。

Step1: 加载预训练模型,并去除需要再次训练的层

#注意:需要重新训练的层的名字要和之前的不同。
model=resnet()#自己构建的模型,以resnet为例
model_dict = model.state_dict()
pretrained_dict = torch.load('xxx.pkl')
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
model_dict.update(pretrained_dict)
model.load_state_dict(model_dict)

Step2:固定部分参数

#k是可训练参数的名字,v是包含可训练参数的一个实体
#可以先print(k),找到自己想进行调整的层,并将该层的名字加入到if语句中:
for k,v in model.named_parameters():
    if k!='xxx.weight' and k!='xxx.bias' :
        v.requires_grad=False#固定参数

Step3:训练部分参数

#将要训练的参数放入优化器
optimizer2=torch.optim.Adam(params=[model.xxx.weight,model.xxx.bias],lr=learning_rate,betas=(0.9,0.999),weight_decay=1e-5)

Step4:检查部分参数是否固定

debug之后,程序正常运行,最好检查一下网络的参数是否真的被固定了,如何没固定,网络的状态接近于重新训练,可能会导致网络性能不稳定,也没办法得到想要得到的性能提升。

for k,v in model.named_parameters():
   if k!='xxx.weight' and k!='xxx.bias' :
   print(v.requires_grad)#理想状态下,所有值都是False

需要注意的是,操作失误最大的影响是,loss函数几乎不会发生变化,一直处于最开始的状态,这很可能是因为所有参数都被固定了。

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Ubuntu 14.04+Django 1.7.1+Nginx+uwsgi部署教程
Nov 18 Python
Python实现竖排打印传单手机号码易撕条
Mar 16 Python
Python解析最简单的验证码
Jan 07 Python
python中实现迭代器(iterator)的方法示例
Jan 19 Python
Python中.py文件打包成exe可执行文件详解
Mar 22 Python
Python2实现的LED大数字显示效果示例
Sep 04 Python
快速了解Python相对导入
Jan 12 Python
如何使用Python脚本实现文件拷贝
Nov 20 Python
python异常处理和日志处理方式
Dec 24 Python
借助Paramiko通过Python实现linux远程登陆及sftp的操作
Mar 16 Python
Django 如何使用日期时间选择器规范用户的时间输入示例代码详解
May 22 Python
浅谈pycharm导入pandas包遇到的问题及解决
Jun 01 Python
解决pytorch 损失函数中输入输出不匹配的问题
Jun 05 #Python
Pytorch distributed 多卡并行载入模型操作
Jun 05 #Python
Pytorch中的学习率衰减及其用法详解
Jun 05 #Python
pytorch finetuning 自己的图片进行训练操作
Jun 05 #Python
Python 如何将integer转化为罗马数(3999以内)
Jun 05 #Python
刚学完怎么用Python实现定时任务,转头就跑去撩妹!
OpenCV中resize函数插值算法的实现过程(五种)
Jun 05 #Python
You might like
支持生僻字且自动识别utf-8编码的php汉字转拼音类
2014/06/27 PHP
PHP使用CURL_MULTI实现多线程采集的例子
2014/07/29 PHP
php中chdir()函数用法实例
2014/11/13 PHP
PHP中使用数组指针函数操作数组示例
2014/11/19 PHP
php生成随机颜色方法汇总
2014/12/03 PHP
Yii2分页的使用及其扩展方法详解
2016/05/23 PHP
PHP常见加密函数用法示例【crypt与md5】
2019/01/27 PHP
使用jQuery的将桌面应用程序引入浏览器
2010/11/19 Javascript
Javascript类定义语法,私有成员、受保护成员、静态成员等介绍
2011/12/08 Javascript
JavaScript 判断用户输入的邮箱及手机格式是否正确
2013/12/08 Javascript
jQuery 无限级菜单的简单实例
2014/02/21 Javascript
JS判断是否在微信浏览器打开的简单实例(推荐)
2016/08/24 Javascript
JavaScript与java语言有什么不同
2016/09/22 Javascript
jquery 抽奖小程序实现代码
2016/10/12 Javascript
Angular2.0实现modal对话框的方法示例
2018/02/18 Javascript
Vue仿支付宝支付功能
2018/05/25 Javascript
vue展示dicom文件医疗系统的实现代码
2018/08/27 Javascript
python实现排序算法
2014/02/14 Python
python对json的相关操作实例详解
2017/01/04 Python
完美解决Python2操作中文名文件乱码的问题
2017/01/04 Python
python图像常规操作
2017/11/11 Python
详解Python nose单元测试框架的安装与使用
2017/12/20 Python
利用Python实现微信找房机器人实例教程
2019/03/10 Python
Pytorch实现将模型的所有参数的梯度清0
2020/06/24 Python
阿根廷旅游网站:almundo阿根廷
2018/02/12 全球购物
德国净水壶和滤芯品牌:波尔德PearlCo(家用净水器)
2020/04/29 全球购物
大学自主招生自荐信
2013/12/16 职场文书
认购协议书范本
2014/04/22 职场文书
医院保洁服务方案
2014/06/11 职场文书
学雷锋日活动总结
2015/02/06 职场文书
单身申明具结书
2015/02/26 职场文书
2016教师廉洁教育心得体会
2016/01/13 职场文书
mysq启动失败问题及场景分析
2021/07/15 MySQL
MySQL 主从复制数据不一致的解决方法
2022/03/18 MySQL
Python 图片添加美颜效果
2022/04/28 Python
CSS浮动引起的高度塌陷问题
2022/08/05 HTML / CSS