Pytorch通过保存为ONNX模型转TensorRT5的实现


Posted in Python onMay 25, 2020

1 Pytorch以ONNX方式保存模型

def saveONNX(model, filepath):
  '''
  保存ONNX模型
  :param model: 神经网络模型
  :param filepath: 文件保存路径
  '''
  
  # 神经网络输入数据类型
  dummy_input = torch.randn(self.config.BATCH_SIZE, 1, 28, 28, device='cuda')
  torch.onnx.export(model, dummy_input, filepath, verbose=True)

2 利用TensorRT5中ONNX解析器构建Engine

def ONNX_build_engine(onnx_file_path):
  '''
  通过加载onnx文件,构建engine
  :param onnx_file_path: onnx文件路径
  :return: engine
  '''
  # 打印日志
  G_LOGGER = trt.Logger(trt.Logger.WARNING)

  with trt.Builder(G_LOGGER) as builder, builder.create_network() as network, trt.OnnxParser(network, G_LOGGER) as parser:
   builder.max_batch_size = 100
   builder.max_workspace_size = 1 << 20

   print('Loading ONNX file from path {}...'.format(onnx_file_path))
   with open(onnx_file_path, 'rb') as model:
    print('Beginning ONNX file parsing')
    parser.parse(model.read())
   print('Completed parsing of ONNX file')

   print('Building an engine from file {}; this may take a while...'.format(onnx_file_path))
   engine = builder.build_cuda_engine(network)
   print("Completed creating Engine")

   # 保存计划文件
   # with open(engine_file_path, "wb") as f:
   #  f.write(engine.serialize())
   return engine

3 构建TensorRT运行引擎进行预测

def loadONNX2TensorRT(filepath):
  '''
  通过onnx文件,构建TensorRT运行引擎
  :param filepath: onnx文件路径
  '''
  # 计算开始时间
  Start = time()

  engine = self.ONNX_build_engine(filepath)

  # 读取测试集
  datas = DataLoaders()
  test_loader = datas.testDataLoader()
  img, target = next(iter(test_loader))
  img = img.numpy()
  target = target.numpy()

  img = img.ravel()

  context = engine.create_execution_context()
  output = np.empty((100, 10), dtype=np.float32)

  # 分配内存
  d_input = cuda.mem_alloc(1 * img.size * img.dtype.itemsize)
  d_output = cuda.mem_alloc(1 * output.size * output.dtype.itemsize)
  bindings = [int(d_input), int(d_output)]

  # pycuda操作缓冲区
  stream = cuda.Stream()
  # 将输入数据放入device
  cuda.memcpy_htod_async(d_input, img, stream)
  # 执行模型
  context.execute_async(100, bindings, stream.handle, None)
  # 将预测结果从从缓冲区取出
  cuda.memcpy_dtoh_async(output, d_output, stream)
  # 线程同步
  stream.synchronize()

  print("Test Case: " + str(target))
  print("Prediction: " + str(np.argmax(output, axis=1)))
  print("tensorrt time:", time() - Start)

  del context
  del engine

补充知识:Pytorch/Caffe可以先转换为ONNX,再转换为TensorRT

近来工作,试图把Pytorch用TensorRT运行。折腾了半天,没有完成。github中的转换代码,只能处理pytorch 0.2.0的功能(也明确表示不维护了)。和同事一起处理了很多例外,还是没有通过。吾以为,实际上即使勉强过了,能不能跑也是问题。

后来有高手建议,先转换为ONNX,再转换为TensorRT。这个思路基本可行。

是不是这样就万事大吉?当然不是,还是有严重问题要解决的。这只是个思路。

以上这篇Pytorch通过保存为ONNX模型转TensorRT5的实现就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python实现按行切分文本文件的方法
Apr 18 Python
如何用itertools解决无序排列组合的问题
May 18 Python
Python3实现发送QQ邮件功能(附件)
Dec 23 Python
python实现抽奖小程序
Apr 15 Python
django框架自定义模板标签(template tag)操作示例
Jun 24 Python
pandas.cut具体使用总结
Jun 24 Python
Python 数据可视化pyecharts的使用详解
Jun 26 Python
python读csv文件时指定行为表头或无表头的方法
Jun 26 Python
python已协程方式处理任务实现过程
Dec 27 Python
解决jupyter notebook import error但是命令提示符import正常的问题
Apr 15 Python
Django websocket原理及功能实现代码
Nov 14 Python
python 管理系统实现mysql交互的示例代码
Dec 06 Python
tensorflow pb to tflite 精度下降详解
May 25 #Python
Python HTMLTestRunner测试报告view按钮失效解决方案
May 25 #Python
python用opencv完成图像分割并进行目标物的提取
May 25 #Python
Pytorch转tflite方式
May 25 #Python
Python HTMLTestRunner库安装过程解析
May 25 #Python
Anaconda+vscode+pytorch环境搭建过程详解
May 25 #Python
5行Python代码实现图像分割的步骤详解
May 25 #Python
You might like
php数组随机排序实现方法
2015/06/13 PHP
php实现搜索一维数组元素并删除二维数组对应元素的方法
2015/07/06 PHP
thinkphp 中的volist标签在ajax操作中的特殊性(推荐)
2018/01/15 PHP
windows 2008r2+php5.6.28环境搭建详细过程
2019/06/18 PHP
Gambit vs ForZe BO3 第三场 2.13
2021/03/10 DOTA
Javascript的IE和Firefox兼容性汇编
2006/07/01 Javascript
jQuery EasyUI API 中文文档 - MenuButton菜单按钮使用介绍
2011/10/06 Javascript
一个页面元素appendchild追加到另一个页面元素的问题
2013/01/27 Javascript
网页中表单按回车就自动提交的问题的解决方案
2014/11/03 Javascript
让angularjs支持浏览器自动填表
2014/11/10 Javascript
javascript获取网页宽高方法汇总
2015/07/19 Javascript
JavaScript与java语言有什么不同
2016/09/22 Javascript
基于JavaScript实现报警器提示音效果
2017/10/27 Javascript
vue解决使用webpack打包后keep-alive不生效的方法
2018/09/01 Javascript
Angular之jwt令牌身份验证的实现
2020/02/14 Javascript
vue-axios同时请求多个接口 等所有接口全部加载完成再处理操作
2020/11/09 Javascript
[04:19]DOTA2亚洲邀请赛 现场花絮
2015/03/11 DOTA
使用python 获取进程pid号的方法
2014/03/10 Python
Python素数检测实例分析
2015/06/15 Python
Python中的字符串类型基本知识学习教程
2016/02/04 Python
Python异常处理知识点总结
2019/02/18 Python
python验证码图片处理(二值化)
2019/11/01 Python
Django将默认的SQLite更换为MySQL的实现
2019/11/18 Python
基于python实现微信好友数据分析(简单)
2020/02/16 Python
python中如何打包用户自定义模块
2020/09/23 Python
关于HTML5 Placeholder新标签低版本浏览器下不兼容的问题分析及解决办法
2016/01/27 HTML / CSS
美国第一大药店连锁机构:Walgreens(沃尔格林)
2019/10/10 全球购物
香港艺人陈冠希创办的潮流品牌:JUICESTORE
2021/03/04 全球购物
医学专业大学生求职的自我评价
2013/11/27 职场文书
读群众路线心得体会
2014/03/07 职场文书
班主任个人工作反思
2014/04/28 职场文书
写字楼租赁意向书
2014/07/30 职场文书
2014最新毕业证代领委托书
2014/09/26 职场文书
500字作文之关于爸爸
2019/11/14 职场文书
php引用传递
2021/04/01 PHP
python自动化调用百度api解决验证码
2021/04/13 Python