Pytorch转onnx、torchscript方式


Posted in Python onMay 25, 2020

前言

本文将介绍如何使用ONNX将PyTorch中训练好的模型(.pt、.pth)型转换为ONNX格式,然后将其加载到Caffe2中。需要安装好onnx和Caffe2。

PyTorch及ONNX环境准备

为了正常运行ONNX,我们需要安装最新的Pytorch,你可以选择源码安装:

git clone --recursive https://github.com/pytorch/pytorch
cd pytorch
mkdir build && cd build
sudo cmake .. -DPYTHON_INCLUDE_DIR=/usr/include/python3.6 -DUSE_MPI=OFF
make install
export PYTHONPATH=$PYTHONPATH:/opt/pytorch/build

其中 "/opt/pytorch/build"是前面build pytorch的目。

or conda安装

conda install pytorch torchvision -c pytorch

安装ONNX的库

conda install -c conda-forge onnx

onnx-caffe2 安装

pip3 install onnx-caffe2

Pytorch模型转onnx

在PyTorch中导出模型通过跟踪工作。要导出模型,请调用torch.onnx.export()函数。这将执行模型,记录运算符用于计算输出的轨迹。因为_export运行模型,我们需要提供输入张量x。

这个张量的值并不重要; 它可以是图像或随机张量,只要它是正确的大小。更多详细信息,请查看torch.onnx documentation文档。

# 输入模型
example = torch.randn(batch_size, 1, 224, 224, requires_grad=True)

# 导出模型
torch_out = torch_out = torch.onnx.export(model, # model being run
    example, # model input (or a tuple for multiple inputs)
    "peleeNet.onnx",
 verbose=False, # store the trained parameter weights inside the model file
 training=False,
 do_constant_folding=True,
 input_names=['input'],
 output_names=['output'])

其中torch_out是执行模型后的输出,通常以忽略此输出。转换得到onnx后可以使用OpenCV的 cv::dnn::readNetFromONNX or cv::dnn::readNet进行模型加载推理了。

还可以进一步将onnx模型转换为ncnn进而部署到移动端。这就需要ncnn的onnx2ncnn工具了.

编译ncnn源码,生成 onnx2ncnn。

其中onnx转换模型时有一些冗余,可以使用用工具简化一些onnx模型。

pip3 install onnx-simplifier

简化onnx模型

python3 -m onnxsim pnet.onnx pnet-sim.onnx

转换成ncnn

onnx2ncnn pnet-sim.onnx pnet.param pnet.bin

ncnn 加载模型做推理

Pytorch模型转torch script

pytorch 加入libtorch前端处理,集体步骤为:

Pytorch转onnx、torchscript方式

以mtcnn pnet为例

# convert pytorch model to torch script
# An example input you would normally provide to your model's forward() method.
example = torch.rand(1, 3, 12, 12).to(device)
# Use torch.jit.trace to generate a torch.jit.ScriptModule via tracing.
traced_script_module = torch.jit.trace(pnet, example)
# Save traced model
traced_script_module.save("pnet_model_final.pth")

C++调用如下所示:

#include <torch/script.h> // One-stop header.
#include <iostream>
#include <memory>
int main(int argc, const char* argv[]) 
{
 if (argc != 2) 
 {
 std::cerr << "usage: example-app <path-to-exported-script-module>\n";
 return -1;
 }

 // Deserialize the ScriptModule from a file using torch::jit::load().
 std::shared_ptr<torch::jit::script::Module> module = torch::jit::load(argv[1]);

 assert(module != nullptr);
 std::cout << "ok\n";
}
Python 相关文章推荐
详解Python中for循环的使用方法
May 14 Python
python的keyword模块用法实例分析
Jun 30 Python
Python实现查找匹配项作处理后再替换回去的方法
Jun 10 Python
pandas对指定列进行填充的方法
Apr 11 Python
浅谈python之高阶函数和匿名函数
Mar 21 Python
django页面跳转问题及注意事项
Jul 18 Python
由面试题加深对Django的认识理解
Jul 19 Python
VScode连接远程服务器上的jupyter notebook的实现
Apr 23 Python
tensorflow 大于某个值为1,小于为0的实例
Jun 30 Python
python 实现简易的记事本
Nov 30 Python
tensorflow2.0教程之Keras快速入门
Feb 20 Python
Python基础学习之奇异的GUI对话框
May 27 Python
使用pandas库对csv文件进行筛选保存
May 25 #Python
pytorch中 gpu与gpu、gpu与cpu 在load时相互转化操作
May 25 #Python
基于pandas向csv添加新的行和列
May 25 #Python
Python如何把十进制数转换成ip地址
May 25 #Python
tensorflow模型转ncnn的操作方式
May 25 #Python
MxNet预训练模型到Pytorch模型的转换方式
May 25 #Python
浅谈pytorch 模型 .pt, .pth, .pkl的区别及模型保存方式
May 25 #Python
You might like
星际争霸, 教主第一视角, ZvT经典龙蛇演义
2020/03/02 星际争霸
浅谈电磁辐射对健康的影响
2021/03/01 无线电
PHP过滤★等特殊符号的正则
2014/01/27 PHP
PHP同时连接多个mysql数据库示例代码
2014/03/17 PHP
php实现面包屑导航例子分享
2015/12/19 PHP
Zend Studio使用技巧两则
2016/04/01 PHP
PHP实现支付宝即时到账功能
2016/12/21 PHP
PHP设计模式之抽象工厂模式实例分析
2019/03/25 PHP
使用javascript实现json数据以csv格式下载
2015/01/09 Javascript
JavaScript Window浏览器对象模型方法与属性汇总
2015/04/20 Javascript
javascript实现连续赋值
2015/08/10 Javascript
jQuery焦点图轮播效果实现方法
2016/12/19 Javascript
js和jquery中获取非行间样式
2017/05/05 jQuery
jQuery AJAX 方法success()后台传来的4种数据详解
2018/08/08 jQuery
微信小程序实现单选功能
2018/10/30 Javascript
使用VueCli3+TypeScript+Vuex一步步构建todoList的方法
2019/07/25 Javascript
基于Vue CSR的微前端实现方案实践
2020/05/27 Javascript
Element Cascader 级联选择器的使用示例
2020/07/27 Javascript
原生js实现表格翻页和跳转
2020/09/29 Javascript
Python 文件操作技巧(File operation) 实例代码分析
2008/08/11 Python
Python2.7下安装Scrapy框架步骤教程
2017/12/22 Python
Python之csv文件从MySQL数据库导入导出的方法
2018/06/21 Python
Python模块、包(Package)概念与用法分析
2019/05/31 Python
python实战串口助手_解决8串口多个发送的问题
2019/06/12 Python
python二进制文件的转译详解
2019/07/03 Python
使用 Python 合并多个格式一致的 Excel 文件(推荐)
2019/12/09 Python
移动Web—CSS为Retina屏幕替换更高质量的图片
2012/12/24 HTML / CSS
canvas画图被放大且模糊的解决方法
2020/08/11 HTML / CSS
孕妇装中的著名品牌:Isabella Oliver(伊莎贝拉·奥利弗)
2016/10/31 全球购物
static全局变量与普通的全局变量有什么区别
2014/05/27 面试题
国际金融专业自荐信
2014/07/05 职场文书
开票员岗位职责
2015/02/12 职场文书
信息技术研修心得体会
2016/01/08 职场文书
投资入股协议书
2016/03/22 职场文书
实用求职信模板范文
2019/05/13 职场文书
导游词之茶卡盐湖
2019/11/26 职场文书