Pytorch通过保存为ONNX模型转TensorRT5的实现


Posted in Python onMay 25, 2020

1 Pytorch以ONNX方式保存模型

def saveONNX(model, filepath):
  '''
  保存ONNX模型
  :param model: 神经网络模型
  :param filepath: 文件保存路径
  '''
  
  # 神经网络输入数据类型
  dummy_input = torch.randn(self.config.BATCH_SIZE, 1, 28, 28, device='cuda')
  torch.onnx.export(model, dummy_input, filepath, verbose=True)

2 利用TensorRT5中ONNX解析器构建Engine

def ONNX_build_engine(onnx_file_path):
  '''
  通过加载onnx文件,构建engine
  :param onnx_file_path: onnx文件路径
  :return: engine
  '''
  # 打印日志
  G_LOGGER = trt.Logger(trt.Logger.WARNING)

  with trt.Builder(G_LOGGER) as builder, builder.create_network() as network, trt.OnnxParser(network, G_LOGGER) as parser:
   builder.max_batch_size = 100
   builder.max_workspace_size = 1 << 20

   print('Loading ONNX file from path {}...'.format(onnx_file_path))
   with open(onnx_file_path, 'rb') as model:
    print('Beginning ONNX file parsing')
    parser.parse(model.read())
   print('Completed parsing of ONNX file')

   print('Building an engine from file {}; this may take a while...'.format(onnx_file_path))
   engine = builder.build_cuda_engine(network)
   print("Completed creating Engine")

   # 保存计划文件
   # with open(engine_file_path, "wb") as f:
   #  f.write(engine.serialize())
   return engine

3 构建TensorRT运行引擎进行预测

def loadONNX2TensorRT(filepath):
  '''
  通过onnx文件,构建TensorRT运行引擎
  :param filepath: onnx文件路径
  '''
  # 计算开始时间
  Start = time()

  engine = self.ONNX_build_engine(filepath)

  # 读取测试集
  datas = DataLoaders()
  test_loader = datas.testDataLoader()
  img, target = next(iter(test_loader))
  img = img.numpy()
  target = target.numpy()

  img = img.ravel()

  context = engine.create_execution_context()
  output = np.empty((100, 10), dtype=np.float32)

  # 分配内存
  d_input = cuda.mem_alloc(1 * img.size * img.dtype.itemsize)
  d_output = cuda.mem_alloc(1 * output.size * output.dtype.itemsize)
  bindings = [int(d_input), int(d_output)]

  # pycuda操作缓冲区
  stream = cuda.Stream()
  # 将输入数据放入device
  cuda.memcpy_htod_async(d_input, img, stream)
  # 执行模型
  context.execute_async(100, bindings, stream.handle, None)
  # 将预测结果从从缓冲区取出
  cuda.memcpy_dtoh_async(output, d_output, stream)
  # 线程同步
  stream.synchronize()

  print("Test Case: " + str(target))
  print("Prediction: " + str(np.argmax(output, axis=1)))
  print("tensorrt time:", time() - Start)

  del context
  del engine

补充知识:Pytorch/Caffe可以先转换为ONNX,再转换为TensorRT

近来工作,试图把Pytorch用TensorRT运行。折腾了半天,没有完成。github中的转换代码,只能处理pytorch 0.2.0的功能(也明确表示不维护了)。和同事一起处理了很多例外,还是没有通过。吾以为,实际上即使勉强过了,能不能跑也是问题。

后来有高手建议,先转换为ONNX,再转换为TensorRT。这个思路基本可行。

是不是这样就万事大吉?当然不是,还是有严重问题要解决的。这只是个思路。

以上这篇Pytorch通过保存为ONNX模型转TensorRT5的实现就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
浅谈python中截取字符函数strip,lstrip,rstrip
Jul 17 Python
浅谈用VSCode写python的正确姿势
Dec 16 Python
PyQt5实现拖放功能
Apr 25 Python
python实现俄罗斯方块
Jun 26 Python
PyTorch 1.0 正式版已经发布了
Dec 13 Python
分享Python切分字符串的一个不错方法
Dec 14 Python
Python爬虫设置代理IP(图文)
Dec 23 Python
opencv之为图像添加边界的方法示例
Dec 26 Python
python3 xpath和requests应用详解
Mar 06 Python
python中return如何写
Jun 18 Python
如何表示python中的相对路径
Jul 08 Python
Pycharm Available Package无法显示/安装包的问题Error Loading Package List解决
Sep 18 Python
tensorflow pb to tflite 精度下降详解
May 25 #Python
Python HTMLTestRunner测试报告view按钮失效解决方案
May 25 #Python
python用opencv完成图像分割并进行目标物的提取
May 25 #Python
Pytorch转tflite方式
May 25 #Python
Python HTMLTestRunner库安装过程解析
May 25 #Python
Anaconda+vscode+pytorch环境搭建过程详解
May 25 #Python
5行Python代码实现图像分割的步骤详解
May 25 #Python
You might like
php轻松实现中英文混排字符串截取
2014/05/28 PHP
Parse正式发布开源PHP SDK
2014/08/11 PHP
php动态变量定义及使用
2015/06/10 PHP
PHP实现QQ登录实例代码
2016/01/14 PHP
php常用图片处理类
2016/03/16 PHP
PHP入门教程之面向对象的特性分析(继承,多态,接口,抽象类,抽象方法等)
2016/09/11 PHP
浅析PHP类的反射来实现依赖注入过程
2018/02/06 PHP
php使用curl获取header检测开启GZip压缩的方法
2018/08/15 PHP
List all the Databases on a SQL Server
2007/06/21 Javascript
editable.js 基于jquery的表格的编辑插件
2011/10/24 Javascript
JQuery给元素绑定click事件多次执行的解决方法
2014/05/29 Javascript
jQuery中attr()方法用法实例
2015/01/05 Javascript
轻松学习jQuery插件EasyUI EasyUI创建树形网络(1)
2015/11/30 Javascript
jQuery命名空间与闭包用法示例
2017/01/12 Javascript
css配合JavaScript实现tab标签切换效果
2018/10/11 Javascript
axios如何取消重复无用的请求详解
2019/12/15 Javascript
vue+iview使用树形控件的具体使用
2020/11/02 Javascript
基于vuex实现购物车功能
2021/01/10 Vue.js
wxpython 最小化到托盘与欢迎图片的实现方法
2014/06/09 Python
17个Python小技巧分享
2015/01/23 Python
Python实现处理管道的方法
2015/06/04 Python
Python实现比较两个文件夹中代码变化的方法
2015/07/10 Python
Python模糊查询本地文件夹去除文件后缀的实例(7行代码)
2017/11/09 Python
python判断所输入的任意一个正整数是否为素数的两种方法
2019/06/27 Python
Python如何使用argparse模块处理命令行参数
2019/12/11 Python
TensorBoard 计算图的可视化实现
2020/02/15 Python
Keras实现DenseNet结构操作
2020/07/06 Python
pip已经安装好第三方库但pycharm中import时还是标红的解决方案
2020/10/09 Python
图解CSS3制作圆环形进度条的实例教程
2016/05/26 HTML / CSS
Html5 FileReader实现即时上传图片功能实例代码
2014/09/01 HTML / CSS
普通院校学生的自荐信
2013/11/27 职场文书
施工资料员的岗位职责
2013/12/22 职场文书
探亲邀请信范文
2014/01/30 职场文书
酒店保安员岗位职责
2014/01/31 职场文书
幼儿园大班区域活动总结
2014/07/09 职场文书
SQL写法--行行比较
2021/08/23 SQL Server