pytorch fine-tune 预训练的模型操作


Posted in Python onJune 03, 2021

之一:

torchvision 中包含了很多预训练好的模型,这样就使得 fine-tune 非常容易。本文主要介绍如何 fine-tune torchvision 中预训练好的模型。

安装

pip install torchvision

如何 fine-tune

以 resnet18 为例:

from torchvision import models
from torch import nn
from torch import optim
 
resnet_model = models.resnet18(pretrained=True) 
# pretrained 设置为 True,会自动下载模型 所对应权重,并加载到模型中
# 也可以自己下载 权重,然后 load 到 模型中,源码中有 权重的地址。
 
# 假设 我们的 分类任务只需要 分 100 类,那么我们应该做的是
# 1. 查看 resnet 的源码
# 2. 看最后一层的 名字是啥 (在 resnet 里是 self.fc = nn.Linear(512 * block.expansion, num_classes))
# 3. 在外面替换掉这个层
resnet_model.fc= nn.Linear(in_features=..., out_features=100)
 
# 这样就 哦了,修改后的模型除了输出层的参数是 随机初始化的,其他层都是用预训练的参数初始化的。
 
# 如果只想训练 最后一层的话,应该做的是:
# 1. 将其它层的参数 requires_grad 设置为 False
# 2. 构建一个 optimizer, optimizer 管理的参数只有最后一层的参数
# 3. 然后 backward, step 就可以了
 
# 这一步可以节省大量的时间,因为多数的参数不需要计算梯度
for para in list(resnet_model.parameters())[:-2]:
    para.requires_grad=False 
 
optimizer = optim.SGD(params=[resnet_model.fc.weight, resnet_model.fc.bias], lr=1e-3)
 
...

为什么

这里介绍下 运行resnet_model.fc= nn.Linear(in_features=..., out_features=100)时 框架内发生了什么

这时应该看 nn.Module 源码的 __setattr__ 部分,因为 setattr 时都会调用这个方法:

def __setattr__(self, name, value):
    def remove_from(*dicts):
        for d in dicts:
            if name in d:
                del d[name]

首先映入眼帘就是 remove_from 这个函数,这个函数的目的就是,如果出现了 同名的属性,就将旧的属性移除。 用刚才举的例子就是:

预训练的模型中 有个 名字叫fc 的 Module。

在类定义外,我们 将另一个 Module 重新 赋值给了 fc。

类定义内的 fc 对应的 Module 就会从 模型中 删除。

之二:

前言

这篇文章算是论坛PyTorch Forums关于参数初始化和finetune的总结,也是我在写代码中用的算是“最佳实践”吧。最后希望大家没事多逛逛论坛,有很多高质量的回答。

参数初始化

参数的初始化其实就是对参数赋值。而我们需要学习的参数其实都是Variable,它其实是对Tensor的封装,同时提供了data,grad等借口,这就意味着我们可以直接对这些参数进行操作赋值了。这就是PyTorch简洁高效所在。

pytorch fine-tune 预训练的模型操作

所以我们可以进行如下操作进行初始化,当然其实有其他的方法,但是这种方法是PyTorch作者所推崇的:

def weight_init(m):
# 使用isinstance来判断m属于什么类型
    if isinstance(m, nn.Conv2d):
        n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
        m.weight.data.normal_(0, math.sqrt(2. / n))
    elif isinstance(m, nn.BatchNorm2d):
# m中的weight,bias其实都是Variable,为了能学习参数以及后向传播
        m.weight.data.fill_(1)
        m.bias.data.zero_()

Finetune

往往在加载了预训练模型的参数之后,我们需要finetune模型,可以使用不同的方式finetune。

局部微调

有时候我们加载了训练模型后,只想调节最后的几层,其他层不训练。其实不训练也就意味着不进行梯度计算,PyTorch中提供的requires_grad使得对训练的控制变得非常简单。

model = torchvision.models.resnet18(pretrained=True)
for param in model.parameters():
    param.requires_grad = False
# 替换最后的全连接层, 改为训练100类
# 新构造的模块的参数默认requires_grad为True
model.fc = nn.Linear(512, 100)
 
# 只优化最后的分类层
optimizer = optim.SGD(model.fc.parameters(), lr=1e-2, momentum=0.9)

全局微调

有时候我们需要对全局都进行finetune,只不过我们希望改换过的层和其他层的学习速率不一样,这时候我们可以把其他层和新层在optimizer中单独赋予不同的学习速率。比如:

ignored_params = list(map(id, model.fc.parameters()))
base_params = filter(lambda p: id(p) not in ignored_params,
                     model.parameters())
 
optimizer = torch.optim.SGD([
            {'params': base_params},
            {'params': model.fc.parameters(), 'lr': 1e-3}
            ], lr=1e-2, momentum=0.9)

其中base_params使用1e-3来训练,model.fc.parameters使用1e-2来训练,momentum是二者共有的。

之三:

pytorch finetune模型

文章主要讲述如何在pytorch上读取以往训练的模型参数,在模型的名字已经变更的情况下又如何读取模型的部分参数等。

pytorch 模型的存储与读取

其中在模型的保存过程有存储模型和参数一起的也有单独存储模型参数的

单独存储模型参数

存储时使用:

torch.save(the_model.state_dict(), PATH)

读取时:

the_model = TheModelClass(*args, **kwargs)
the_model.load_state_dict(torch.load(PATH))

存储模型与参数

存储:

torch.save(the_model, PATH)

读取:

the_model = torch.load(PATH)

模型的参数

fine-tune的过程是读取原有模型的参数,但是由于模型的所要处理的数据集不同,最后的一层class的总数不同,所以需要修改模型的最后一层,这样模型读取的参数,和在大数据集上训练好下载的模型参数在形式上不一样。需要我们自己去写函数读取参数。

pytorch模型参数的形式

模型的参数是以字典的形式存储的。

model_dict = the_model.state_dict(),
for k,v in model_dict.items():
    print(k)

即可看到所有的键值

如果想修改模型的参数,给相应的键值赋值即可

model_dict[k] = new_value

最后更新模型的参数

the_model.load_state_dict(model_dict)

如果模型的key值和在大数据集上训练时的key值是一样的

我们可以通过下列算法进行读取模型

model_dict = model.state_dict() 
pretrained_dict = torch.load(model_path)
 # 1. filter out unnecessary keys
diff = {k: v for k, v in model_dict.items() if \
            k in pretrained_dict and pretrained_dict[k].size() == v.size()}
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict and model_dict[k].size() == v.size()}
pretrained_dict.update(diff)
# 2. overwrite entries in the existing state dict
model_dict.update(pretrained_dict)
# 3. load the new state dict
model.load_state_dict(model_dict)

如果模型的key值和在大数据集上训练时的key值是不一样的,但是顺序是一样的

model_dict = model.state_dict() 
pretrained_dict = torch.load(model_path)
keys = []
for k,v in pretrained_dict.items():
    keys.append(k)
i = 0
for k,v in model_dict.items():
    if v.size() == pretrained_dict[keys[i]].size():
        print(k, ',', keys[i])
         model_dict[k]=pretrained_dict[keys[i]]
    i = i + 1
model.load_state_dict(model_dict)

如果模型的key值和在大数据集上训练时的key值是不一样的,但是顺序是也不一样的

自己找对应关系,一个key对应一个key的赋值

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python ZipFile模块详解
Nov 01 Python
python完成FizzBuzzWhizz问题(拉勾网面试题)示例
May 05 Python
python3 selenium 切换窗口的几种方法小结
May 21 Python
使用OpenCV实现仿射变换—平移功能
Aug 29 Python
Django中间件拦截未登录url实例详解
Sep 03 Python
flask框架配置mysql数据库操作详解
Nov 29 Python
使用pyhon绘图比较两个手机屏幕大小(实例代码)
Jan 03 Python
python不同系统中打开方法
Jun 23 Python
解决keras GAN训练是loss不发生变化,accuracy一直为0.5的问题
Jul 02 Python
详解vscode实现远程linux服务器上Python开发
Nov 10 Python
python 实现超级玛丽游戏
Nov 25 Python
如何使用flask将模型部署为服务
May 13 Python
Python实现byte转integer
Jun 03 #Python
Python数据分析之绘图和可视化详解
Python数据分析之pandas读取数据
Jun 02 #Python
Python 如何实现文件自动去重
python状态机transitions库详解
Jun 02 #Python
python爬取某网站原图作为壁纸
Python爬虫之自动爬取某车之家各车销售数据
You might like
php上传图片生成缩略图(GD库)
2016/01/06 PHP
详解Laravel视图间共享数据与视图Composer
2016/08/04 PHP
Ajax和PHP正则表达式验证表单及验证码
2016/09/24 PHP
php简单统计中文个数的方法
2016/09/30 PHP
php5.3/5.4/5.5/5.6/7常见新增特性汇总整理
2020/02/27 PHP
JS面向对象编程之对象使用分析
2010/08/19 Javascript
javascript控制Div层透明属性由浅变深由深变浅逐渐显示
2013/11/12 Javascript
一个检测表单数据的JavaScript实例
2014/10/31 Javascript
javascript的函数作用域
2014/11/12 Javascript
javascript下拉列表菜单的实现方法
2015/11/18 Javascript
全面详细的jQuery常见开发技巧手册
2016/02/21 Javascript
深入理解JS正则表达式---分组
2016/07/18 Javascript
JS实现的表格行上下移动操作示例
2016/08/03 Javascript
实例解析Array和String方法
2016/12/14 Javascript
详解Nodejs 通过 fs.createWriteStream 保存文件
2017/10/10 NodeJs
Angular.js实现获取验证码倒计时60秒按钮的简单方法
2017/10/18 Javascript
4个顶级JavaScript高级文本编辑器
2018/10/10 Javascript
react+ant design实现Table的增、删、改的示例代码
2018/12/27 Javascript
基于JS开发微信网页录音功能的实例代码
2019/04/30 Javascript
Vue+Node实现的商城用户管理功能示例
2019/12/23 Javascript
逐行分析鸿蒙系统的 JavaScript 框架(推荐)
2020/09/17 Javascript
[46:20]DOTA2-DPC中国联赛 正赛 PSG.LGD vs LBZS BO3 第二场 1月22日
2021/03/11 DOTA
Python使用Scrapy爬取妹子图
2015/05/28 Python
python中关于for循环的碎碎念
2017/06/30 Python
python中利用Future对象异步返回结果示例代码
2017/09/07 Python
Python字符串和字典相关操作的实例详解
2017/09/23 Python
Ivory Isle Designs美国/加拿大:婚礼和活动文具公司
2018/08/21 全球购物
戴尔新加坡官网:Dell Singapore
2020/12/13 全球购物
数据库面试要点基本概念
2013/10/31 面试题
医疗纠纷协议书
2014/04/16 职场文书
卫生标语大全
2014/06/21 职场文书
毕业设计致谢词
2015/05/14 职场文书
惊天动地观后感
2015/06/10 职场文书
安全生产奖惩制度
2015/08/06 职场文书
2019年暑期安全广播稿!
2019/07/03 职场文书
vue如何在data中引入图片的正确路径
2022/06/05 Vue.js