Pytorch加载部分预训练模型的参数实例


Posted in Python onAugust 18, 2019

前言

自从从深度学习框架caffe转到Pytorch之后,感觉Pytorch的优点妙不可言,各种设计简洁,方便研究网络结构修改,容易上手,比TensorFlow的臃肿好多了。对于深度学习的初学者,Pytorch值得推荐。今天主要主要谈谈Pytorch是如何加载预训练模型的参数以及代码的实现过程。

直接加载预选脸模型

如果我们使用的模型和预训练模型完全一样,那么我们就可以直接加载别人的模型,还有一种情况,我们在训练自己模型的过程中,突然中断了,但只要我们保存了之前的模型的参数也可以使用下面的代码直接加载我们保存的模型继续训练,不用从头开始。

model=DPN(*args, **kwargs)
model.load_state_dict(torch.load("DPN.pth"))

这样的加载方式是基于Pytorch使用的模型存储方法:

torch.save(DPN.state_dict(), "DPN.pth")

加载部分预训练模型参数

其实大多数时候我们根据自己的任物所提出的模型是在一些公开模型的基础上改变而来,其中公开模型的参数我们没有必要在从头开始训练,只要加载其训练好的模型参数即可,这样有助于提高训练的准确率和我们模型的泛化能力。

model = DPN(num_init_features=64, k_R=96, G=32, k_sec=(3,4,20,3), inc_sec=(16,32,24,128), num_classes=1,decoder=args.decoder)
 http = {'url': 'http://data.lip6.fr/cadene/pretrainedmodels/dpn92_extra-b040e4a9b.pth'}
 pretrained_dict=model_zoo.load_url(http['url'])
 model_dict = model.state_dict()
 pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}#filter out unnecessary keys 
 model_dict.update(pretrained_dict)
 model.load_state_dict(model_dict)
 model = torch.nn.DataParallel(model).cuda()

因为需要删除预训练模型中不匹配的的键,也就是层的名字。

以上这篇Pytorch加载部分预训练模型的参数实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python模块学习 re 正则表达式
May 19 Python
Python Web框架Flask下网站开发入门实例
Feb 08 Python
python实现简单ftp客户端的方法
Jun 28 Python
python3批量删除豆瓣分组下的好友的实现代码
Jun 07 Python
python如何实现内容写在图片上
Mar 23 Python
pycharm远程开发项目的实现步骤
Jan 20 Python
Python 实现取多维数组第n维的前几位
Nov 26 Python
Python高级特性——详解多维数组切片(Slice)
Nov 26 Python
python3.7通过thrift操作hbase的示例代码
Jan 14 Python
Python开发之身份证验证库id_validator验证身份证号合法性及根据身份证号返回住址年龄等信息
Mar 20 Python
Python 如何测试文件是否存在
Jul 31 Python
python中@property的作用和getter setter的解释
Dec 22 Python
在pytorch中查看可训练参数的例子
Aug 18 #Python
浅析PyTorch中nn.Module的使用
Aug 18 #Python
关于PyTorch 自动求导机制详解
Aug 18 #Python
pytorch神经网络之卷积层与全连接层参数的设置方法
Aug 18 #Python
pytorch numpy list类型之间的相互转换实例
Aug 18 #Python
对Pytorch中nn.ModuleList 和 nn.Sequential详解
Aug 18 #Python
pytorch 自定义数据集加载方法
Aug 18 #Python
You might like
十天学会php之第六天
2006/10/09 PHP
Dedecms V3.1 生成HTML速度的优化办法
2007/03/18 PHP
PHP curl 抓取AJAX异步内容示例
2014/09/09 PHP
php实现的支持imagemagick及gd库两种处理的缩略图生成类
2014/09/23 PHP
php通过Chianz.com获取IP地址与地区的方法
2015/01/14 PHP
Laravel 5 框架入门(三)
2015/04/09 PHP
PHP批量获取网页中所有固定种子链接的方法
2016/11/18 PHP
PHP mysqli事务操作常用方法分析
2017/07/22 PHP
ThinkPHP类似AOP思想的参数验证的实现方法
2019/12/18 PHP
用JavaScript显示随机图像或引用
2009/04/21 Javascript
JQuery+JS实现仿百度搜索结果中关键字变色效果
2011/08/02 Javascript
javascript实现页面内关键词高亮显示代码
2014/04/03 Javascript
使用JavaScript+canvas实现图片裁剪
2015/01/30 Javascript
浅析jquery数组删除指定元素的方法:grep()
2016/05/19 Javascript
浅谈JavaScript对象与继承
2016/07/10 Javascript
jQuery UI制作选项卡(tabs)
2016/12/13 Javascript
Vue.js中轻松解决v-for执行出错的三个方案
2017/06/09 Javascript
Vue.js 的移动端组件库mint-ui实现无限滚动加载更多的方法
2017/12/23 Javascript
[06:44]2014DOTA2国际邀请赛-钥匙体育馆开战 开幕式振奋人心
2014/07/19 DOTA
[01:13:51]TNC vs Serenity 2018国际邀请赛小组赛BO2 第二场 8.18
2018/08/19 DOTA
python启动办公软件进程(word、excel、ppt、以及wps的et、wps、wpp)
2009/04/09 Python
Python挑选文件夹里宽大于300图片的方法
2015/03/05 Python
浅谈Python中range和xrange的区别
2017/12/20 Python
cmd运行python文件时对结果进行保存的方法
2018/05/16 Python
python 求一个列表中所有元素的乘积实例
2019/06/11 Python
在python中实现同行输入/接收多个数据的示例
2019/07/20 Python
使用pyinstaller逆向.pyc文件
2019/12/20 Python
使用PyTorch将文件夹下的图片分为训练集和验证集实例
2020/01/08 Python
css3 border-radius属性详解
2017/07/05 HTML / CSS
使用phonegap创建联系人的实现方法
2017/03/30 HTML / CSS
澳大利亚百货公司:David Jones
2018/02/08 全球购物
暑期培训随笔感言
2014/03/10 职场文书
先进事迹材料范文
2014/12/29 职场文书
思想工作总结范文
2015/08/12 职场文书
mysql如何查询连续记录
2022/05/11 MySQL
LeetCode189轮转数组python示例
2022/08/05 Python