pytorch fine-tune 预训练的模型操作


Posted in Python onJune 03, 2021

之一:

torchvision 中包含了很多预训练好的模型,这样就使得 fine-tune 非常容易。本文主要介绍如何 fine-tune torchvision 中预训练好的模型。

安装

pip install torchvision

如何 fine-tune

以 resnet18 为例:

from torchvision import models
from torch import nn
from torch import optim
 
resnet_model = models.resnet18(pretrained=True) 
# pretrained 设置为 True,会自动下载模型 所对应权重,并加载到模型中
# 也可以自己下载 权重,然后 load 到 模型中,源码中有 权重的地址。
 
# 假设 我们的 分类任务只需要 分 100 类,那么我们应该做的是
# 1. 查看 resnet 的源码
# 2. 看最后一层的 名字是啥 (在 resnet 里是 self.fc = nn.Linear(512 * block.expansion, num_classes))
# 3. 在外面替换掉这个层
resnet_model.fc= nn.Linear(in_features=..., out_features=100)
 
# 这样就 哦了,修改后的模型除了输出层的参数是 随机初始化的,其他层都是用预训练的参数初始化的。
 
# 如果只想训练 最后一层的话,应该做的是:
# 1. 将其它层的参数 requires_grad 设置为 False
# 2. 构建一个 optimizer, optimizer 管理的参数只有最后一层的参数
# 3. 然后 backward, step 就可以了
 
# 这一步可以节省大量的时间,因为多数的参数不需要计算梯度
for para in list(resnet_model.parameters())[:-2]:
    para.requires_grad=False 
 
optimizer = optim.SGD(params=[resnet_model.fc.weight, resnet_model.fc.bias], lr=1e-3)
 
...

为什么

这里介绍下 运行resnet_model.fc= nn.Linear(in_features=..., out_features=100)时 框架内发生了什么

这时应该看 nn.Module 源码的 __setattr__ 部分,因为 setattr 时都会调用这个方法:

def __setattr__(self, name, value):
    def remove_from(*dicts):
        for d in dicts:
            if name in d:
                del d[name]

首先映入眼帘就是 remove_from 这个函数,这个函数的目的就是,如果出现了 同名的属性,就将旧的属性移除。 用刚才举的例子就是:

预训练的模型中 有个 名字叫fc 的 Module。

在类定义外,我们 将另一个 Module 重新 赋值给了 fc。

类定义内的 fc 对应的 Module 就会从 模型中 删除。

之二:

前言

这篇文章算是论坛PyTorch Forums关于参数初始化和finetune的总结,也是我在写代码中用的算是“最佳实践”吧。最后希望大家没事多逛逛论坛,有很多高质量的回答。

参数初始化

参数的初始化其实就是对参数赋值。而我们需要学习的参数其实都是Variable,它其实是对Tensor的封装,同时提供了data,grad等借口,这就意味着我们可以直接对这些参数进行操作赋值了。这就是PyTorch简洁高效所在。

pytorch fine-tune 预训练的模型操作

所以我们可以进行如下操作进行初始化,当然其实有其他的方法,但是这种方法是PyTorch作者所推崇的:

def weight_init(m):
# 使用isinstance来判断m属于什么类型
    if isinstance(m, nn.Conv2d):
        n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
        m.weight.data.normal_(0, math.sqrt(2. / n))
    elif isinstance(m, nn.BatchNorm2d):
# m中的weight,bias其实都是Variable,为了能学习参数以及后向传播
        m.weight.data.fill_(1)
        m.bias.data.zero_()

Finetune

往往在加载了预训练模型的参数之后,我们需要finetune模型,可以使用不同的方式finetune。

局部微调

有时候我们加载了训练模型后,只想调节最后的几层,其他层不训练。其实不训练也就意味着不进行梯度计算,PyTorch中提供的requires_grad使得对训练的控制变得非常简单。

model = torchvision.models.resnet18(pretrained=True)
for param in model.parameters():
    param.requires_grad = False
# 替换最后的全连接层, 改为训练100类
# 新构造的模块的参数默认requires_grad为True
model.fc = nn.Linear(512, 100)
 
# 只优化最后的分类层
optimizer = optim.SGD(model.fc.parameters(), lr=1e-2, momentum=0.9)

全局微调

有时候我们需要对全局都进行finetune,只不过我们希望改换过的层和其他层的学习速率不一样,这时候我们可以把其他层和新层在optimizer中单独赋予不同的学习速率。比如:

ignored_params = list(map(id, model.fc.parameters()))
base_params = filter(lambda p: id(p) not in ignored_params,
                     model.parameters())
 
optimizer = torch.optim.SGD([
            {'params': base_params},
            {'params': model.fc.parameters(), 'lr': 1e-3}
            ], lr=1e-2, momentum=0.9)

其中base_params使用1e-3来训练,model.fc.parameters使用1e-2来训练,momentum是二者共有的。

之三:

pytorch finetune模型

文章主要讲述如何在pytorch上读取以往训练的模型参数,在模型的名字已经变更的情况下又如何读取模型的部分参数等。

pytorch 模型的存储与读取

其中在模型的保存过程有存储模型和参数一起的也有单独存储模型参数的

单独存储模型参数

存储时使用:

torch.save(the_model.state_dict(), PATH)

读取时:

the_model = TheModelClass(*args, **kwargs)
the_model.load_state_dict(torch.load(PATH))

存储模型与参数

存储:

torch.save(the_model, PATH)

读取:

the_model = torch.load(PATH)

模型的参数

fine-tune的过程是读取原有模型的参数,但是由于模型的所要处理的数据集不同,最后的一层class的总数不同,所以需要修改模型的最后一层,这样模型读取的参数,和在大数据集上训练好下载的模型参数在形式上不一样。需要我们自己去写函数读取参数。

pytorch模型参数的形式

模型的参数是以字典的形式存储的。

model_dict = the_model.state_dict(),
for k,v in model_dict.items():
    print(k)

即可看到所有的键值

如果想修改模型的参数,给相应的键值赋值即可

model_dict[k] = new_value

最后更新模型的参数

the_model.load_state_dict(model_dict)

如果模型的key值和在大数据集上训练时的key值是一样的

我们可以通过下列算法进行读取模型

model_dict = model.state_dict() 
pretrained_dict = torch.load(model_path)
 # 1. filter out unnecessary keys
diff = {k: v for k, v in model_dict.items() if \
            k in pretrained_dict and pretrained_dict[k].size() == v.size()}
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict and model_dict[k].size() == v.size()}
pretrained_dict.update(diff)
# 2. overwrite entries in the existing state dict
model_dict.update(pretrained_dict)
# 3. load the new state dict
model.load_state_dict(model_dict)

如果模型的key值和在大数据集上训练时的key值是不一样的,但是顺序是一样的

model_dict = model.state_dict() 
pretrained_dict = torch.load(model_path)
keys = []
for k,v in pretrained_dict.items():
    keys.append(k)
i = 0
for k,v in model_dict.items():
    if v.size() == pretrained_dict[keys[i]].size():
        print(k, ',', keys[i])
         model_dict[k]=pretrained_dict[keys[i]]
    i = i + 1
model.load_state_dict(model_dict)

如果模型的key值和在大数据集上训练时的key值是不一样的,但是顺序是也不一样的

自己找对应关系,一个key对应一个key的赋值

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
把大数据数字口语化(python与js)两种实现
Feb 21 Python
Python的Django框架中的URL配置与松耦合
Jul 15 Python
Python3.6.0+opencv3.3.0人脸检测示例
May 25 Python
使用Python生成200个激活码的实现方法
Nov 22 Python
Python字典底层实现原理详解
Dec 18 Python
Python实现使用dir获取类的方法列表
Dec 24 Python
python+adb命令实现自动刷视频脚本案例
Apr 23 Python
Python urllib库如何添加headers过程解析
Oct 05 Python
python实现在列表中查找某个元素的下标示例
Nov 16 Python
python flask框架快速入门
May 14 Python
Python编程源码报错解决方法总结经验分享
Oct 05 Python
Python内置包对JSON文件数据进行编码和解码
Apr 12 Python
Python实现byte转integer
Jun 03 #Python
Python数据分析之绘图和可视化详解
Python数据分析之pandas读取数据
Jun 02 #Python
Python 如何实现文件自动去重
python状态机transitions库详解
Jun 02 #Python
python爬取某网站原图作为壁纸
Python爬虫之自动爬取某车之家各车销售数据
You might like
把从SQL中取出的数据转化成XMl格式
2006/10/09 PHP
php对gzip文件或者字符串解压实例参考
2008/07/25 PHP
php 分库分表hash算法
2009/11/12 PHP
Laravel中批量赋值Mass-Assignment的真正含义详解
2017/09/29 PHP
ajax+php实现无刷新验证手机号的实例
2017/12/22 PHP
php微信公众号开发之秒杀
2018/10/20 PHP
PHP实现的只保留字符串首尾字符功能示例【隐藏部分字符串】
2019/03/11 PHP
JS加ASP二级域名转向的代码
2007/05/17 Javascript
Jquery实现三层遍历删除功能代码
2013/04/23 Javascript
基于jQuery实现图片的前进与后退功能
2013/04/24 Javascript
jQuery+easyui中的combobox实现下拉框特效
2015/02/27 Javascript
Bootstrap每天必学之表格
2015/11/23 Javascript
解决vue+element 键盘回车事件导致页面刷新的问题
2018/08/25 Javascript
vue百度地图 + 定位的详解
2019/05/13 Javascript
理解Proxy及使用Proxy实现vue数据双向绑定操作
2020/07/18 Javascript
一篇文章看懂JavaScript中的回调
2021/01/05 Javascript
[02:04]完美世界城市挑战赛秋季赛报名开始 谁是solo路人王?
2019/10/10 DOTA
python3写的简单本地文件上传服务器实例
2018/06/04 Python
python2和python3的输入和输出区别介绍
2018/11/20 Python
python 文本单词提取和词频统计的实例
2018/12/22 Python
详解js文件通过python访问数据库方法
2019/03/03 Python
python实现雪花飘落效果实例讲解
2019/06/18 Python
python实现俄罗斯方块游戏(改进版)
2020/03/13 Python
python中类与对象之间的关系详解
2020/12/16 Python
CSS3中各种颜色属性的使用教程
2016/05/17 HTML / CSS
HTML5中的Web Notification桌面通知功能的实现方法
2019/07/29 HTML / CSS
We Fashion荷兰:一家国际时装公司
2018/04/18 全球购物
DOUGLAS荷兰:购买香水和化妆品
2020/10/24 全球购物
高中竞选班长演讲稿
2014/04/24 职场文书
安全环保演讲稿
2014/08/28 职场文书
导游词400字
2015/02/13 职场文书
导游词之四川熊猫基地
2020/01/13 职场文书
nginx 多个location转发任意请求或访问静态资源文件的实现
2021/03/31 Servers
html实现随机点名器的示例代码
2021/04/02 Javascript
CSS 新特性 contain控制页面的重绘与重排问题
2021/04/30 HTML / CSS
Python使用DFA算法过滤内容敏感词
2022/04/22 Python