pytorch fine-tune 预训练的模型操作


Posted in Python onJune 03, 2021

之一:

torchvision 中包含了很多预训练好的模型,这样就使得 fine-tune 非常容易。本文主要介绍如何 fine-tune torchvision 中预训练好的模型。

安装

pip install torchvision

如何 fine-tune

以 resnet18 为例:

from torchvision import models
from torch import nn
from torch import optim
 
resnet_model = models.resnet18(pretrained=True) 
# pretrained 设置为 True,会自动下载模型 所对应权重,并加载到模型中
# 也可以自己下载 权重,然后 load 到 模型中,源码中有 权重的地址。
 
# 假设 我们的 分类任务只需要 分 100 类,那么我们应该做的是
# 1. 查看 resnet 的源码
# 2. 看最后一层的 名字是啥 (在 resnet 里是 self.fc = nn.Linear(512 * block.expansion, num_classes))
# 3. 在外面替换掉这个层
resnet_model.fc= nn.Linear(in_features=..., out_features=100)
 
# 这样就 哦了,修改后的模型除了输出层的参数是 随机初始化的,其他层都是用预训练的参数初始化的。
 
# 如果只想训练 最后一层的话,应该做的是:
# 1. 将其它层的参数 requires_grad 设置为 False
# 2. 构建一个 optimizer, optimizer 管理的参数只有最后一层的参数
# 3. 然后 backward, step 就可以了
 
# 这一步可以节省大量的时间,因为多数的参数不需要计算梯度
for para in list(resnet_model.parameters())[:-2]:
    para.requires_grad=False 
 
optimizer = optim.SGD(params=[resnet_model.fc.weight, resnet_model.fc.bias], lr=1e-3)
 
...

为什么

这里介绍下 运行resnet_model.fc= nn.Linear(in_features=..., out_features=100)时 框架内发生了什么

这时应该看 nn.Module 源码的 __setattr__ 部分,因为 setattr 时都会调用这个方法:

def __setattr__(self, name, value):
    def remove_from(*dicts):
        for d in dicts:
            if name in d:
                del d[name]

首先映入眼帘就是 remove_from 这个函数,这个函数的目的就是,如果出现了 同名的属性,就将旧的属性移除。 用刚才举的例子就是:

预训练的模型中 有个 名字叫fc 的 Module。

在类定义外,我们 将另一个 Module 重新 赋值给了 fc。

类定义内的 fc 对应的 Module 就会从 模型中 删除。

之二:

前言

这篇文章算是论坛PyTorch Forums关于参数初始化和finetune的总结,也是我在写代码中用的算是“最佳实践”吧。最后希望大家没事多逛逛论坛,有很多高质量的回答。

参数初始化

参数的初始化其实就是对参数赋值。而我们需要学习的参数其实都是Variable,它其实是对Tensor的封装,同时提供了data,grad等借口,这就意味着我们可以直接对这些参数进行操作赋值了。这就是PyTorch简洁高效所在。

pytorch fine-tune 预训练的模型操作

所以我们可以进行如下操作进行初始化,当然其实有其他的方法,但是这种方法是PyTorch作者所推崇的:

def weight_init(m):
# 使用isinstance来判断m属于什么类型
    if isinstance(m, nn.Conv2d):
        n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
        m.weight.data.normal_(0, math.sqrt(2. / n))
    elif isinstance(m, nn.BatchNorm2d):
# m中的weight,bias其实都是Variable,为了能学习参数以及后向传播
        m.weight.data.fill_(1)
        m.bias.data.zero_()

Finetune

往往在加载了预训练模型的参数之后,我们需要finetune模型,可以使用不同的方式finetune。

局部微调

有时候我们加载了训练模型后,只想调节最后的几层,其他层不训练。其实不训练也就意味着不进行梯度计算,PyTorch中提供的requires_grad使得对训练的控制变得非常简单。

model = torchvision.models.resnet18(pretrained=True)
for param in model.parameters():
    param.requires_grad = False
# 替换最后的全连接层, 改为训练100类
# 新构造的模块的参数默认requires_grad为True
model.fc = nn.Linear(512, 100)
 
# 只优化最后的分类层
optimizer = optim.SGD(model.fc.parameters(), lr=1e-2, momentum=0.9)

全局微调

有时候我们需要对全局都进行finetune,只不过我们希望改换过的层和其他层的学习速率不一样,这时候我们可以把其他层和新层在optimizer中单独赋予不同的学习速率。比如:

ignored_params = list(map(id, model.fc.parameters()))
base_params = filter(lambda p: id(p) not in ignored_params,
                     model.parameters())
 
optimizer = torch.optim.SGD([
            {'params': base_params},
            {'params': model.fc.parameters(), 'lr': 1e-3}
            ], lr=1e-2, momentum=0.9)

其中base_params使用1e-3来训练,model.fc.parameters使用1e-2来训练,momentum是二者共有的。

之三:

pytorch finetune模型

文章主要讲述如何在pytorch上读取以往训练的模型参数,在模型的名字已经变更的情况下又如何读取模型的部分参数等。

pytorch 模型的存储与读取

其中在模型的保存过程有存储模型和参数一起的也有单独存储模型参数的

单独存储模型参数

存储时使用:

torch.save(the_model.state_dict(), PATH)

读取时:

the_model = TheModelClass(*args, **kwargs)
the_model.load_state_dict(torch.load(PATH))

存储模型与参数

存储:

torch.save(the_model, PATH)

读取:

the_model = torch.load(PATH)

模型的参数

fine-tune的过程是读取原有模型的参数,但是由于模型的所要处理的数据集不同,最后的一层class的总数不同,所以需要修改模型的最后一层,这样模型读取的参数,和在大数据集上训练好下载的模型参数在形式上不一样。需要我们自己去写函数读取参数。

pytorch模型参数的形式

模型的参数是以字典的形式存储的。

model_dict = the_model.state_dict(),
for k,v in model_dict.items():
    print(k)

即可看到所有的键值

如果想修改模型的参数,给相应的键值赋值即可

model_dict[k] = new_value

最后更新模型的参数

the_model.load_state_dict(model_dict)

如果模型的key值和在大数据集上训练时的key值是一样的

我们可以通过下列算法进行读取模型

model_dict = model.state_dict() 
pretrained_dict = torch.load(model_path)
 # 1. filter out unnecessary keys
diff = {k: v for k, v in model_dict.items() if \
            k in pretrained_dict and pretrained_dict[k].size() == v.size()}
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict and model_dict[k].size() == v.size()}
pretrained_dict.update(diff)
# 2. overwrite entries in the existing state dict
model_dict.update(pretrained_dict)
# 3. load the new state dict
model.load_state_dict(model_dict)

如果模型的key值和在大数据集上训练时的key值是不一样的,但是顺序是一样的

model_dict = model.state_dict() 
pretrained_dict = torch.load(model_path)
keys = []
for k,v in pretrained_dict.items():
    keys.append(k)
i = 0
for k,v in model_dict.items():
    if v.size() == pretrained_dict[keys[i]].size():
        print(k, ',', keys[i])
         model_dict[k]=pretrained_dict[keys[i]]
    i = i + 1
model.load_state_dict(model_dict)

如果模型的key值和在大数据集上训练时的key值是不一样的,但是顺序是也不一样的

自己找对应关系,一个key对应一个key的赋值

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
在Python中使用判断语句和循环的教程
Apr 25 Python
python定时检查某个进程是否已经关闭的方法
May 20 Python
Python的Flask框架中使用Flask-Migrate扩展迁移数据库的教程
Jun 14 Python
python对json的相关操作实例详解
Jan 04 Python
python实现Virginia无密钥解密
Mar 20 Python
pandas条件组合筛选和按范围筛选的示例代码
Aug 26 Python
Django框架 信号调度原理解析
Sep 04 Python
python中property和setter装饰器用法
Dec 19 Python
python小技巧——将变量保存在本地及读取
Nov 13 Python
python 模块导入问题汇总
Feb 01 Python
用Python爬虫破解滑动验证码的案例解析
May 06 Python
Python使用openpyxl模块处理Excel文件
Jun 05 Python
Python实现byte转integer
Jun 03 #Python
Python数据分析之绘图和可视化详解
Python数据分析之pandas读取数据
Jun 02 #Python
Python 如何实现文件自动去重
python状态机transitions库详解
Jun 02 #Python
python爬取某网站原图作为壁纸
Python爬虫之自动爬取某车之家各车销售数据
You might like
上海永华YH-R296(华普R-96)12波段立体声收音机的分析和打理
2021/03/02 无线电
PHP+MySQL删除操作实例
2015/01/21 PHP
PHP记录和读取JSON格式日志文件
2016/07/07 PHP
PHP中的多种加密技术及代码示例解析
2016/10/20 PHP
jquery中ajax学习笔记3
2011/10/16 Javascript
Knockout数组(observable)使用详解示例
2013/11/15 Javascript
JS模拟实现Select效果代码
2015/09/24 Javascript
jquery实现倒计时效果
2015/12/14 Javascript
JS实现简单的二维矩阵乘积运算
2016/01/26 Javascript
jQuery实现带延时功能的水平多级菜单效果【附demo源码下载】
2016/09/21 Javascript
Vue.js绑定HTML class数组语法错误的原因分析
2016/10/19 Javascript
Vue 进阶教程之v-model详解
2017/05/06 Javascript
vue.js+Echarts开发图表放大缩小功能实例
2017/06/09 Javascript
Vue.js列表渲染绑定jQuery插件的正确姿势
2017/06/29 jQuery
基于JavaScript实现无缝滚动效果
2017/07/21 Javascript
nodejs前端模板引擎swig入门详解
2018/05/15 NodeJs
微信小程序网络封装(简单高效)
2018/08/06 Javascript
详解vue中使用axios对同一个接口连续请求导致返回数据混乱的问题
2019/11/06 Javascript
element-ui 实现响应式导航栏的示例代码
2020/05/08 Javascript
Vue scoped及deep使用方法解析
2020/08/01 Javascript
Python与shell的3种交互方式介绍
2015/04/11 Python
python实现查找两个字符串中相同字符并输出的方法
2015/07/11 Python
Python异步编程之协程任务的调度操作实例分析
2020/02/01 Python
解决Django no such table: django_session的问题
2020/04/07 Python
python基于pygame实现飞机大作战小游戏
2020/11/19 Python
HTML5高仿微信聊天、微信聊天表情|对话框|编辑器功能
2018/04/23 HTML / CSS
英国著名音像制品和图书游戏购物网站:Zavvi
2016/08/04 全球购物
Marriott国际:万豪国际酒店查询预订
2017/09/25 全球购物
美国电子元器件分销商:Newark element14
2018/01/13 全球购物
科颜氏香港官方网店:Kiehl’s香港
2021/03/07 全球购物
班级入场式解说词
2014/02/01 职场文书
酒店开业策划方案
2014/06/02 职场文书
华山导游词
2015/02/03 职场文书
2015年档案管理工作总结
2015/04/08 职场文书
超强台风观后感
2015/06/09 职场文书
小学六年级班主任工作经验交流材料
2015/11/02 职场文书