Pytorch加载部分预训练模型的参数实例


Posted in Python onAugust 18, 2019

前言

自从从深度学习框架caffe转到Pytorch之后,感觉Pytorch的优点妙不可言,各种设计简洁,方便研究网络结构修改,容易上手,比TensorFlow的臃肿好多了。对于深度学习的初学者,Pytorch值得推荐。今天主要主要谈谈Pytorch是如何加载预训练模型的参数以及代码的实现过程。

直接加载预选脸模型

如果我们使用的模型和预训练模型完全一样,那么我们就可以直接加载别人的模型,还有一种情况,我们在训练自己模型的过程中,突然中断了,但只要我们保存了之前的模型的参数也可以使用下面的代码直接加载我们保存的模型继续训练,不用从头开始。

model=DPN(*args, **kwargs)
model.load_state_dict(torch.load("DPN.pth"))

这样的加载方式是基于Pytorch使用的模型存储方法:

torch.save(DPN.state_dict(), "DPN.pth")

加载部分预训练模型参数

其实大多数时候我们根据自己的任物所提出的模型是在一些公开模型的基础上改变而来,其中公开模型的参数我们没有必要在从头开始训练,只要加载其训练好的模型参数即可,这样有助于提高训练的准确率和我们模型的泛化能力。

model = DPN(num_init_features=64, k_R=96, G=32, k_sec=(3,4,20,3), inc_sec=(16,32,24,128), num_classes=1,decoder=args.decoder)
 http = {'url': 'http://data.lip6.fr/cadene/pretrainedmodels/dpn92_extra-b040e4a9b.pth'}
 pretrained_dict=model_zoo.load_url(http['url'])
 model_dict = model.state_dict()
 pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}#filter out unnecessary keys 
 model_dict.update(pretrained_dict)
 model.load_state_dict(model_dict)
 model = torch.nn.DataParallel(model).cuda()

因为需要删除预训练模型中不匹配的的键,也就是层的名字。

以上这篇Pytorch加载部分预训练模型的参数实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python的即时标记项目练习笔记
Sep 18 Python
详解python中xlrd包的安装与处理Excel表格
Dec 16 Python
python 返回列表中某个值的索引方法
Nov 07 Python
django主动抛出403异常的方法详解
Jan 04 Python
基于python进行抽样分布描述及实践详解
Sep 02 Python
pyhton中__pycache__文件夹的产生与作用详解
Nov 24 Python
在Django下创建项目以及设置settings.py教程
Dec 03 Python
使用Python制作新型冠状病毒实时疫情图
Jan 28 Python
pyqt5 QlistView列表显示的实现示例
Mar 24 Python
Python爬虫中Selenium实现文件上传
Dec 04 Python
如何用Python编写一个电子考勤系统
Feb 08 Python
Python深度学习之Pytorch初步使用
May 20 Python
在pytorch中查看可训练参数的例子
Aug 18 #Python
浅析PyTorch中nn.Module的使用
Aug 18 #Python
关于PyTorch 自动求导机制详解
Aug 18 #Python
pytorch神经网络之卷积层与全连接层参数的设置方法
Aug 18 #Python
pytorch numpy list类型之间的相互转换实例
Aug 18 #Python
对Pytorch中nn.ModuleList 和 nn.Sequential详解
Aug 18 #Python
pytorch 自定义数据集加载方法
Aug 18 #Python
You might like
PHP设计模式之调解者模式的深入解析
2013/06/13 PHP
PHP字符串长度计算 - strlen()函数使用介绍
2013/10/15 PHP
Laravel框架数据库CURD操作、连贯操作总结
2014/09/03 PHP
php用正则判断是否为数字的方法
2016/03/25 PHP
php事件驱动化设计详解
2016/11/10 PHP
PHP十六进制颜色随机生成器功能示例
2017/07/24 PHP
PHP文件类型检查及fileinfo模块安装使用详解
2019/05/09 PHP
js中parseInt函数浅谈
2013/07/31 Javascript
jquery放大镜效果超漂亮噢
2013/11/15 Javascript
jQuery实现可用于博客的动态滑动菜单
2015/03/09 Javascript
BootStrap制作导航条实例代码
2016/05/06 Javascript
javascript url几种编码方式详解
2016/06/06 Javascript
AngularJS基础 ng-repeat 指令简单示例
2016/08/03 Javascript
深入理解Angularjs向指令传递数据双向绑定机制
2016/12/31 Javascript
微信小程序 数组中的push与concat的区别
2017/01/05 Javascript
BootStrap table删除指定行的注意事项(笔记整理)
2017/02/05 Javascript
JavaScript方法_动力节点Java学院整理
2017/06/28 Javascript
Vue项目中quill-editor带样式编辑器的使用方法
2017/08/08 Javascript
JS实现的简单标签点击切换功能示例
2017/09/21 Javascript
详解Vue路由History mode模式中页面无法渲染的原因及解决
2017/09/28 Javascript
es6 filter() 数组过滤方法总结
2019/04/03 Javascript
Vue拖拽组件列表实现动态页面配置功能
2019/06/17 Javascript
nodejs+express最简易的连接数据库的方法
2020/12/23 NodeJs
[04:22]DOTA2上海特级锦标赛主赛事第四日TOP10
2016/03/06 DOTA
零基础写python爬虫之爬虫的定义及URL构成
2014/11/04 Python
开始着手第一个Django项目
2015/07/15 Python
Python获取当前公网ip并自动断开宽带连接实例代码
2018/01/12 Python
Python 字符串操作(string替换、删除、截取、复制、连接、比较、查找、包含、大小写转换、分割等)
2018/03/19 Python
Python闭包和装饰器用法实例详解
2019/05/22 Python
keras load model时出现Missing Layer错误的解决方式
2020/06/11 Python
python入门:argparse浅析 nargs='+'作用
2020/07/12 Python
用OpenCV进行年龄和性别检测的实现示例
2021/01/29 Python
抽象方法、抽象类怎样声明
2014/10/25 面试题
Three.js实现雪糕地球的使用示例详解
2022/07/07 Javascript
Linux中一对多配置日志服务器的详细步骤
2022/07/23 Servers
使用JS前端技术实现静态图片局部流动效果
2022/08/05 Javascript