Pytorch通过保存为ONNX模型转TensorRT5的实现


Posted in Python onMay 25, 2020

1 Pytorch以ONNX方式保存模型

def saveONNX(model, filepath):
  '''
  保存ONNX模型
  :param model: 神经网络模型
  :param filepath: 文件保存路径
  '''
  
  # 神经网络输入数据类型
  dummy_input = torch.randn(self.config.BATCH_SIZE, 1, 28, 28, device='cuda')
  torch.onnx.export(model, dummy_input, filepath, verbose=True)

2 利用TensorRT5中ONNX解析器构建Engine

def ONNX_build_engine(onnx_file_path):
  '''
  通过加载onnx文件,构建engine
  :param onnx_file_path: onnx文件路径
  :return: engine
  '''
  # 打印日志
  G_LOGGER = trt.Logger(trt.Logger.WARNING)

  with trt.Builder(G_LOGGER) as builder, builder.create_network() as network, trt.OnnxParser(network, G_LOGGER) as parser:
   builder.max_batch_size = 100
   builder.max_workspace_size = 1 << 20

   print('Loading ONNX file from path {}...'.format(onnx_file_path))
   with open(onnx_file_path, 'rb') as model:
    print('Beginning ONNX file parsing')
    parser.parse(model.read())
   print('Completed parsing of ONNX file')

   print('Building an engine from file {}; this may take a while...'.format(onnx_file_path))
   engine = builder.build_cuda_engine(network)
   print("Completed creating Engine")

   # 保存计划文件
   # with open(engine_file_path, "wb") as f:
   #  f.write(engine.serialize())
   return engine

3 构建TensorRT运行引擎进行预测

def loadONNX2TensorRT(filepath):
  '''
  通过onnx文件,构建TensorRT运行引擎
  :param filepath: onnx文件路径
  '''
  # 计算开始时间
  Start = time()

  engine = self.ONNX_build_engine(filepath)

  # 读取测试集
  datas = DataLoaders()
  test_loader = datas.testDataLoader()
  img, target = next(iter(test_loader))
  img = img.numpy()
  target = target.numpy()

  img = img.ravel()

  context = engine.create_execution_context()
  output = np.empty((100, 10), dtype=np.float32)

  # 分配内存
  d_input = cuda.mem_alloc(1 * img.size * img.dtype.itemsize)
  d_output = cuda.mem_alloc(1 * output.size * output.dtype.itemsize)
  bindings = [int(d_input), int(d_output)]

  # pycuda操作缓冲区
  stream = cuda.Stream()
  # 将输入数据放入device
  cuda.memcpy_htod_async(d_input, img, stream)
  # 执行模型
  context.execute_async(100, bindings, stream.handle, None)
  # 将预测结果从从缓冲区取出
  cuda.memcpy_dtoh_async(output, d_output, stream)
  # 线程同步
  stream.synchronize()

  print("Test Case: " + str(target))
  print("Prediction: " + str(np.argmax(output, axis=1)))
  print("tensorrt time:", time() - Start)

  del context
  del engine

补充知识:Pytorch/Caffe可以先转换为ONNX,再转换为TensorRT

近来工作,试图把Pytorch用TensorRT运行。折腾了半天,没有完成。github中的转换代码,只能处理pytorch 0.2.0的功能(也明确表示不维护了)。和同事一起处理了很多例外,还是没有通过。吾以为,实际上即使勉强过了,能不能跑也是问题。

后来有高手建议,先转换为ONNX,再转换为TensorRT。这个思路基本可行。

是不是这样就万事大吉?当然不是,还是有严重问题要解决的。这只是个思路。

以上这篇Pytorch通过保存为ONNX模型转TensorRT5的实现就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python利用pyHook实现监听用户鼠标与键盘事件
Aug 21 Python
线程和进程的区别及Python代码实例
Feb 04 Python
Python字符串匹配算法KMP实例
Jul 18 Python
Python 私有函数的实例详解
Sep 11 Python
python虚拟环境的安装配置图文教程
Oct 20 Python
Python计算一个给定时间点前一个月和后一个月第一天的方法
May 29 Python
python 使用re.search()筛选后 选取部分结果的方法
Nov 28 Python
python多线程下信号处理程序示例
May 31 Python
Python Selenium参数配置方法解析
Jan 19 Python
Python bytes string相互转换过程解析
Mar 05 Python
使用python-cv2实现视频的分解与合成的示例代码
Oct 26 Python
只用40行Python代码就能写出pdf转word小工具
May 31 Python
tensorflow pb to tflite 精度下降详解
May 25 #Python
Python HTMLTestRunner测试报告view按钮失效解决方案
May 25 #Python
python用opencv完成图像分割并进行目标物的提取
May 25 #Python
Pytorch转tflite方式
May 25 #Python
Python HTMLTestRunner库安装过程解析
May 25 #Python
Anaconda+vscode+pytorch环境搭建过程详解
May 25 #Python
5行Python代码实现图像分割的步骤详解
May 25 #Python
You might like
PHP mail 通过Windows的SMTP发送邮件失败的解决方案
2009/05/27 PHP
smarty内置函数foreach用法实例
2015/01/22 PHP
php session_decode函数用法讲解
2019/05/26 PHP
把html页面的部分内容保存成新的html文件的jquery代码
2009/11/12 Javascript
JavaScript中九种常用排序算法
2014/09/02 Javascript
重写document.write实现无阻塞加载js广告(补充)
2014/12/12 Javascript
Jquery1.9.1源码分析系列(六)延时对象应用之jQuery.ready
2015/11/24 Javascript
详解Javascript事件驱动编程
2016/01/03 Javascript
Bootstrap carousel轮转图的使用实例详解
2016/05/17 Javascript
基于node.js依赖express解析post请求四种数据格式
2017/02/13 Javascript
JavaScript数据结构中串的表示与应用实例
2017/04/12 Javascript
Angular 4环境准备与Angular cli创建项目详解
2017/05/27 Javascript
详解Vue.js分发之作用域槽
2017/06/13 Javascript
详解原生js实现offset方法
2017/06/15 Javascript
深入理解vuex2.0 之 modules
2017/11/20 Javascript
vue2.0 computed 计算list循环后累加值的实例
2018/03/07 Javascript
vue.js template模板的使用(仿饿了么布局)
2018/08/13 Javascript
详解js模板引擎art template数组渲染的方法
2018/10/09 Javascript
原生JS实现手动轮播图效果实例代码
2018/11/22 Javascript
OpenLayers3实现测量功能
2020/09/25 Javascript
flask使用session保存登录状态及拦截未登录请求代码
2018/01/19 Python
python学生信息管理系统
2018/03/13 Python
基于python实现名片管理系统
2018/11/30 Python
Python socket实现多对多全双工通信的方法
2019/02/13 Python
python循环嵌套的多种使用方法解析
2019/11/29 Python
pyinstaller打包成无控制台程序时运行出错(与popen冲突的解决方法)
2020/04/15 Python
matplotlib基础绘图命令之bar的使用方法
2020/08/13 Python
CSS3 please 跨浏览器的CSS3产生器
2010/03/14 HTML / CSS
初中三年学生的学习自我评价
2013/11/13 职场文书
应届生自我鉴定
2013/12/11 职场文书
酒店副总岗位职责
2013/12/24 职场文书
舞蹈专业大学生职业规划范文
2014/03/12 职场文书
生日宴会策划方案
2014/06/03 职场文书
2015年科室工作总结
2015/04/10 职场文书
pygame面向对象的飞行小鸟实现(Flappy bird)
2021/04/01 Python
SQL实现LeetCode(178.分数排行)
2021/08/04 MySQL