从Pytorch模型pth文件中读取参数成numpy矩阵的操作


Posted in Python onMarch 04, 2021

目的:

把训练好的pth模型参数提取出来,然后用其他方式部署到边缘设备。

Pytorch给了很方便的读取参数接口:

nn.Module.parameters()

直接看demo:

from torchvision.models.alexnet import alexnet 
model = alexnet(pretrained=True).eval().cuda()
parameters = model.parameters()
for p in parameters:
  numpy_para = p.detach().cpu().numpy()
  print(type(numpy_para))
  print(numpy_para.shape)

上面得到的numpy_para就是numpy参数了~

Note:

model.parameters()是以一个生成器的形式迭代返回每一层的参数。所以用for循环读取到各层的参数,循环次数就表示层数。

而每一层的参数都是torch.nn.parameter.Parameter类型,是Tensor的子类,所以直接用tensor转numpy(即p.detach().cpu().numpy())的方法就可以直接转成numpy矩阵。

方便又好用,爆赞~

补充:pytorch训练好的.pth模型转换为.pt

将python训练好的.pth文件转为.pt

import torch
import torchvision
from unet import UNet
model = UNet(3, 2)#自己定义的网络模型
model.load_state_dict(torch.load("best_weights.pth"))#保存的训练模型
model.eval()#切换到eval()
example = torch.rand(1, 3, 320, 480)#生成一个随机输入维度的输入
traced_script_module = torch.jit.trace(model, example)
traced_script_module.save("model.pt")

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。如有错误或未考虑完全的地方,望不吝赐教。

Python 相关文章推荐
Python 实现购物商城,含有用户入口和商家入口的示例
Sep 15 Python
Python栈算法的实现与简单应用示例
Nov 01 Python
python 筛选数据集中列中value长度大于20的数据集方法
Jun 14 Python
Python多继承原理与用法示例
Aug 23 Python
python列表list保留顺序去重的实例
Dec 14 Python
对python:threading.Thread类的使用方法详解
Jan 31 Python
详解python算法之冒泡排序
Mar 05 Python
python实现给微信指定好友定时发送消息
Apr 29 Python
python实现广度优先搜索过程解析
Oct 19 Python
keras打印loss对权重的导数方式
Jun 10 Python
解决keras加入lambda层时shape的问题
Jun 11 Python
python gui开发——制作抖音无水印视频下载工具(附源码)
Feb 07 Python
python 如何用urllib与服务端交互(发送和接收数据)
Mar 04 #Python
python 求两个向量的顺时针夹角操作
Mar 04 #Python
python 制作磁力搜索工具
Mar 04 #Python
python抢购软件/插件/脚本附完整源码
Mar 04 #Python
Python 求向量的余弦值操作
Mar 04 #Python
django使用多个数据库的方法实例
Mar 04 #Python
Python使用paramiko连接远程服务器执行Shell命令的实现
Mar 04 #Python
You might like
PHP-CGI进程CPU 100% 与 file_get_contents 函数的关系分析
2011/08/15 PHP
PHP_SELF,SCRIPT_NAME,REQUEST_URI区别
2014/12/24 PHP
再推荐十款免费的php开发工具
2015/11/09 PHP
php实现微信公众平台发红包功能
2018/06/14 PHP
php代码调试利器firephp安装与使用方法分析
2018/08/21 PHP
PHP抽象类和接口用法实例详解
2019/07/20 PHP
js实现双向链表互联网机顶盒实战应用实现
2011/10/28 Javascript
js正文内容高亮效果的实现方法
2013/06/30 Javascript
点击按钮自动加关注的代码(sina微博/QQ空间/人人网/腾讯微博)
2014/01/02 Javascript
js使用for循环及if语句判断多个一样的name
2014/09/09 Javascript
javascript随机显示背景图片的方法
2015/06/18 Javascript
JS实现兼容性好,带缓冲的动感网页右键菜单效果
2015/09/18 Javascript
JavaScript禁止复制与粘贴的实现代码
2016/05/16 Javascript
jQuery的Cookie封装,与PHP交互的简单实现
2016/10/05 Javascript
JS中事件冒泡和事件捕获介绍
2016/12/13 Javascript
vue-router路由懒加载和权限控制详解
2017/12/13 Javascript
jQuery序列化form表单数据为JSON对象的实现方法
2018/09/20 jQuery
js防抖和节流的深入讲解
2018/12/06 Javascript
JavaScript遍历查找数组中最大值与最小值的方法示例
2019/05/24 Javascript
[00:47]DOTA2荣耀之路6:玩不了啦!
2018/05/30 DOTA
Python3中条件控制、循环与函数的简易教程
2017/11/21 Python
对python 操作solr索引数据的实例详解
2018/12/07 Python
Python实现FTP弱口令扫描器的方法示例
2019/01/31 Python
用Python 执行cmd命令
2020/12/18 Python
基于HTML5 audio元素播放声音jQuery小插件
2011/05/11 HTML / CSS
STUBHUB日本:购买和出售全球活动门票
2018/07/01 全球购物
100%羊绒:NakedCashmere
2020/08/26 全球购物
Java 中访问数据库的步骤?Statement 和PreparedStatement 之间的区别?
2012/06/05 面试题
和平主题的演讲稿
2014/01/12 职场文书
《月亮湾》教学反思
2014/04/14 职场文书
社区学习雷锋活动总结
2014/04/25 职场文书
汉语言文学专业求职信
2014/06/19 职场文书
工作证明英文模板
2014/10/21 职场文书
行为规范主题班会
2015/08/13 职场文书
2019同学聚会主持词
2019/05/06 职场文书
《思路决定出路》读后感3篇
2019/12/11 职场文书