Pytorch转onnx、torchscript方式


Posted in Python onMay 25, 2020

前言

本文将介绍如何使用ONNX将PyTorch中训练好的模型(.pt、.pth)型转换为ONNX格式,然后将其加载到Caffe2中。需要安装好onnx和Caffe2。

PyTorch及ONNX环境准备

为了正常运行ONNX,我们需要安装最新的Pytorch,你可以选择源码安装:

git clone --recursive https://github.com/pytorch/pytorch
cd pytorch
mkdir build && cd build
sudo cmake .. -DPYTHON_INCLUDE_DIR=/usr/include/python3.6 -DUSE_MPI=OFF
make install
export PYTHONPATH=$PYTHONPATH:/opt/pytorch/build

其中 "/opt/pytorch/build"是前面build pytorch的目。

or conda安装

conda install pytorch torchvision -c pytorch

安装ONNX的库

conda install -c conda-forge onnx

onnx-caffe2 安装

pip3 install onnx-caffe2

Pytorch模型转onnx

在PyTorch中导出模型通过跟踪工作。要导出模型,请调用torch.onnx.export()函数。这将执行模型,记录运算符用于计算输出的轨迹。因为_export运行模型,我们需要提供输入张量x。

这个张量的值并不重要; 它可以是图像或随机张量,只要它是正确的大小。更多详细信息,请查看torch.onnx documentation文档。

# 输入模型
example = torch.randn(batch_size, 1, 224, 224, requires_grad=True)

# 导出模型
torch_out = torch_out = torch.onnx.export(model, # model being run
    example, # model input (or a tuple for multiple inputs)
    "peleeNet.onnx",
 verbose=False, # store the trained parameter weights inside the model file
 training=False,
 do_constant_folding=True,
 input_names=['input'],
 output_names=['output'])

其中torch_out是执行模型后的输出,通常以忽略此输出。转换得到onnx后可以使用OpenCV的 cv::dnn::readNetFromONNX or cv::dnn::readNet进行模型加载推理了。

还可以进一步将onnx模型转换为ncnn进而部署到移动端。这就需要ncnn的onnx2ncnn工具了.

编译ncnn源码,生成 onnx2ncnn。

其中onnx转换模型时有一些冗余,可以使用用工具简化一些onnx模型。

pip3 install onnx-simplifier

简化onnx模型

python3 -m onnxsim pnet.onnx pnet-sim.onnx

转换成ncnn

onnx2ncnn pnet-sim.onnx pnet.param pnet.bin

ncnn 加载模型做推理

Pytorch模型转torch script

pytorch 加入libtorch前端处理,集体步骤为:

Pytorch转onnx、torchscript方式

以mtcnn pnet为例

# convert pytorch model to torch script
# An example input you would normally provide to your model's forward() method.
example = torch.rand(1, 3, 12, 12).to(device)
# Use torch.jit.trace to generate a torch.jit.ScriptModule via tracing.
traced_script_module = torch.jit.trace(pnet, example)
# Save traced model
traced_script_module.save("pnet_model_final.pth")

C++调用如下所示:

#include <torch/script.h> // One-stop header.
#include <iostream>
#include <memory>
int main(int argc, const char* argv[]) 
{
 if (argc != 2) 
 {
 std::cerr << "usage: example-app <path-to-exported-script-module>\n";
 return -1;
 }

 // Deserialize the ScriptModule from a file using torch::jit::load().
 std::shared_ptr<torch::jit::script::Module> module = torch::jit::load(argv[1]);

 assert(module != nullptr);
 std::cout << "ok\n";
}
Python 相关文章推荐
使用Python中的greenlet包实现并发编程的入门教程
Apr 16 Python
Python基于identicon库创建类似Github上用的头像功能
Sep 25 Python
python编程之requests在网络请求中添加cookies参数方法详解
Oct 25 Python
Anaconda下安装mysql-python的包实例
Jun 11 Python
Python找出微信上删除你好友的人脚本写法
Nov 01 Python
Django生成PDF文档显示在网页上以及解决PDF中文显示乱码的问题
Jul 04 Python
Pandas 重塑(stack)和轴向旋转(pivot)的实现
Jul 22 Python
python中类的输出或类的实例输出为这种形式的原因
Aug 12 Python
Pycharm如何导入python文件及解决报错问题
May 10 Python
Python3爬虫里关于代理的设置总结
Jul 30 Python
Python pip install之SSL异常处理操作
Sep 03 Python
Python中递归以及递归遍历目录详解
Oct 24 Python
使用pandas库对csv文件进行筛选保存
May 25 #Python
pytorch中 gpu与gpu、gpu与cpu 在load时相互转化操作
May 25 #Python
基于pandas向csv添加新的行和列
May 25 #Python
Python如何把十进制数转换成ip地址
May 25 #Python
tensorflow模型转ncnn的操作方式
May 25 #Python
MxNet预训练模型到Pytorch模型的转换方式
May 25 #Python
浅谈pytorch 模型 .pt, .pth, .pkl的区别及模型保存方式
May 25 #Python
You might like
大师制作的中短波矿石收音机
2020/04/02 无线电
phpmyadmin 常用选项设置详解版
2010/03/07 PHP
rrmdir php中递归删除目录及目录下的文件
2011/05/15 PHP
php使用sql数据库 获取字段问题介绍
2013/08/12 PHP
使用Firebug对js进行断点调试的图文方法
2011/04/02 Javascript
JQGrid的用法解析(列编辑,添加行,删除行)
2013/11/08 Javascript
jquery1.9 下检测浏览器类型和版本的方法
2013/12/26 Javascript
js实现鼠标感应图片展示的方法
2015/02/27 Javascript
jQuery中extend函数详解
2015/07/13 Javascript
深入浅析AngularJS中的module(模块)
2016/01/04 Javascript
AngularJS入门(用ng-repeat指令实现循环输出
2016/05/05 Javascript
jquery日历插件e-calendar升级版
2016/11/10 Javascript
详解Node.js串行化流程控制
2017/05/04 Javascript
Avalonjs双向数据绑定与监听的实例代码
2017/06/23 Javascript
web前端页面生成exe可执行文件的方法
2018/02/08 Javascript
vue绑定事件后获取绑定事件中的this方法
2018/09/15 Javascript
详解React项目中碰到的IE问题
2019/03/14 Javascript
Vue-input框checkbox强制刷新问题
2019/04/18 Javascript
微信小程序 组件的外部样式externalClasses使用详解
2019/09/06 Javascript
如何通过javaScript去除字符串两端的空白字符
2020/02/06 Javascript
leaflet加载geojson叠加显示功能代码
2020/02/21 Javascript
Vue.js暴露方法给WebView的使用操作
2020/09/07 Javascript
openlayers实现地图测距测面
2020/09/25 Javascript
[04:11]2014DOTA2国际邀请赛 CIS遗憾出局梦想不灭
2014/07/09 DOTA
[01:20]辉夜杯背景故事宣传片《辉夜传说》
2015/12/25 DOTA
python正则表达式去掉数字中的逗号(python正则匹配逗号)
2013/12/25 Python
使用Python从零开始撸一个区块链
2018/03/14 Python
CentOS7下python3.7.0安装教程
2018/07/30 Python
python flask web服务实现更换默认端口和IP的方法
2019/07/26 Python
vscode 配置 python3开发环境的方法
2019/09/19 Python
无谷物狗粮:Pooch & Mutt
2018/05/23 全球购物
质检部岗位职责
2013/11/11 职场文书
五一劳动节演讲稿
2014/09/12 职场文书
教师四风对照检查材料思想汇报
2014/09/17 职场文书
公司总经理岗位职责
2015/04/01 职场文书
MySQL 数据表操作
2022/05/04 MySQL