Pytorch转onnx、torchscript方式


Posted in Python onMay 25, 2020

前言

本文将介绍如何使用ONNX将PyTorch中训练好的模型(.pt、.pth)型转换为ONNX格式,然后将其加载到Caffe2中。需要安装好onnx和Caffe2。

PyTorch及ONNX环境准备

为了正常运行ONNX,我们需要安装最新的Pytorch,你可以选择源码安装:

git clone --recursive https://github.com/pytorch/pytorch
cd pytorch
mkdir build && cd build
sudo cmake .. -DPYTHON_INCLUDE_DIR=/usr/include/python3.6 -DUSE_MPI=OFF
make install
export PYTHONPATH=$PYTHONPATH:/opt/pytorch/build

其中 "/opt/pytorch/build"是前面build pytorch的目。

or conda安装

conda install pytorch torchvision -c pytorch

安装ONNX的库

conda install -c conda-forge onnx

onnx-caffe2 安装

pip3 install onnx-caffe2

Pytorch模型转onnx

在PyTorch中导出模型通过跟踪工作。要导出模型,请调用torch.onnx.export()函数。这将执行模型,记录运算符用于计算输出的轨迹。因为_export运行模型,我们需要提供输入张量x。

这个张量的值并不重要; 它可以是图像或随机张量,只要它是正确的大小。更多详细信息,请查看torch.onnx documentation文档。

# 输入模型
example = torch.randn(batch_size, 1, 224, 224, requires_grad=True)

# 导出模型
torch_out = torch_out = torch.onnx.export(model, # model being run
    example, # model input (or a tuple for multiple inputs)
    "peleeNet.onnx",
 verbose=False, # store the trained parameter weights inside the model file
 training=False,
 do_constant_folding=True,
 input_names=['input'],
 output_names=['output'])

其中torch_out是执行模型后的输出,通常以忽略此输出。转换得到onnx后可以使用OpenCV的 cv::dnn::readNetFromONNX or cv::dnn::readNet进行模型加载推理了。

还可以进一步将onnx模型转换为ncnn进而部署到移动端。这就需要ncnn的onnx2ncnn工具了.

编译ncnn源码,生成 onnx2ncnn。

其中onnx转换模型时有一些冗余,可以使用用工具简化一些onnx模型。

pip3 install onnx-simplifier

简化onnx模型

python3 -m onnxsim pnet.onnx pnet-sim.onnx

转换成ncnn

onnx2ncnn pnet-sim.onnx pnet.param pnet.bin

ncnn 加载模型做推理

Pytorch模型转torch script

pytorch 加入libtorch前端处理,集体步骤为:

Pytorch转onnx、torchscript方式

以mtcnn pnet为例

# convert pytorch model to torch script
# An example input you would normally provide to your model's forward() method.
example = torch.rand(1, 3, 12, 12).to(device)
# Use torch.jit.trace to generate a torch.jit.ScriptModule via tracing.
traced_script_module = torch.jit.trace(pnet, example)
# Save traced model
traced_script_module.save("pnet_model_final.pth")

C++调用如下所示:

#include <torch/script.h> // One-stop header.
#include <iostream>
#include <memory>
int main(int argc, const char* argv[]) 
{
 if (argc != 2) 
 {
 std::cerr << "usage: example-app <path-to-exported-script-module>\n";
 return -1;
 }

 // Deserialize the ScriptModule from a file using torch::jit::load().
 std::shared_ptr<torch::jit::script::Module> module = torch::jit::load(argv[1]);

 assert(module != nullptr);
 std::cout << "ok\n";
}
Python 相关文章推荐
python单链表实现代码实例
Nov 21 Python
Python每天必学之bytes字节
Jan 28 Python
使用Python通过win32 COM打开Excel并添加Sheet的方法
May 02 Python
python下PyGame的下载与安装过程及遇到问题
Aug 04 Python
python读写Excel表格的实例代码(简单实用)
Dec 19 Python
pytorch梯度剪裁方式
Feb 04 Python
Python decimal模块使用方法详解
Jun 08 Python
Django Auth用户认证组件实现代码
Oct 13 Python
python如何构建mock接口服务
Jan 28 Python
一文搞懂如何实现Go 超时控制
Mar 30 Python
Python与C++中梯度方向直方图的实现
Mar 17 Python
Python可视化神器pyecharts绘制水球图
Jul 07 Python
使用pandas库对csv文件进行筛选保存
May 25 #Python
pytorch中 gpu与gpu、gpu与cpu 在load时相互转化操作
May 25 #Python
基于pandas向csv添加新的行和列
May 25 #Python
Python如何把十进制数转换成ip地址
May 25 #Python
tensorflow模型转ncnn的操作方式
May 25 #Python
MxNet预训练模型到Pytorch模型的转换方式
May 25 #Python
浅谈pytorch 模型 .pt, .pth, .pkl的区别及模型保存方式
May 25 #Python
You might like
图书管理程序(二)
2006/10/09 PHP
简单谈谈PHP中的Reload操作
2016/12/12 PHP
PHP实现简单注册登录系统
2020/12/28 PHP
js 发个判断字符串是否为符合标准的函数
2009/04/27 Javascript
浅析javascript闭包 实例分析
2010/12/25 Javascript
基于jquery的3d效果实现代码
2011/03/23 Javascript
JavaScript之Getters和Setters 平台支持等详细介绍
2012/12/07 Javascript
jquery中交替点击事件toggle方法的使用示例
2013/12/08 Javascript
js实现同一页面多个运动效果的方法
2015/04/10 Javascript
贴近用户体验的Jquery日期、时间选择插件
2015/08/19 Javascript
浅析JavaScript访问对象属性和方法及区别
2015/11/16 Javascript
老生常谈 js中this的指向
2016/06/30 Javascript
jQuery向父辈遍历的简单方法
2016/09/18 Javascript
JS给Array添加是否包含字符串的简单方法
2016/10/29 Javascript
基于jQuery实现图片推拉门动画效果的两种方法
2017/08/26 jQuery
Vue下滚动到页面底部无限加载数据的示例代码
2018/04/22 Javascript
webpack实现一个行内样式px转vw的loader示例
2018/09/13 Javascript
NodeJs 文件系统操作模块fs使用方法详解
2018/11/26 NodeJs
vue基础之事件简写、事件对象、冒泡、默认行为、键盘事件实例分析
2019/03/11 Javascript
[01:08:09]DOTA2上海特级锦标赛主赛事日 - 1 胜者组第一轮#1Liquid VS Alliance第二局
2016/03/02 DOTA
python 全文检索引擎详解
2017/04/25 Python
Python基于辗转相除法求解最大公约数的方法示例
2018/04/04 Python
python交易记录链的实现过程详解
2019/07/03 Python
python pygame实现挡板弹球游戏
2019/11/25 Python
现代生活方式的家具和装饰:Dot & Bo
2018/12/26 全球购物
C语言开发工程师测试题
2016/12/20 面试题
如何在存储过程中使用Loop
2016/01/05 面试题
毕业自我评价
2014/02/05 职场文书
主管会计岗位职责
2014/03/13 职场文书
2015年个人现实表现材料
2014/12/10 职场文书
MBA推荐信怎么写
2015/03/25 职场文书
国家助学贷款承诺书
2015/04/30 职场文书
党内外群众意见范文
2015/06/02 职场文书
小学英语教师2015年度个人工作总结
2015/10/14 职场文书
Java中PriorityQueue实现最小堆和最大堆的用法
2021/06/27 Java/Android
简单且有用的Python数据分析和机器学习代码
2021/07/02 Python