pytorch fine-tune 预训练的模型操作


Posted in Python onJune 03, 2021

之一:

torchvision 中包含了很多预训练好的模型,这样就使得 fine-tune 非常容易。本文主要介绍如何 fine-tune torchvision 中预训练好的模型。

安装

pip install torchvision

如何 fine-tune

以 resnet18 为例:

from torchvision import models
from torch import nn
from torch import optim
 
resnet_model = models.resnet18(pretrained=True) 
# pretrained 设置为 True,会自动下载模型 所对应权重,并加载到模型中
# 也可以自己下载 权重,然后 load 到 模型中,源码中有 权重的地址。
 
# 假设 我们的 分类任务只需要 分 100 类,那么我们应该做的是
# 1. 查看 resnet 的源码
# 2. 看最后一层的 名字是啥 (在 resnet 里是 self.fc = nn.Linear(512 * block.expansion, num_classes))
# 3. 在外面替换掉这个层
resnet_model.fc= nn.Linear(in_features=..., out_features=100)
 
# 这样就 哦了,修改后的模型除了输出层的参数是 随机初始化的,其他层都是用预训练的参数初始化的。
 
# 如果只想训练 最后一层的话,应该做的是:
# 1. 将其它层的参数 requires_grad 设置为 False
# 2. 构建一个 optimizer, optimizer 管理的参数只有最后一层的参数
# 3. 然后 backward, step 就可以了
 
# 这一步可以节省大量的时间,因为多数的参数不需要计算梯度
for para in list(resnet_model.parameters())[:-2]:
    para.requires_grad=False 
 
optimizer = optim.SGD(params=[resnet_model.fc.weight, resnet_model.fc.bias], lr=1e-3)
 
...

为什么

这里介绍下 运行resnet_model.fc= nn.Linear(in_features=..., out_features=100)时 框架内发生了什么

这时应该看 nn.Module 源码的 __setattr__ 部分,因为 setattr 时都会调用这个方法:

def __setattr__(self, name, value):
    def remove_from(*dicts):
        for d in dicts:
            if name in d:
                del d[name]

首先映入眼帘就是 remove_from 这个函数,这个函数的目的就是,如果出现了 同名的属性,就将旧的属性移除。 用刚才举的例子就是:

预训练的模型中 有个 名字叫fc 的 Module。

在类定义外,我们 将另一个 Module 重新 赋值给了 fc。

类定义内的 fc 对应的 Module 就会从 模型中 删除。

之二:

前言

这篇文章算是论坛PyTorch Forums关于参数初始化和finetune的总结,也是我在写代码中用的算是“最佳实践”吧。最后希望大家没事多逛逛论坛,有很多高质量的回答。

参数初始化

参数的初始化其实就是对参数赋值。而我们需要学习的参数其实都是Variable,它其实是对Tensor的封装,同时提供了data,grad等借口,这就意味着我们可以直接对这些参数进行操作赋值了。这就是PyTorch简洁高效所在。

pytorch fine-tune 预训练的模型操作

所以我们可以进行如下操作进行初始化,当然其实有其他的方法,但是这种方法是PyTorch作者所推崇的:

def weight_init(m):
# 使用isinstance来判断m属于什么类型
    if isinstance(m, nn.Conv2d):
        n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
        m.weight.data.normal_(0, math.sqrt(2. / n))
    elif isinstance(m, nn.BatchNorm2d):
# m中的weight,bias其实都是Variable,为了能学习参数以及后向传播
        m.weight.data.fill_(1)
        m.bias.data.zero_()

Finetune

往往在加载了预训练模型的参数之后,我们需要finetune模型,可以使用不同的方式finetune。

局部微调

有时候我们加载了训练模型后,只想调节最后的几层,其他层不训练。其实不训练也就意味着不进行梯度计算,PyTorch中提供的requires_grad使得对训练的控制变得非常简单。

model = torchvision.models.resnet18(pretrained=True)
for param in model.parameters():
    param.requires_grad = False
# 替换最后的全连接层, 改为训练100类
# 新构造的模块的参数默认requires_grad为True
model.fc = nn.Linear(512, 100)
 
# 只优化最后的分类层
optimizer = optim.SGD(model.fc.parameters(), lr=1e-2, momentum=0.9)

全局微调

有时候我们需要对全局都进行finetune,只不过我们希望改换过的层和其他层的学习速率不一样,这时候我们可以把其他层和新层在optimizer中单独赋予不同的学习速率。比如:

ignored_params = list(map(id, model.fc.parameters()))
base_params = filter(lambda p: id(p) not in ignored_params,
                     model.parameters())
 
optimizer = torch.optim.SGD([
            {'params': base_params},
            {'params': model.fc.parameters(), 'lr': 1e-3}
            ], lr=1e-2, momentum=0.9)

其中base_params使用1e-3来训练,model.fc.parameters使用1e-2来训练,momentum是二者共有的。

之三:

pytorch finetune模型

文章主要讲述如何在pytorch上读取以往训练的模型参数,在模型的名字已经变更的情况下又如何读取模型的部分参数等。

pytorch 模型的存储与读取

其中在模型的保存过程有存储模型和参数一起的也有单独存储模型参数的

单独存储模型参数

存储时使用:

torch.save(the_model.state_dict(), PATH)

读取时:

the_model = TheModelClass(*args, **kwargs)
the_model.load_state_dict(torch.load(PATH))

存储模型与参数

存储:

torch.save(the_model, PATH)

读取:

the_model = torch.load(PATH)

模型的参数

fine-tune的过程是读取原有模型的参数,但是由于模型的所要处理的数据集不同,最后的一层class的总数不同,所以需要修改模型的最后一层,这样模型读取的参数,和在大数据集上训练好下载的模型参数在形式上不一样。需要我们自己去写函数读取参数。

pytorch模型参数的形式

模型的参数是以字典的形式存储的。

model_dict = the_model.state_dict(),
for k,v in model_dict.items():
    print(k)

即可看到所有的键值

如果想修改模型的参数,给相应的键值赋值即可

model_dict[k] = new_value

最后更新模型的参数

the_model.load_state_dict(model_dict)

如果模型的key值和在大数据集上训练时的key值是一样的

我们可以通过下列算法进行读取模型

model_dict = model.state_dict() 
pretrained_dict = torch.load(model_path)
 # 1. filter out unnecessary keys
diff = {k: v for k, v in model_dict.items() if \
            k in pretrained_dict and pretrained_dict[k].size() == v.size()}
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict and model_dict[k].size() == v.size()}
pretrained_dict.update(diff)
# 2. overwrite entries in the existing state dict
model_dict.update(pretrained_dict)
# 3. load the new state dict
model.load_state_dict(model_dict)

如果模型的key值和在大数据集上训练时的key值是不一样的,但是顺序是一样的

model_dict = model.state_dict() 
pretrained_dict = torch.load(model_path)
keys = []
for k,v in pretrained_dict.items():
    keys.append(k)
i = 0
for k,v in model_dict.items():
    if v.size() == pretrained_dict[keys[i]].size():
        print(k, ',', keys[i])
         model_dict[k]=pretrained_dict[keys[i]]
    i = i + 1
model.load_state_dict(model_dict)

如果模型的key值和在大数据集上训练时的key值是不一样的,但是顺序是也不一样的

自己找对应关系,一个key对应一个key的赋值

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python单元测试框架unittest使用方法讲解
Apr 13 Python
在Python中使用cookielib和urllib2配合PyQuery抓取网页信息
Apr 25 Python
Python实现新浪博客备份的方法
Apr 27 Python
python用reduce和map把字符串转为数字的方法
Dec 19 Python
Python实现获取磁盘剩余空间的2种方法
Jun 07 Python
如何用python写一个简单的词法分析器
Dec 18 Python
对python列表里的字典元素去重方法详解
Jan 21 Python
对python3.4 字符串转16进制的实例详解
Jun 12 Python
python数据预处理 :数据共线性处理详解
Feb 24 Python
python 安装impala包步骤
Mar 28 Python
解决python便携版无法直接运行py文件的问题
Sep 01 Python
 python中的元类metaclass详情
May 30 Python
Python实现byte转integer
Jun 03 #Python
Python数据分析之绘图和可视化详解
Python数据分析之pandas读取数据
Jun 02 #Python
Python 如何实现文件自动去重
python状态机transitions库详解
Jun 02 #Python
python爬取某网站原图作为壁纸
Python爬虫之自动爬取某车之家各车销售数据
You might like
学习php设计模式 php实现状态模式
2015/12/07 PHP
详谈php静态方法及普通方法的区别
2016/10/04 PHP
Thinkphp5结合layer弹窗定制操作结果页面
2017/07/07 PHP
PHP实现的简单sha1加密功能示例
2017/08/27 PHP
通过jQuery源码学习javascript(一)
2012/12/27 Javascript
ExtJs默认的字体大小改变的几种方法(自己整理)
2013/04/18 Javascript
javascript中字符串拼接详解
2014/09/26 Javascript
javascript实现倒计时跳转页面
2016/01/17 Javascript
JavaScript数据结构与算法之集合(Set)
2016/01/29 Javascript
jquery trigger函数执行两次的解决方法
2016/02/29 Javascript
只要1K 纯JS脚本送你一朵3D红色玫瑰
2016/08/09 Javascript
JavaScript实现拖拽元素对齐到网格(每次移动固定距离)
2016/11/30 Javascript
webpack构建vue项目的详细教程(配置篇)
2017/07/17 Javascript
使用node.js对音视频文件加密的实例代码
2017/08/30 Javascript
vue-router3.0版本中 router.push 不能刷新页面的问题
2018/05/10 Javascript
Layui给数据表格动态添加一行并跳转到添加行所在页的方法
2018/08/20 Javascript
基于React Native 0.52实现轮播图效果
2020/08/25 Javascript
vue组件通信传值操作示例
2019/01/08 Javascript
JS使用iView的Dropdown实现一个右键菜单
2019/05/06 Javascript
13 个npm 快速开发技巧(推荐)
2019/07/04 Javascript
Vue-drag-resize 拖拽缩放插件的使用(简单示例)
2019/12/04 Javascript
python益智游戏计算汉诺塔问题示例
2014/03/05 Python
浅谈Python中的闭包
2015/07/08 Python
Python写入数据到MP3文件中的方法
2015/07/10 Python
Python中规范定义命名空间的一些建议
2016/06/04 Python
浅谈python抛出异常、自定义异常, 传递异常
2016/06/20 Python
python编写弹球游戏的实现代码
2018/03/12 Python
Python实现批量读取图片并存入mongodb数据库的方法示例
2018/04/02 Python
Python for i in range ()用法详解
2020/09/18 Python
Python高阶函数、常用内置函数用法实例分析
2019/12/26 Python
Python写捕鱼达人的游戏实现
2020/03/31 Python
幼儿园园长自我鉴定
2013/10/22 职场文书
售前工程师职业生涯规划
2014/03/02 职场文书
2014年大学生四年规划书范文
2014/04/03 职场文书
乡镇挂职心得体会
2014/09/04 职场文书
销售人员管理制度
2015/08/06 职场文书