Pytorch转onnx、torchscript方式


Posted in Python onMay 25, 2020

前言

本文将介绍如何使用ONNX将PyTorch中训练好的模型(.pt、.pth)型转换为ONNX格式,然后将其加载到Caffe2中。需要安装好onnx和Caffe2。

PyTorch及ONNX环境准备

为了正常运行ONNX,我们需要安装最新的Pytorch,你可以选择源码安装:

git clone --recursive https://github.com/pytorch/pytorch
cd pytorch
mkdir build && cd build
sudo cmake .. -DPYTHON_INCLUDE_DIR=/usr/include/python3.6 -DUSE_MPI=OFF
make install
export PYTHONPATH=$PYTHONPATH:/opt/pytorch/build

其中 "/opt/pytorch/build"是前面build pytorch的目。

or conda安装

conda install pytorch torchvision -c pytorch

安装ONNX的库

conda install -c conda-forge onnx

onnx-caffe2 安装

pip3 install onnx-caffe2

Pytorch模型转onnx

在PyTorch中导出模型通过跟踪工作。要导出模型,请调用torch.onnx.export()函数。这将执行模型,记录运算符用于计算输出的轨迹。因为_export运行模型,我们需要提供输入张量x。

这个张量的值并不重要; 它可以是图像或随机张量,只要它是正确的大小。更多详细信息,请查看torch.onnx documentation文档。

# 输入模型
example = torch.randn(batch_size, 1, 224, 224, requires_grad=True)

# 导出模型
torch_out = torch_out = torch.onnx.export(model, # model being run
    example, # model input (or a tuple for multiple inputs)
    "peleeNet.onnx",
 verbose=False, # store the trained parameter weights inside the model file
 training=False,
 do_constant_folding=True,
 input_names=['input'],
 output_names=['output'])

其中torch_out是执行模型后的输出,通常以忽略此输出。转换得到onnx后可以使用OpenCV的 cv::dnn::readNetFromONNX or cv::dnn::readNet进行模型加载推理了。

还可以进一步将onnx模型转换为ncnn进而部署到移动端。这就需要ncnn的onnx2ncnn工具了.

编译ncnn源码,生成 onnx2ncnn。

其中onnx转换模型时有一些冗余,可以使用用工具简化一些onnx模型。

pip3 install onnx-simplifier

简化onnx模型

python3 -m onnxsim pnet.onnx pnet-sim.onnx

转换成ncnn

onnx2ncnn pnet-sim.onnx pnet.param pnet.bin

ncnn 加载模型做推理

Pytorch模型转torch script

pytorch 加入libtorch前端处理,集体步骤为:

Pytorch转onnx、torchscript方式

以mtcnn pnet为例

# convert pytorch model to torch script
# An example input you would normally provide to your model's forward() method.
example = torch.rand(1, 3, 12, 12).to(device)
# Use torch.jit.trace to generate a torch.jit.ScriptModule via tracing.
traced_script_module = torch.jit.trace(pnet, example)
# Save traced model
traced_script_module.save("pnet_model_final.pth")

C++调用如下所示:

#include <torch/script.h> // One-stop header.
#include <iostream>
#include <memory>
int main(int argc, const char* argv[]) 
{
 if (argc != 2) 
 {
 std::cerr << "usage: example-app <path-to-exported-script-module>\n";
 return -1;
 }

 // Deserialize the ScriptModule from a file using torch::jit::load().
 std::shared_ptr<torch::jit::script::Module> module = torch::jit::load(argv[1]);

 assert(module != nullptr);
 std::cout << "ok\n";
}
Python 相关文章推荐
Python中文件I/O高效操作处理的技巧分享
Feb 04 Python
Python实现树莓派WiFi断线自动重连的实例代码
Mar 16 Python
Python中正则表达式详解
May 17 Python
TensorFlow实现iris数据集线性回归
Sep 07 Python
python添加模块搜索路径和包的导入方法
Jan 19 Python
Python split() 函数拆分字符串将字符串转化为列的方法
Jul 16 Python
利用python实现短信和电话提醒功能的例子
Aug 08 Python
django处理select下拉表单实例(从model到前端到post到form)
Mar 13 Python
jupyter notebook 多行输出实例
Apr 09 Python
Selenium关闭INFO:CONSOLE提示的解决
Dec 07 Python
Python机器学习工具scikit-learn的使用笔记
Jan 28 Python
python 批量压缩图片的脚本
Jun 02 Python
使用pandas库对csv文件进行筛选保存
May 25 #Python
pytorch中 gpu与gpu、gpu与cpu 在load时相互转化操作
May 25 #Python
基于pandas向csv添加新的行和列
May 25 #Python
Python如何把十进制数转换成ip地址
May 25 #Python
tensorflow模型转ncnn的操作方式
May 25 #Python
MxNet预训练模型到Pytorch模型的转换方式
May 25 #Python
浅谈pytorch 模型 .pt, .pth, .pkl的区别及模型保存方式
May 25 #Python
You might like
php基于curl实现随机ip地址抓取内容的方法
2016/10/11 PHP
PHP调用API接口实现天气查询功能的示例
2017/09/21 PHP
tp5.1 实现setInc字段自动加1
2019/10/18 PHP
JavaScript 编程引入命名空间的方法与代码
2007/08/13 Javascript
jquery ajax 检测用户注册时用户名是否存在
2009/11/03 Javascript
Javascript 面向对象(一)(共有方法,私有方法,特权方法)
2012/05/23 Javascript
在Javascript中 声明时用&quot;var&quot;与不用&quot;var&quot;的区别
2013/04/15 Javascript
javascript获取鼠标点击元素对象(示例代码)
2013/12/20 Javascript
jQuery获取checkboxlist的value值的方法
2015/09/27 Javascript
jQuery中hover与mouseover和mouseout的区别分析
2015/12/24 Javascript
JS中闭包的经典用法小结(2则示例)
2016/12/28 Javascript
JavaScript编写一个贪吃蛇游戏
2017/03/09 Javascript
JS 组件系列之BootstrapTable的treegrid功能
2017/06/16 Javascript
jq源码解析之绑在$,jQuery上面的方法(实例讲解)
2017/10/13 jQuery
仿淘宝JSsearch搜索下拉深度用法
2018/01/15 Javascript
在小程序中使用Echart图表的示例代码
2018/08/02 Javascript
微信小程序实现的五星评价功能示例
2019/04/25 Javascript
vue项目中引入Sass实例方法
2019/08/27 Javascript
微信小程序登录时如何获取input框中的内容
2019/12/04 Javascript
[52:26]完美世界DOTA2联赛决赛 FTD vs Phoenix 第一场 11.08
2020/11/11 DOTA
Python 冒泡,选择,插入排序使用实例
2015/02/05 Python
Python科学画图代码分享
2017/11/29 Python
解决Pandas to_json()中文乱码,转化为json数组的问题
2018/05/10 Python
Python爬虫 批量爬取下载抖音视频代码实例
2019/08/16 Python
django 通过url实现简单的权限控制的例子
2019/08/16 Python
python pandas利用fillna方法实现部分自动填充功能
2020/03/16 Python
手把手教你安装Windows版本的Tensorflow
2020/03/26 Python
Ellos丹麦:时尚和服装在线
2016/09/19 全球购物
乐天旅游香港网站:日本饭店预订
2017/11/29 全球购物
美国排名第一的泳池用品直接来源:In The Swim
2019/09/23 全球购物
能否解释一下XSS cookie盗窃是什么意思
2012/06/02 面试题
应届行政管理专业个人自我评价
2013/12/28 职场文书
篝火晚会策划方案
2014/05/16 职场文书
法人委托书范本
2014/09/15 职场文书
2016年“我们的节日·清明节”活动总结
2016/04/01 职场文书
MySQL系列之十 MySQL事务隔离实现并发控制
2021/07/02 MySQL