Pytorch加载部分预训练模型的参数实例


Posted in Python onAugust 18, 2019

前言

自从从深度学习框架caffe转到Pytorch之后,感觉Pytorch的优点妙不可言,各种设计简洁,方便研究网络结构修改,容易上手,比TensorFlow的臃肿好多了。对于深度学习的初学者,Pytorch值得推荐。今天主要主要谈谈Pytorch是如何加载预训练模型的参数以及代码的实现过程。

直接加载预选脸模型

如果我们使用的模型和预训练模型完全一样,那么我们就可以直接加载别人的模型,还有一种情况,我们在训练自己模型的过程中,突然中断了,但只要我们保存了之前的模型的参数也可以使用下面的代码直接加载我们保存的模型继续训练,不用从头开始。

model=DPN(*args, **kwargs)
model.load_state_dict(torch.load("DPN.pth"))

这样的加载方式是基于Pytorch使用的模型存储方法:

torch.save(DPN.state_dict(), "DPN.pth")

加载部分预训练模型参数

其实大多数时候我们根据自己的任物所提出的模型是在一些公开模型的基础上改变而来,其中公开模型的参数我们没有必要在从头开始训练,只要加载其训练好的模型参数即可,这样有助于提高训练的准确率和我们模型的泛化能力。

model = DPN(num_init_features=64, k_R=96, G=32, k_sec=(3,4,20,3), inc_sec=(16,32,24,128), num_classes=1,decoder=args.decoder)
 http = {'url': 'http://data.lip6.fr/cadene/pretrainedmodels/dpn92_extra-b040e4a9b.pth'}
 pretrained_dict=model_zoo.load_url(http['url'])
 model_dict = model.state_dict()
 pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}#filter out unnecessary keys 
 model_dict.update(pretrained_dict)
 model.load_state_dict(model_dict)
 model = torch.nn.DataParallel(model).cuda()

因为需要删除预训练模型中不匹配的的键,也就是层的名字。

以上这篇Pytorch加载部分预训练模型的参数实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python实现从字符串中找出字符1的位置以及个数的方法
Aug 25 Python
Python3写入文件常用方法实例分析
May 22 Python
简单的python后台管理程序
Apr 13 Python
全面分析Python的优点和缺点
Feb 07 Python
详谈python中冒号与逗号的区别
Apr 18 Python
python3使用SMTP发送HTML格式邮件
Jun 19 Python
python 调用钉钉机器人的方法
Feb 20 Python
Flask框架学习笔记之表单基础介绍与表单提交方式
Aug 12 Python
使用PyInstaller将Pygame库编写的小游戏程序打包为exe文件及出现问题解决方法
Sep 06 Python
pycharm显示远程图片的实现
Nov 04 Python
Python中os模块功能与用法详解
Feb 26 Python
关于的python五子棋的算法
May 02 Python
在pytorch中查看可训练参数的例子
Aug 18 #Python
浅析PyTorch中nn.Module的使用
Aug 18 #Python
关于PyTorch 自动求导机制详解
Aug 18 #Python
pytorch神经网络之卷积层与全连接层参数的设置方法
Aug 18 #Python
pytorch numpy list类型之间的相互转换实例
Aug 18 #Python
对Pytorch中nn.ModuleList 和 nn.Sequential详解
Aug 18 #Python
pytorch 自定义数据集加载方法
Aug 18 #Python
You might like
十大催泪虐心动漫,你能坚持看到第几部?
2020/03/04 日漫
PHP判断指定时间段的2个方法
2014/03/14 PHP
PHP常用操作类之通信数据封装类的实现
2017/07/16 PHP
PHP基于phpqrcode类生成二维码的方法示例详解
2020/08/07 PHP
php中yar框架实例用法讲解
2020/12/27 PHP
js 跨域和ajax 跨域问题小结
2009/07/01 Javascript
24款非常有用的 jQuery 插件分享
2011/04/06 Javascript
js验证是否为数字的总结
2013/04/14 Javascript
Jquery 切换不同图片示例代码
2013/12/05 Javascript
AngularJS ng-template寄宿方式用法分析
2016/11/07 Javascript
微信小程序 基础组件与导航组件详细介绍
2017/02/21 Javascript
js中的触发事件对象event.srcElement与event.target详解
2017/03/15 Javascript
Vue2 使用 Echarts 创建图表实例代码
2017/05/18 Javascript
ReactJS实现表单的单选多选和反选的示例
2017/10/13 Javascript
Vue.js的复用组件开发流程完整记录
2018/11/29 Javascript
JS绘图Flot应用图形绘制异常解决方案
2020/10/16 Javascript
详解Python 实现元胞自动机中的生命游戏(Game of life)
2018/01/27 Python
python+pandas生成指定日期和重采样的方法
2018/04/11 Python
python实现微信机器人: 登录微信、消息接收、自动回复功能
2019/04/29 Python
Python BeautifulSoup [解决方法] TypeError: list indices must be integers or slices, not str
2019/08/07 Python
python requests证书问题解决
2019/09/05 Python
Python如何使用Gitlab API实现批量的合并分支
2019/11/27 Python
python环境下安装opencv库的方法
2020/03/05 Python
Python with语句用法原理详解
2020/07/03 Python
英国巧克力贸易公司:Chocolate Trading Company
2017/03/21 全球购物
远东集团网络工程师面试题
2014/10/20 面试题
岗位职责说明书
2014/05/07 职场文书
品牌服务方案
2014/06/03 职场文书
营销学习心得体会
2014/09/12 职场文书
社保代办委托书怎么写
2014/10/06 职场文书
事业单位聘任报告
2015/03/02 职场文书
房屋租赁意向书范本
2015/05/09 职场文书
vue中 this.$set的使用详解
2021/11/17 Vue.js
SQL Server数据库查询出现阻塞之性能调优
2022/04/10 SQL Server
阿里云ECS云服务器快照的概念以及如何使用
2022/04/21 Servers
MySQL选择合适的备份策略和备份工具
2022/06/01 MySQL