PyTorch预训练的实现


Posted in Python onSeptember 18, 2019

前言

最近使用PyTorch感觉妙不可言,有种当初使用Keras的快感,而且速度还不慢。各种设计直接简洁,方便研究,比tensorflow的臃肿好多了。今天让我们来谈谈PyTorch的预训练,主要是自己写代码的经验以及论坛PyTorch Forums上的一些回答的总结整理。

直接加载预训练模型

如果我们使用的模型和原模型完全一样,那么我们可以直接加载别人训练好的模型:

my_resnet = MyResNet(*args, **kwargs)
my_resnet.load_state_dict(torch.load("my_resnet.pth"))

当然这样的加载方法是基于PyTorch推荐的存储模型的方法:

torch.save(my_resnet.state_dict(), "my_resnet.pth")

还有第二种加载方法:

my_resnet = torch.load("my_resnet.pth")

加载部分预训练模型

其实大多数时候我们需要根据我们的任务调节我们的模型,所以很难保证模型和公开的模型完全一样,但是预训练模型的参数确实有助于提高训练的准确率,为了结合二者的优点,就需要我们加载部分预训练模型。

pretrained_dict = model_zoo.load_url(model_urls['resnet152'])
model_dict = model.state_dict()
# 将pretrained_dict里不属于model_dict的键剔除掉
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
# 更新现有的model_dict
model_dict.update(pretrained_dict)
# 加载我们真正需要的state_dict
model.load_state_dict(model_dict)

因为需要剔除原模型中不匹配的键,也就是层的名字,所以我们的新模型改变了的层需要和原模型对应层的名字不一样,比如:resnet最后一层的名字是fc(PyTorch中),那么我们修改过的resnet的最后一层就不能取这个名字,可以叫fc_

微改基础模型预训练

对于改动比较大的模型,我们可能需要自己实现一下再加载别人的预训练参数。但是,对于一些基本模型PyTorch中已经有了,而且我只想进行一些小的改动那么怎么办呢?难道我又去实现一遍吗?当然不是。

我们首先看看怎么进行微改模型。

微改基础模型

PyTorch中的torchvision里已经有很多常用的模型了,可以直接调用:

  • AlexNet
  • VGG
  • ResNet
  • SqueezeNet
  • DenseNet
import torchvision.models as models
resnet18 = models.resnet18()
alexnet = models.alexnet()
squeezenet = models.squeezenet1_0()
densenet = models.densenet_161()

但是对于我们的任务而言有些层并不是直接能用,需要我们微微改一下,比如,resnet最后的全连接层是分1000类,而我们只有21类;又比如,resnet第一层卷积接收的通道是3, 我们可能输入图片的通道是4,那么可以通过以下方法修改:

resnet.conv1 = nn.Conv2d(4, 64, kernel_size=7, stride=2, padding=3, bias=False)
resnet.fc = nn.Linear(2048, 21)

简单预训练

模型已经改完了,接下来我们就进行简单预训练吧。
我们先从torchvision中调用基本模型,加载预训练模型,然后,重点来了,将其中的层直接替换为我们需要的层即可:

resnet = torchvision.models.resnet152(pretrained=True)
# 原本为1000类,改为10类
resnet.fc = torch.nn.Linear(2048, 10)

其中使用了pretrained参数,会直接加载预训练模型,内部实现和前文提到的加载预训练的方法一样。因为是先加载的预训练参数,相当于模型中已经有参数了,所以替换掉最后一层即可。OK!

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python实现Const详解
Jan 27 Python
python中list列表的高级函数
May 17 Python
详解python里使用正则表达式的全匹配功能
Oct 19 Python
numpy.delete删除一列或多列的方法
Apr 03 Python
如何优雅地改进Django中的模板碎片缓存详解
Jul 04 Python
Python3 列表,数组,矩阵的相互转换的方法示例
Aug 05 Python
Python 利用高德地图api实现经纬度与地址的批量转换
Aug 14 Python
Python3 实现爬取网站下所有URL方式
Jan 16 Python
tensorflow ckpt模型和pb模型获取节点名称,及ckpt转pb模型实例
Jan 21 Python
Python 如何实现访问者模式
Jul 28 Python
python如何使用腾讯云发送短信
Sep 17 Python
scrapy-redis分布式爬虫的搭建过程(理论篇)
Sep 29 Python
用python实现英文字母和相应序数转换的方法
Sep 18 #Python
Django模板导入母版继承和自定义返回Html片段过程解析
Sep 18 #Python
Python爬虫图片懒加载技术 selenium和PhantomJS解析
Sep 18 #Python
python rsa实现数据加密和解密、签名加密和验签功能
Sep 18 #Python
决策树剪枝算法的python实现方法详解
Sep 18 #Python
python生成requirements.txt的两种方法
Sep 18 #Python
python2与python3爬虫中get与post对比解析
Sep 18 #Python
You might like
php 获取完整url地址
2008/12/20 PHP
php 将excel导入mysql
2009/11/09 PHP
PHP中的float类型使用说明
2010/07/27 PHP
PHP实现将视频转成MP4并获取视频预览图的方法
2015/03/12 PHP
javascript parseInt 函数分析(转)
2009/03/21 Javascript
javascript函数中的arguments参数
2010/08/01 Javascript
单击按钮显示隐藏子菜单经典案例
2013/01/04 Javascript
JavaScript声明变量时为什么要加var关键字
2014/09/29 Javascript
javascript原生ajax写法分享
2016/04/10 Javascript
深入理解jQuery中的事件冒泡
2016/05/24 Javascript
清空元素html("") innerHTML="" 与 empty()的区别和应用(推荐)
2017/08/14 Javascript
快速处理vue渲染前的显示问题
2018/03/05 Javascript
在 Vue-CLI 中引入 simple-mock实现简易的 API Mock 接口数据模拟
2018/11/28 Javascript
vue.js使用v-model实现表单元素(input) 双向数据绑定功能示例
2019/03/08 Javascript
vue+webpack dev本地调试全局样式引用失效的解决方案
2019/11/12 Javascript
微信小程序利用for循环解决内容变更问题
2020/03/05 Javascript
echarts实现获取datazoom的起始值(包括x轴和y轴)
2020/07/20 Javascript
如何通过vscode运行调试javascript代码
2020/07/24 Javascript
Jquery $.map使用方法实例详解
2020/09/01 jQuery
JavaScript arguments.callee作用及替换方案详解
2020/09/02 Javascript
JavaScript 常见的继承方式汇总
2020/09/17 Javascript
Python运算符重载用法实例
2015/05/28 Python
Django框架多表查询实例分析
2018/07/04 Python
对python借助百度云API对评论进行观点抽取的方法详解
2019/02/21 Python
Python从函数参数类型引出元组实例分析
2019/05/28 Python
在pytorch 中计算精度、回归率、F1 score等指标的实例
2020/01/18 Python
python 调整图片亮度的示例
2020/12/03 Python
HTML5之WebGL 3D概述(上)—WebGL原生开发开启网页3D渲染新时代
2013/01/31 HTML / CSS
阿提哈德航空官方网站:Etihad Airways
2017/01/06 全球购物
.TTL是什么?有什么用处,通常那些工具会用到它?(ping? traceroute? ifconfig? netstat?)
2016/05/09 面试题
会计毕业生自我鉴定
2013/11/04 职场文书
七一党建活动方案
2014/01/28 职场文书
出纳员岗位职责风险
2014/03/06 职场文书
涨价通知怎么写
2015/04/23 职场文书
小学数学教学随笔
2015/08/14 职场文书
2019年大学生职业生涯规划书最新范文
2019/03/25 职场文书