Pytorch通过保存为ONNX模型转TensorRT5的实现


Posted in Python onMay 25, 2020

1 Pytorch以ONNX方式保存模型

def saveONNX(model, filepath):
  '''
  保存ONNX模型
  :param model: 神经网络模型
  :param filepath: 文件保存路径
  '''
  
  # 神经网络输入数据类型
  dummy_input = torch.randn(self.config.BATCH_SIZE, 1, 28, 28, device='cuda')
  torch.onnx.export(model, dummy_input, filepath, verbose=True)

2 利用TensorRT5中ONNX解析器构建Engine

def ONNX_build_engine(onnx_file_path):
  '''
  通过加载onnx文件,构建engine
  :param onnx_file_path: onnx文件路径
  :return: engine
  '''
  # 打印日志
  G_LOGGER = trt.Logger(trt.Logger.WARNING)

  with trt.Builder(G_LOGGER) as builder, builder.create_network() as network, trt.OnnxParser(network, G_LOGGER) as parser:
   builder.max_batch_size = 100
   builder.max_workspace_size = 1 << 20

   print('Loading ONNX file from path {}...'.format(onnx_file_path))
   with open(onnx_file_path, 'rb') as model:
    print('Beginning ONNX file parsing')
    parser.parse(model.read())
   print('Completed parsing of ONNX file')

   print('Building an engine from file {}; this may take a while...'.format(onnx_file_path))
   engine = builder.build_cuda_engine(network)
   print("Completed creating Engine")

   # 保存计划文件
   # with open(engine_file_path, "wb") as f:
   #  f.write(engine.serialize())
   return engine

3 构建TensorRT运行引擎进行预测

def loadONNX2TensorRT(filepath):
  '''
  通过onnx文件,构建TensorRT运行引擎
  :param filepath: onnx文件路径
  '''
  # 计算开始时间
  Start = time()

  engine = self.ONNX_build_engine(filepath)

  # 读取测试集
  datas = DataLoaders()
  test_loader = datas.testDataLoader()
  img, target = next(iter(test_loader))
  img = img.numpy()
  target = target.numpy()

  img = img.ravel()

  context = engine.create_execution_context()
  output = np.empty((100, 10), dtype=np.float32)

  # 分配内存
  d_input = cuda.mem_alloc(1 * img.size * img.dtype.itemsize)
  d_output = cuda.mem_alloc(1 * output.size * output.dtype.itemsize)
  bindings = [int(d_input), int(d_output)]

  # pycuda操作缓冲区
  stream = cuda.Stream()
  # 将输入数据放入device
  cuda.memcpy_htod_async(d_input, img, stream)
  # 执行模型
  context.execute_async(100, bindings, stream.handle, None)
  # 将预测结果从从缓冲区取出
  cuda.memcpy_dtoh_async(output, d_output, stream)
  # 线程同步
  stream.synchronize()

  print("Test Case: " + str(target))
  print("Prediction: " + str(np.argmax(output, axis=1)))
  print("tensorrt time:", time() - Start)

  del context
  del engine

补充知识:Pytorch/Caffe可以先转换为ONNX,再转换为TensorRT

近来工作,试图把Pytorch用TensorRT运行。折腾了半天,没有完成。github中的转换代码,只能处理pytorch 0.2.0的功能(也明确表示不维护了)。和同事一起处理了很多例外,还是没有通过。吾以为,实际上即使勉强过了,能不能跑也是问题。

后来有高手建议,先转换为ONNX,再转换为TensorRT。这个思路基本可行。

是不是这样就万事大吉?当然不是,还是有严重问题要解决的。这只是个思路。

以上这篇Pytorch通过保存为ONNX模型转TensorRT5的实现就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python之wxPython菜单使用详解
Sep 28 Python
Python脚本实现格式化css文件
Apr 08 Python
python基于phantomjs实现导入图片
May 13 Python
python matplotlib坐标轴设置的方法
Dec 05 Python
Python处理CSV与List的转换方法
Apr 19 Python
python+unittest+requests实现接口自动化的方法
Nov 29 Python
Django中ajax发送post请求 报403错误CSRF验证失败解决方案
Aug 13 Python
Python基本语法之运算符功能与用法详解
Oct 22 Python
pytorch随机采样操作SubsetRandomSampler()
Jul 07 Python
python3跳出一个循环的实例操作
Aug 18 Python
Python lambda表达式原理及用法解析
Aug 18 Python
Pytorch如何切换 cpu和gpu的使用详解
Mar 01 Python
tensorflow pb to tflite 精度下降详解
May 25 #Python
Python HTMLTestRunner测试报告view按钮失效解决方案
May 25 #Python
python用opencv完成图像分割并进行目标物的提取
May 25 #Python
Pytorch转tflite方式
May 25 #Python
Python HTMLTestRunner库安装过程解析
May 25 #Python
Anaconda+vscode+pytorch环境搭建过程详解
May 25 #Python
5行Python代码实现图像分割的步骤详解
May 25 #Python
You might like
php 归并排序 数组交集
2011/05/10 PHP
php中file_get_content 和curl以及fopen 效率分析
2014/09/19 PHP
微信自定义菜单的处理开发示例
2015/04/16 PHP
Windows下php+mysql5.7配置教程
2017/05/16 PHP
Laravel 5.5官方推荐的Nginx配置学习教程
2017/10/06 PHP
JavaScript 高级语法介绍
2009/06/15 Javascript
javascript 客户端验证上传图片的大小(兼容IE和火狐)
2009/08/15 Javascript
jQuery 点击图片跳转上一张或下一张功能的实现代码
2010/03/12 Javascript
基于jquery的内容循环滚动小模块(仿新浪微博未登录首页滚动微博显示)
2011/03/28 Javascript
基于jquery的跟随屏幕滚动代码
2012/07/24 Javascript
使用typeof判断function是否存在于上下文
2014/08/14 Javascript
IE下通过a实现location.href 获取referer的值
2014/09/04 Javascript
完美兼容多浏览器的js判断图片路径代码汇总
2015/04/17 Javascript
十个免费的web前端开发工具详细整理
2017/09/18 Javascript
JavaScript正则表达式函数总结(常用)
2018/02/22 Javascript
Vue Router去掉url中默认的锚点#
2018/08/01 Javascript
微信小程序之swiper滑动面板用法示例
2018/12/04 Javascript
Nodejs中获取当前函数被调用的行数及文件名详解
2018/12/12 NodeJs
图文详解vue框架安装步骤
2019/02/12 Javascript
JS实现移动端在线签协议功能
2019/08/22 Javascript
在vue中使用vant TreeSelect分类选择组件操作
2020/11/02 Javascript
[01:51]2014DOTA2西雅图邀请赛 MVP 外卡赛black场间采访
2014/07/09 DOTA
python类型强制转换long to int的代码
2013/02/10 Python
基于python(urlparse)模板的使用方法总结
2017/10/13 Python
python使用Apriori算法进行关联性解析
2017/12/21 Python
python选取特定列 pandas iloc,loc,icol的使用详解(列切片及行切片)
2019/08/06 Python
Python中类似于jquery的pyquery库用法分析
2019/12/02 Python
Python3标准库之dbm UNIX键-值数据库问题
2020/03/24 Python
用HTML5 实现橡皮擦的涂抹效果的教程
2015/05/11 HTML / CSS
财务担保书范文
2014/04/02 职场文书
《长江之歌》教学反思
2014/04/17 职场文书
讲座新闻稿
2015/07/18 职场文书
公司环境卫生管理制度
2015/08/05 职场文书
安全生产标语口号
2015/12/26 职场文书
MySQL七大JOIN的具体使用
2022/02/28 MySQL
Python中的 enumerate和zip详情
2022/05/30 Python