从Pytorch模型pth文件中读取参数成numpy矩阵的操作


Posted in Python onMarch 04, 2021

目的:

把训练好的pth模型参数提取出来,然后用其他方式部署到边缘设备。

Pytorch给了很方便的读取参数接口:

nn.Module.parameters()

直接看demo:

from torchvision.models.alexnet import alexnet 
model = alexnet(pretrained=True).eval().cuda()
parameters = model.parameters()
for p in parameters:
  numpy_para = p.detach().cpu().numpy()
  print(type(numpy_para))
  print(numpy_para.shape)

上面得到的numpy_para就是numpy参数了~

Note:

model.parameters()是以一个生成器的形式迭代返回每一层的参数。所以用for循环读取到各层的参数,循环次数就表示层数。

而每一层的参数都是torch.nn.parameter.Parameter类型,是Tensor的子类,所以直接用tensor转numpy(即p.detach().cpu().numpy())的方法就可以直接转成numpy矩阵。

方便又好用,爆赞~

补充:pytorch训练好的.pth模型转换为.pt

将python训练好的.pth文件转为.pt

import torch
import torchvision
from unet import UNet
model = UNet(3, 2)#自己定义的网络模型
model.load_state_dict(torch.load("best_weights.pth"))#保存的训练模型
model.eval()#切换到eval()
example = torch.rand(1, 3, 320, 480)#生成一个随机输入维度的输入
traced_script_module = torch.jit.trace(model, example)
traced_script_module.save("model.pt")

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。如有错误或未考虑完全的地方,望不吝赐教。

Python 相关文章推荐
python生成验证码图片代码分享
Jan 28 Python
Python二叉搜索树与双向链表转换实现方法
Apr 29 Python
python pandas中对Series数据进行轴向连接的实例
Jun 08 Python
flask入门之表单的实现
Jul 18 Python
pygame游戏之旅 调用按钮实现游戏开始功能
Nov 21 Python
Python操作mongodb数据库的方法详解
Dec 08 Python
Python 中Django安装和使用教程详解
Jul 03 Python
django框架中ajax的使用及避开CSRF 验证的方式详解
Dec 11 Python
Keras官方中文文档:性能评估Metrices详解
Jun 15 Python
Python实现手势识别
Oct 21 Python
python 下载m3u8视频的示例代码
Nov 11 Python
Python 流媒体播放器的实现(基于VLC)
Apr 28 Python
python 如何用urllib与服务端交互(发送和接收数据)
Mar 04 #Python
python 求两个向量的顺时针夹角操作
Mar 04 #Python
python 制作磁力搜索工具
Mar 04 #Python
python抢购软件/插件/脚本附完整源码
Mar 04 #Python
Python 求向量的余弦值操作
Mar 04 #Python
django使用多个数据库的方法实例
Mar 04 #Python
Python使用paramiko连接远程服务器执行Shell命令的实现
Mar 04 #Python
You might like
php变量作用域的深入解析
2013/06/03 PHP
jquery实现的元素的left增加N像素 鼠标移开会慢慢的移动到原来的位置
2010/03/21 Javascript
jQuery实现的Email中的收件人效果(按del键删除)
2011/03/20 Javascript
js中的时间转换—毫秒转换成日期时间的示例代码
2014/01/26 Javascript
jquery移除、绑定、触发元素事件使用示例详解
2014/04/10 Javascript
js实现无限级树形导航列表效果代码
2015/09/23 Javascript
javascript css红色经典选项卡效果实现代码
2016/05/17 Javascript
微信小程序 开发指南详解
2016/09/27 Javascript
js实现table添加行tr、删除行tr、清空行tr的简单实例
2016/10/15 Javascript
360doc网站不登录就无法复制内容的解决方法
2018/01/27 Javascript
一个Vue页面的内存泄露分析详解
2018/06/25 Javascript
Vue 组件注册实例详解
2019/02/23 Javascript
产制造追溯系统之通过微信小程序实现移动端报表平台
2019/06/03 Javascript
JavaScript实现公告栏上下滚动效果
2020/03/13 Javascript
javascript-hashchange事件和历史状态管理实例分析
2020/04/18 Javascript
jQuery+Ajax+js实现请求json格式数据并渲染到html页面操作示例
2020/06/02 jQuery
微信h5静默和非静默授权获取用户openId的方法和步骤
2020/06/08 Javascript
vue实现列表滚动的过渡动画
2020/06/29 Javascript
vue组件是如何解析及渲染的?
2021/01/13 Vue.js
[57:09]DOTA2-DPC中国联赛 正赛 Phoenix vs Dynasty BO3 第一场 1月26日
2021/03/11 DOTA
使用70行Python代码实现一个递归下降解析器的教程
2015/04/17 Python
Python设计模式之工厂模式简单示例
2018/01/09 Python
关于python写入文件自动换行的问题
2018/06/23 Python
基于numpy中数组元素的切片复制方法
2018/11/15 Python
Python3 SSH远程连接服务器的方法示例
2018/12/29 Python
Python中按值来获取指定的键
2019/03/04 Python
使用Pandas对数据进行筛选和排序的实现
2019/07/29 Python
python matplotlib中的subplot函数使用详解
2020/01/19 Python
浅谈pytorch中torch.max和F.softmax函数的维度解释
2020/06/28 Python
html5视频播放_动力节点Java学院整理
2017/07/13 HTML / CSS
全国法院系统开展党的群众路线教育实践活动综述(全文)
2014/10/25 职场文书
2015年团支部年度工作总结
2015/05/27 职场文书
2015大学党建带团建工作总结
2015/07/23 职场文书
nginx共享内存的机制详解
2022/03/21 Servers
分析SQL窗口函数之取值窗口函数
2022/04/21 Oracle
Nginx 常用配置
2022/05/15 Servers