Pytorch通过保存为ONNX模型转TensorRT5的实现


Posted in Python onMay 25, 2020

1 Pytorch以ONNX方式保存模型

def saveONNX(model, filepath):
  '''
  保存ONNX模型
  :param model: 神经网络模型
  :param filepath: 文件保存路径
  '''
  
  # 神经网络输入数据类型
  dummy_input = torch.randn(self.config.BATCH_SIZE, 1, 28, 28, device='cuda')
  torch.onnx.export(model, dummy_input, filepath, verbose=True)

2 利用TensorRT5中ONNX解析器构建Engine

def ONNX_build_engine(onnx_file_path):
  '''
  通过加载onnx文件,构建engine
  :param onnx_file_path: onnx文件路径
  :return: engine
  '''
  # 打印日志
  G_LOGGER = trt.Logger(trt.Logger.WARNING)

  with trt.Builder(G_LOGGER) as builder, builder.create_network() as network, trt.OnnxParser(network, G_LOGGER) as parser:
   builder.max_batch_size = 100
   builder.max_workspace_size = 1 << 20

   print('Loading ONNX file from path {}...'.format(onnx_file_path))
   with open(onnx_file_path, 'rb') as model:
    print('Beginning ONNX file parsing')
    parser.parse(model.read())
   print('Completed parsing of ONNX file')

   print('Building an engine from file {}; this may take a while...'.format(onnx_file_path))
   engine = builder.build_cuda_engine(network)
   print("Completed creating Engine")

   # 保存计划文件
   # with open(engine_file_path, "wb") as f:
   #  f.write(engine.serialize())
   return engine

3 构建TensorRT运行引擎进行预测

def loadONNX2TensorRT(filepath):
  '''
  通过onnx文件,构建TensorRT运行引擎
  :param filepath: onnx文件路径
  '''
  # 计算开始时间
  Start = time()

  engine = self.ONNX_build_engine(filepath)

  # 读取测试集
  datas = DataLoaders()
  test_loader = datas.testDataLoader()
  img, target = next(iter(test_loader))
  img = img.numpy()
  target = target.numpy()

  img = img.ravel()

  context = engine.create_execution_context()
  output = np.empty((100, 10), dtype=np.float32)

  # 分配内存
  d_input = cuda.mem_alloc(1 * img.size * img.dtype.itemsize)
  d_output = cuda.mem_alloc(1 * output.size * output.dtype.itemsize)
  bindings = [int(d_input), int(d_output)]

  # pycuda操作缓冲区
  stream = cuda.Stream()
  # 将输入数据放入device
  cuda.memcpy_htod_async(d_input, img, stream)
  # 执行模型
  context.execute_async(100, bindings, stream.handle, None)
  # 将预测结果从从缓冲区取出
  cuda.memcpy_dtoh_async(output, d_output, stream)
  # 线程同步
  stream.synchronize()

  print("Test Case: " + str(target))
  print("Prediction: " + str(np.argmax(output, axis=1)))
  print("tensorrt time:", time() - Start)

  del context
  del engine

补充知识:Pytorch/Caffe可以先转换为ONNX,再转换为TensorRT

近来工作,试图把Pytorch用TensorRT运行。折腾了半天,没有完成。github中的转换代码,只能处理pytorch 0.2.0的功能(也明确表示不维护了)。和同事一起处理了很多例外,还是没有通过。吾以为,实际上即使勉强过了,能不能跑也是问题。

后来有高手建议,先转换为ONNX,再转换为TensorRT。这个思路基本可行。

是不是这样就万事大吉?当然不是,还是有严重问题要解决的。这只是个思路。

以上这篇Pytorch通过保存为ONNX模型转TensorRT5的实现就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
教你如何将 Sublime 3 打造成 Python/Django IDE开发利器
Jul 04 Python
python输出当前目录下index.html文件路径的方法
Apr 28 Python
Python编程中的文件读写及相关的文件对象方法讲解
Jan 19 Python
Python正则表达式知识汇总
Sep 22 Python
Django实现一对多表模型的跨表查询方法
Dec 18 Python
python读取有密码的zip压缩文件实例
Feb 08 Python
Python快速转换numpy数组中Nan和Inf的方法实例说明
Feb 21 Python
Python 内置函数globals()和locals()对比详解
Dec 23 Python
Python字符编码转码之GBK,UTF8互转
Feb 09 Python
使用pymysql查询数据库,把结果保存为列表并获取指定元素下标实例
May 15 Python
python 实现rolling和apply函数的向下取值操作
Jun 08 Python
Python抓包并解析json爬虫的完整实例代码
Nov 03 Python
tensorflow pb to tflite 精度下降详解
May 25 #Python
Python HTMLTestRunner测试报告view按钮失效解决方案
May 25 #Python
python用opencv完成图像分割并进行目标物的提取
May 25 #Python
Pytorch转tflite方式
May 25 #Python
Python HTMLTestRunner库安装过程解析
May 25 #Python
Anaconda+vscode+pytorch环境搭建过程详解
May 25 #Python
5行Python代码实现图像分割的步骤详解
May 25 #Python
You might like
PHP把数字转成人民币大写的函数分享
2014/06/30 PHP
PHP实现163邮箱自动发送邮件
2016/03/29 PHP
PHP+JQuery+Ajax实现分页方法详解
2016/08/06 PHP
PHP有序表查找之二分查找(折半查找)算法示例
2018/02/09 PHP
PHP实现多图上传和单图上传功能
2018/05/17 PHP
Laravel 模型关联基础教程详解
2019/09/17 PHP
JavaScript高级程序设计 扩展--关于动态原型
2010/11/09 Javascript
js加入收藏以及使用Jquery更改透明度
2014/01/26 Javascript
js对象的复制继承实例
2015/01/10 Javascript
JS+CSS实现仿msn风格选项卡效果代码
2015/10/22 Javascript
javascript如何实现暂停功能
2015/11/06 Javascript
JavaScript严格模式详解
2017/01/16 Javascript
微信小程序form表单组件示例代码
2018/07/15 Javascript
jQuery中DOM操作原则实例分析
2019/08/01 jQuery
JS this关键字在ajax中使用出现问题解决方案
2020/07/17 Javascript
vue 解决在微信内置浏览器中调用支付宝支付的情况
2020/11/09 Javascript
JavaScript实现缓动动画
2020/11/25 Javascript
详解python的数字类型变量与其方法
2016/11/20 Python
Python实现matplotlib显示中文的方法详解
2018/02/06 Python
一篇文章彻底搞懂Python中可迭代(Iterable)、迭代器(Iterator)与生成器(Generator)的概念
2019/05/13 Python
python 对任意数据和曲线进行拟合并求出函数表达式的三种解决方案
2020/02/18 Python
Python面向对象魔法方法和单例模块代码实例
2020/03/25 Python
TensorFlow tf.nn.softmax_cross_entropy_with_logits的用法
2020/04/19 Python
Django 5种类型Session使用方法解析
2020/04/29 Python
python+opencv实现车道线检测
2021/02/19 Python
OPPO手机官方商城:中国手机市场出货量第一品牌
2017/10/18 全球购物
网络通讯中,端口有什么含义,端口的取值范围
2012/11/23 面试题
什么是Remote Module
2016/06/10 面试题
校园创业策划书
2014/01/14 职场文书
工厂总经理岗位职责
2014/02/07 职场文书
新疆民族团结演讲稿
2014/08/27 职场文书
2014教师年度工作总结
2014/11/10 职场文书
冲出亚马逊观后感
2015/06/03 职场文书
安全主题班会教案
2015/08/12 职场文书
怎么用Python识别手势数字
2021/06/07 Python
解决MySQL Varchar 类型尾部空格的问题
2022/04/06 MySQL