Pytorch通过保存为ONNX模型转TensorRT5的实现


Posted in Python onMay 25, 2020

1 Pytorch以ONNX方式保存模型

def saveONNX(model, filepath):
  '''
  保存ONNX模型
  :param model: 神经网络模型
  :param filepath: 文件保存路径
  '''
  
  # 神经网络输入数据类型
  dummy_input = torch.randn(self.config.BATCH_SIZE, 1, 28, 28, device='cuda')
  torch.onnx.export(model, dummy_input, filepath, verbose=True)

2 利用TensorRT5中ONNX解析器构建Engine

def ONNX_build_engine(onnx_file_path):
  '''
  通过加载onnx文件,构建engine
  :param onnx_file_path: onnx文件路径
  :return: engine
  '''
  # 打印日志
  G_LOGGER = trt.Logger(trt.Logger.WARNING)

  with trt.Builder(G_LOGGER) as builder, builder.create_network() as network, trt.OnnxParser(network, G_LOGGER) as parser:
   builder.max_batch_size = 100
   builder.max_workspace_size = 1 << 20

   print('Loading ONNX file from path {}...'.format(onnx_file_path))
   with open(onnx_file_path, 'rb') as model:
    print('Beginning ONNX file parsing')
    parser.parse(model.read())
   print('Completed parsing of ONNX file')

   print('Building an engine from file {}; this may take a while...'.format(onnx_file_path))
   engine = builder.build_cuda_engine(network)
   print("Completed creating Engine")

   # 保存计划文件
   # with open(engine_file_path, "wb") as f:
   #  f.write(engine.serialize())
   return engine

3 构建TensorRT运行引擎进行预测

def loadONNX2TensorRT(filepath):
  '''
  通过onnx文件,构建TensorRT运行引擎
  :param filepath: onnx文件路径
  '''
  # 计算开始时间
  Start = time()

  engine = self.ONNX_build_engine(filepath)

  # 读取测试集
  datas = DataLoaders()
  test_loader = datas.testDataLoader()
  img, target = next(iter(test_loader))
  img = img.numpy()
  target = target.numpy()

  img = img.ravel()

  context = engine.create_execution_context()
  output = np.empty((100, 10), dtype=np.float32)

  # 分配内存
  d_input = cuda.mem_alloc(1 * img.size * img.dtype.itemsize)
  d_output = cuda.mem_alloc(1 * output.size * output.dtype.itemsize)
  bindings = [int(d_input), int(d_output)]

  # pycuda操作缓冲区
  stream = cuda.Stream()
  # 将输入数据放入device
  cuda.memcpy_htod_async(d_input, img, stream)
  # 执行模型
  context.execute_async(100, bindings, stream.handle, None)
  # 将预测结果从从缓冲区取出
  cuda.memcpy_dtoh_async(output, d_output, stream)
  # 线程同步
  stream.synchronize()

  print("Test Case: " + str(target))
  print("Prediction: " + str(np.argmax(output, axis=1)))
  print("tensorrt time:", time() - Start)

  del context
  del engine

补充知识:Pytorch/Caffe可以先转换为ONNX,再转换为TensorRT

近来工作,试图把Pytorch用TensorRT运行。折腾了半天,没有完成。github中的转换代码,只能处理pytorch 0.2.0的功能(也明确表示不维护了)。和同事一起处理了很多例外,还是没有通过。吾以为,实际上即使勉强过了,能不能跑也是问题。

后来有高手建议,先转换为ONNX,再转换为TensorRT。这个思路基本可行。

是不是这样就万事大吉?当然不是,还是有严重问题要解决的。这只是个思路。

以上这篇Pytorch通过保存为ONNX模型转TensorRT5的实现就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python 文件读写操作实例详解
Mar 12 Python
Python中__init__.py文件的作用详解
Sep 18 Python
python递归打印某个目录的内容(实例讲解)
Aug 30 Python
Python中顺序表的实现简单代码分享
Jan 09 Python
Pyspider中给爬虫伪造随机请求头的实例
May 07 Python
不知道这5种下划线的含义,你就不算真的会Python!
Oct 09 Python
10分钟教你用Python实现微信自动回复功能
Nov 28 Python
python的concat等多种用法详解
Nov 28 Python
python实现画出e指数函数的图像
Nov 21 Python
Python气泡提示与标签的实现
Apr 01 Python
Python OpenCV读取中文路径图像的方法
Jul 02 Python
Python数据分析入门之教你怎么搭建环境
May 13 Python
tensorflow pb to tflite 精度下降详解
May 25 #Python
Python HTMLTestRunner测试报告view按钮失效解决方案
May 25 #Python
python用opencv完成图像分割并进行目标物的提取
May 25 #Python
Pytorch转tflite方式
May 25 #Python
Python HTMLTestRunner库安装过程解析
May 25 #Python
Anaconda+vscode+pytorch环境搭建过程详解
May 25 #Python
5行Python代码实现图像分割的步骤详解
May 25 #Python
You might like
php实现utf-8转unicode函数分享
2015/01/06 PHP
thinkPHP简单实现多个子查询语句的方法
2016/12/05 PHP
php插件Xajax使用方法详解
2017/08/31 PHP
PHP simplexml_import_dom()函数讲解
2019/02/03 PHP
分享几种好用的PHP自定义加密函数(可逆/不可逆)
2020/09/15 PHP
CL vs ForZe BO5 第一场 2.13
2021/03/10 DOTA
js+FSO遍历文件夹下文件并显示
2007/03/07 Javascript
Javascript常用字符串判断函数代码分享
2014/12/08 Javascript
jsMind通过鼠标拖拽的方式调整节点位置
2015/04/13 Javascript
详解JavaScript设计模式开发中的桥接模式使用
2016/05/18 Javascript
jQuery实现产品对比功能附源码下载
2016/08/09 Javascript
jQuery实现的放大镜效果示例
2016/09/13 Javascript
vue获取dom元素注意事项
2017/12/28 Javascript
微信小程序如何获取手机验证码
2018/11/04 Javascript
详解vue2.0 资源文件assets和static的区别
2018/11/27 Javascript
jquery实现手风琴案例
2020/05/04 jQuery
Python中的迭代器漫谈
2015/02/03 Python
python+django快速实现文件上传
2016/10/24 Python
Python正则表达式经典入门教程
2017/05/22 Python
python数据抓取分析的示例代码(python + mongodb)
2017/12/25 Python
Python numpy实现数组合并实例(vstack,hstack)
2018/01/09 Python
详谈python在windows中的文件路径问题
2018/04/28 Python
python通过http下载文件的方法详解
2019/07/26 Python
Keras自定义IOU方式
2020/06/10 Python
python归并排序算法过程实例讲解
2020/11/04 Python
python爬虫中url管理器去重操作实例
2020/11/30 Python
一文带你掌握Pyecharts地理数据可视化的方法
2021/02/06 Python
Skyscanner澳大利亚:全球领先的旅游搜索网站
2018/03/24 全球购物
C语言开发工程师测试题
2016/12/20 面试题
关联、聚合(Aggregation)以及组合(Composition)的区别
2012/02/29 面试题
毕业生个人的自我评价优秀范文
2013/10/03 职场文书
高级工程师岗位职责
2013/12/15 职场文书
新郎婚宴答谢词
2014/01/19 职场文书
会议主持人开场白台词
2015/05/28 职场文书
党员公开承诺书(2016最新版)
2016/03/24 职场文书
小程序实现文字循环滚动动画
2021/06/14 Javascript