python PyTorch预训练示例


Posted in Python onFebruary 11, 2018

前言

最近使用PyTorch感觉妙不可言,有种当初使用Keras的快感,而且速度还不慢。各种设计直接简洁,方便研究,比tensorflow的臃肿好多了。今天让我们来谈谈PyTorch的预训练,主要是自己写代码的经验以及论坛PyTorch Forums上的一些回答的总结整理。

直接加载预训练模型

如果我们使用的模型和原模型完全一样,那么我们可以直接加载别人训练好的模型:

my_resnet = MyResNet(*args, **kwargs)
my_resnet.load_state_dict(torch.load("my_resnet.pth"))

当然这样的加载方法是基于PyTorch推荐的存储模型的方法:

torch.save(my_resnet.state_dict(), "my_resnet.pth")

还有第二种加载方法:

my_resnet = torch.load("my_resnet.pth")

加载部分预训练模型

其实大多数时候我们需要根据我们的任务调节我们的模型,所以很难保证模型和公开的模型完全一样,但是预训练模型的参数确实有助于提高训练的准确率,为了结合二者的优点,就需要我们加载部分预训练模型。

pretrained_dict = model_zoo.load_url(model_urls['resnet152'])
model_dict = model.state_dict()
# 将pretrained_dict里不属于model_dict的键剔除掉
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
# 更新现有的model_dict
model_dict.update(pretrained_dict)
# 加载我们真正需要的state_dict
model.load_state_dict(model_dict)

因为需要剔除原模型中不匹配的键,也就是层的名字,所以我们的新模型改变了的层需要和原模型对应层的名字不一样,比如:resnet最后一层的名字是fc(PyTorch中),那么我们修改过的resnet的最后一层就不能取这个名字,可以叫fc_

微改基础模型预训练

对于改动比较大的模型,我们可能需要自己实现一下再加载别人的预训练参数。但是,对于一些基本模型PyTorch中已经有了,而且我只想进行一些小的改动那么怎么办呢?难道我又去实现一遍吗?当然不是。

我们首先看看怎么进行微改模型。

微改基础模型

PyTorch中的torchvision里已经有很多常用的模型了,可以直接调用:

  1. AlexNet
  2. VGG
  3. ResNet
  4. SqueezeNet
  5. DenseNet
import torchvision.models as models

resnet18 = models.resnet18()
alexnet = models.alexnet()
squeezenet = models.squeezenet1_0()
densenet = models.densenet_161()

但是对于我们的任务而言有些层并不是直接能用,需要我们微微改一下,比如,resnet最后的全连接层是分1000类,而我们只有21类;又比如,resnet第一层卷积接收的通道是3, 我们可能输入图片的通道是4,那么可以通过以下方法修改:

resnet.conv1 = nn.Conv2d(4, 64, kernel_size=7, stride=2, padding=3, bias=False)
resnet.fc = nn.Linear(2048, 21)

简单预训练

模型已经改完了,接下来我们就进行简单预训练吧。

我们先从torchvision中调用基本模型,加载预训练模型,然后,重点来了,将其中的层直接替换为我们需要的层即可:

resnet = torchvision.models.resnet152(pretrained=True)
# 原本为1000类,改为10类
resnet.fc = torch.nn.Linear(2048, 10)

其中使用了pretrained参数,会直接加载预训练模型,内部实现和前文提到的加载预训练的方法一样。因为是先加载的预训练参数,相当于模型中已经有参数了,所以替换掉最后一层即可。OK!

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python中变量交换的例子
Aug 25 Python
Python获取Windows或Linux主机名称通用函数分享
Nov 22 Python
剖析Django中模版标签的解析与参数传递
Jul 21 Python
python与sqlite3实现解密chrome cookie实例代码
Jan 20 Python
TensorFlow入门使用 tf.train.Saver()保存模型
Apr 24 Python
python实现守护进程、守护线程、守护非守护并行
May 05 Python
对Python中for复合语句的使用示例讲解
Nov 01 Python
python接口调用已训练好的caffe模型测试分类方法
Aug 26 Python
如何基于python操作excel并获取内容
Dec 24 Python
Python新手学习函数默认参数设置
Jun 03 Python
Pytorch实验常用代码段汇总
Nov 19 Python
Python机器学习实战之k-近邻算法的实现
Nov 27 Python
TensorFlow中权重的随机初始化的方法
Feb 11 #Python
python的staticmethod与classmethod实现实例代码
Feb 11 #Python
Python语言的变量认识及操作方法
Feb 11 #Python
利用Opencv中Houghline方法实现直线检测
Feb 11 #Python
tensorflow输出权重值和偏差的方法
Feb 10 #Python
详解tensorflow实现迁移学习实例
Feb 10 #Python
Python学习之Django的管理界面代码示例
Feb 10 #Python
You might like
一个MYSQL操作类
2006/11/16 PHP
ThinkPHP 3.2 数据分页代码分享
2014/10/14 PHP
PHP实现扎金花游戏之大小比赛的方法
2015/03/10 PHP
php生成4位数字验证码的实现代码
2015/11/23 PHP
PHP实现图片不变型裁剪及图片按比例裁剪的方法
2016/01/14 PHP
在PHP语言中使用JSON和将json还原成数组的方法
2016/07/19 PHP
通用于ie和firefox的函数 GetCurrentStyle (obj, prop)
2006/12/27 Javascript
基于jQuery实现文本框缩放以及上下移动功能
2014/11/24 Javascript
JavaScript基础函数整理汇总
2015/01/30 Javascript
JS实现DIV容器赋值的方法
2015/12/14 Javascript
理解js对象继承的N种模式
2016/01/25 Javascript
详解nodejs微信公众号开发——2.自动回复
2017/04/10 NodeJs
vue项目base64字符串转图片的实现代码
2018/07/13 Javascript
微信小程序上传文件到阿里OSS教程
2019/05/20 Javascript
javascript使用substring实现的展开与收缩文字功能示例
2019/06/17 Javascript
JS中async/await实现异步调用的方法
2019/08/28 Javascript
js实现从右往左匀速显示图片(无缝轮播)
2020/06/29 Javascript
vuex的使用步骤
2021/01/06 Vue.js
Python实现扫描局域网活动ip(扫描在线电脑)
2015/04/28 Python
Python中Continue语句的用法的举例详解
2015/05/14 Python
python实现域名系统(DNS)正向查询的方法
2016/04/19 Python
python实现单向链表详解
2018/02/08 Python
详解python实现识别手写MNIST数字集的程序
2018/08/03 Python
详解Python中的内建函数,可迭代对象,迭代器
2019/04/29 Python
python实现读取excel文件中所有sheet操作示例
2019/08/09 Python
Python数据可视化:幂律分布实例详解
2019/12/07 Python
美国和加拿大房车出售在线分类广告:RVT.com
2018/04/23 全球购物
专业幼师实习生自我鉴定范文
2013/12/08 职场文书
致铅球运动员加油稿
2014/02/13 职场文书
陈欧广告词
2014/03/14 职场文书
建议书的格式
2014/05/12 职场文书
法院执行局工作总结
2015/08/11 职场文书
导游词之太行山青龙峡
2020/01/14 职场文书
windows下快速安装nginx并配置开机自启动的方法
2021/05/11 Servers
Dubbo+zookeeper搭配分布式服务的过程详解
2022/04/03 Java/Android
开发者首先否认《遗弃》被取消的传言
2022/04/11 其他游戏