pytorch fine-tune 预训练的模型操作


Posted in Python onJune 03, 2021

之一:

torchvision 中包含了很多预训练好的模型,这样就使得 fine-tune 非常容易。本文主要介绍如何 fine-tune torchvision 中预训练好的模型。

安装

pip install torchvision

如何 fine-tune

以 resnet18 为例:

from torchvision import models
from torch import nn
from torch import optim
 
resnet_model = models.resnet18(pretrained=True) 
# pretrained 设置为 True,会自动下载模型 所对应权重,并加载到模型中
# 也可以自己下载 权重,然后 load 到 模型中,源码中有 权重的地址。
 
# 假设 我们的 分类任务只需要 分 100 类,那么我们应该做的是
# 1. 查看 resnet 的源码
# 2. 看最后一层的 名字是啥 (在 resnet 里是 self.fc = nn.Linear(512 * block.expansion, num_classes))
# 3. 在外面替换掉这个层
resnet_model.fc= nn.Linear(in_features=..., out_features=100)
 
# 这样就 哦了,修改后的模型除了输出层的参数是 随机初始化的,其他层都是用预训练的参数初始化的。
 
# 如果只想训练 最后一层的话,应该做的是:
# 1. 将其它层的参数 requires_grad 设置为 False
# 2. 构建一个 optimizer, optimizer 管理的参数只有最后一层的参数
# 3. 然后 backward, step 就可以了
 
# 这一步可以节省大量的时间,因为多数的参数不需要计算梯度
for para in list(resnet_model.parameters())[:-2]:
    para.requires_grad=False 
 
optimizer = optim.SGD(params=[resnet_model.fc.weight, resnet_model.fc.bias], lr=1e-3)
 
...

为什么

这里介绍下 运行resnet_model.fc= nn.Linear(in_features=..., out_features=100)时 框架内发生了什么

这时应该看 nn.Module 源码的 __setattr__ 部分,因为 setattr 时都会调用这个方法:

def __setattr__(self, name, value):
    def remove_from(*dicts):
        for d in dicts:
            if name in d:
                del d[name]

首先映入眼帘就是 remove_from 这个函数,这个函数的目的就是,如果出现了 同名的属性,就将旧的属性移除。 用刚才举的例子就是:

预训练的模型中 有个 名字叫fc 的 Module。

在类定义外,我们 将另一个 Module 重新 赋值给了 fc。

类定义内的 fc 对应的 Module 就会从 模型中 删除。

之二:

前言

这篇文章算是论坛PyTorch Forums关于参数初始化和finetune的总结,也是我在写代码中用的算是“最佳实践”吧。最后希望大家没事多逛逛论坛,有很多高质量的回答。

参数初始化

参数的初始化其实就是对参数赋值。而我们需要学习的参数其实都是Variable,它其实是对Tensor的封装,同时提供了data,grad等借口,这就意味着我们可以直接对这些参数进行操作赋值了。这就是PyTorch简洁高效所在。

pytorch fine-tune 预训练的模型操作

所以我们可以进行如下操作进行初始化,当然其实有其他的方法,但是这种方法是PyTorch作者所推崇的:

def weight_init(m):
# 使用isinstance来判断m属于什么类型
    if isinstance(m, nn.Conv2d):
        n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
        m.weight.data.normal_(0, math.sqrt(2. / n))
    elif isinstance(m, nn.BatchNorm2d):
# m中的weight,bias其实都是Variable,为了能学习参数以及后向传播
        m.weight.data.fill_(1)
        m.bias.data.zero_()

Finetune

往往在加载了预训练模型的参数之后,我们需要finetune模型,可以使用不同的方式finetune。

局部微调

有时候我们加载了训练模型后,只想调节最后的几层,其他层不训练。其实不训练也就意味着不进行梯度计算,PyTorch中提供的requires_grad使得对训练的控制变得非常简单。

model = torchvision.models.resnet18(pretrained=True)
for param in model.parameters():
    param.requires_grad = False
# 替换最后的全连接层, 改为训练100类
# 新构造的模块的参数默认requires_grad为True
model.fc = nn.Linear(512, 100)
 
# 只优化最后的分类层
optimizer = optim.SGD(model.fc.parameters(), lr=1e-2, momentum=0.9)

全局微调

有时候我们需要对全局都进行finetune,只不过我们希望改换过的层和其他层的学习速率不一样,这时候我们可以把其他层和新层在optimizer中单独赋予不同的学习速率。比如:

ignored_params = list(map(id, model.fc.parameters()))
base_params = filter(lambda p: id(p) not in ignored_params,
                     model.parameters())
 
optimizer = torch.optim.SGD([
            {'params': base_params},
            {'params': model.fc.parameters(), 'lr': 1e-3}
            ], lr=1e-2, momentum=0.9)

其中base_params使用1e-3来训练,model.fc.parameters使用1e-2来训练,momentum是二者共有的。

之三:

pytorch finetune模型

文章主要讲述如何在pytorch上读取以往训练的模型参数,在模型的名字已经变更的情况下又如何读取模型的部分参数等。

pytorch 模型的存储与读取

其中在模型的保存过程有存储模型和参数一起的也有单独存储模型参数的

单独存储模型参数

存储时使用:

torch.save(the_model.state_dict(), PATH)

读取时:

the_model = TheModelClass(*args, **kwargs)
the_model.load_state_dict(torch.load(PATH))

存储模型与参数

存储:

torch.save(the_model, PATH)

读取:

the_model = torch.load(PATH)

模型的参数

fine-tune的过程是读取原有模型的参数,但是由于模型的所要处理的数据集不同,最后的一层class的总数不同,所以需要修改模型的最后一层,这样模型读取的参数,和在大数据集上训练好下载的模型参数在形式上不一样。需要我们自己去写函数读取参数。

pytorch模型参数的形式

模型的参数是以字典的形式存储的。

model_dict = the_model.state_dict(),
for k,v in model_dict.items():
    print(k)

即可看到所有的键值

如果想修改模型的参数,给相应的键值赋值即可

model_dict[k] = new_value

最后更新模型的参数

the_model.load_state_dict(model_dict)

如果模型的key值和在大数据集上训练时的key值是一样的

我们可以通过下列算法进行读取模型

model_dict = model.state_dict() 
pretrained_dict = torch.load(model_path)
 # 1. filter out unnecessary keys
diff = {k: v for k, v in model_dict.items() if \
            k in pretrained_dict and pretrained_dict[k].size() == v.size()}
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict and model_dict[k].size() == v.size()}
pretrained_dict.update(diff)
# 2. overwrite entries in the existing state dict
model_dict.update(pretrained_dict)
# 3. load the new state dict
model.load_state_dict(model_dict)

如果模型的key值和在大数据集上训练时的key值是不一样的,但是顺序是一样的

model_dict = model.state_dict() 
pretrained_dict = torch.load(model_path)
keys = []
for k,v in pretrained_dict.items():
    keys.append(k)
i = 0
for k,v in model_dict.items():
    if v.size() == pretrained_dict[keys[i]].size():
        print(k, ',', keys[i])
         model_dict[k]=pretrained_dict[keys[i]]
    i = i + 1
model.load_state_dict(model_dict)

如果模型的key值和在大数据集上训练时的key值是不一样的,但是顺序是也不一样的

自己找对应关系,一个key对应一个key的赋值

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python client使用http post 到server端的代码
Feb 10 Python
用Python实现一个简单的线程池
Apr 07 Python
python执行get提交的方法
Apr 29 Python
django使用图片延时加载引起后台404错误
Apr 18 Python
Python使用scipy模块实现一维卷积运算示例
Sep 05 Python
python如何实现单链表的反转
Feb 10 Python
Python sep参数使用方法详解
Feb 12 Python
python实现五子棋程序
Apr 24 Python
python实现b站直播自动发送弹幕功能
Feb 20 Python
pytest进阶教程之fixture函数详解
Mar 29 Python
解决hive中导入text文件遇到的坑
Apr 07 Python
python 如何用map()函数创建多线程任务
Apr 07 Python
Python实现byte转integer
Jun 03 #Python
Python数据分析之绘图和可视化详解
Python数据分析之pandas读取数据
Jun 02 #Python
Python 如何实现文件自动去重
python状态机transitions库详解
Jun 02 #Python
python爬取某网站原图作为壁纸
Python爬虫之自动爬取某车之家各车销售数据
You might like
基于header的一些常用指令详解
2013/06/06 PHP
php判断手机访问还是电脑访问示例分享
2014/01/20 PHP
PHP基于curl post实现发送url及相关中文乱码问题解决方法
2017/11/25 PHP
javascript 清除输入框中的数据
2009/04/13 Javascript
javascript中字符串拼接需注意的问题
2010/07/13 Javascript
node.js中的Socket.IO使用实例
2014/11/04 Javascript
Java遍历集合方法分析(实现原理、算法性能、适用场合)
2016/04/25 Javascript
jQuery Mobile操作HTML5的常用函数总结
2016/05/17 Javascript
【经典源码收藏】jQuery实用代码片段(筛选,搜索,样式,清除默认值,多选等)
2016/06/07 Javascript
jQuery中JSONP的两种实现方式详解
2016/09/26 Javascript
javascript的document中的动态添加标签实现方法
2016/10/24 Javascript
js实现适合新闻类图片的轮播效果
2017/02/05 Javascript
浅谈$_FILES数组为空的原因
2017/02/16 Javascript
微信小程序 点击控件后选中其它反选实例详解
2017/02/21 Javascript
jQuery插件echarts实现的单折线图效果示例【附demo源码下载】
2017/03/04 Javascript
AngularJS+bootstrap实现动态选择商品功能示例
2017/05/17 Javascript
微信小程序使用progress组件实现显示进度功能【附源码下载】
2017/12/12 Javascript
Python中easy_install 和 pip 的安装及使用
2017/06/05 Python
python做量化投资系列之比特币初始配置
2018/01/23 Python
python学生信息管理系统
2018/03/13 Python
Python爬虫 批量爬取下载抖音视频代码实例
2019/08/16 Python
python GUI库图形界面开发之PyQt5窗口控件QWidget详细使用方法
2020/02/26 Python
如何将PySpark导入Python的放实现(2种)
2020/04/26 Python
Python基于Twilio及腾讯云实现国际国内短信接口
2020/06/18 Python
利用Python中的Xpath实现一个在线汇率转换器
2020/09/09 Python
HTML5里的placeholder属性使用实例和美化显示效果的方法
2014/04/23 HTML / CSS
学校卫生检查制度
2014/02/03 职场文书
英语老师推荐信
2014/02/26 职场文书
解除合同协议书
2014/04/17 职场文书
合作协议书范文
2014/08/20 职场文书
党性分析自查总结
2014/10/14 职场文书
杭白菊导游词
2015/02/10 职场文书
2014年终个人总结报告
2015/03/09 职场文书
经费申请报告范文
2015/05/18 职场文书
会议承办单位欢迎词
2015/09/30 职场文书
python 实现两个变量值进行交换的n种操作
2021/06/02 Python