Pytorch通过保存为ONNX模型转TensorRT5的实现


Posted in Python onMay 25, 2020

1 Pytorch以ONNX方式保存模型

def saveONNX(model, filepath):
  '''
  保存ONNX模型
  :param model: 神经网络模型
  :param filepath: 文件保存路径
  '''
  
  # 神经网络输入数据类型
  dummy_input = torch.randn(self.config.BATCH_SIZE, 1, 28, 28, device='cuda')
  torch.onnx.export(model, dummy_input, filepath, verbose=True)

2 利用TensorRT5中ONNX解析器构建Engine

def ONNX_build_engine(onnx_file_path):
  '''
  通过加载onnx文件,构建engine
  :param onnx_file_path: onnx文件路径
  :return: engine
  '''
  # 打印日志
  G_LOGGER = trt.Logger(trt.Logger.WARNING)

  with trt.Builder(G_LOGGER) as builder, builder.create_network() as network, trt.OnnxParser(network, G_LOGGER) as parser:
   builder.max_batch_size = 100
   builder.max_workspace_size = 1 << 20

   print('Loading ONNX file from path {}...'.format(onnx_file_path))
   with open(onnx_file_path, 'rb') as model:
    print('Beginning ONNX file parsing')
    parser.parse(model.read())
   print('Completed parsing of ONNX file')

   print('Building an engine from file {}; this may take a while...'.format(onnx_file_path))
   engine = builder.build_cuda_engine(network)
   print("Completed creating Engine")

   # 保存计划文件
   # with open(engine_file_path, "wb") as f:
   #  f.write(engine.serialize())
   return engine

3 构建TensorRT运行引擎进行预测

def loadONNX2TensorRT(filepath):
  '''
  通过onnx文件,构建TensorRT运行引擎
  :param filepath: onnx文件路径
  '''
  # 计算开始时间
  Start = time()

  engine = self.ONNX_build_engine(filepath)

  # 读取测试集
  datas = DataLoaders()
  test_loader = datas.testDataLoader()
  img, target = next(iter(test_loader))
  img = img.numpy()
  target = target.numpy()

  img = img.ravel()

  context = engine.create_execution_context()
  output = np.empty((100, 10), dtype=np.float32)

  # 分配内存
  d_input = cuda.mem_alloc(1 * img.size * img.dtype.itemsize)
  d_output = cuda.mem_alloc(1 * output.size * output.dtype.itemsize)
  bindings = [int(d_input), int(d_output)]

  # pycuda操作缓冲区
  stream = cuda.Stream()
  # 将输入数据放入device
  cuda.memcpy_htod_async(d_input, img, stream)
  # 执行模型
  context.execute_async(100, bindings, stream.handle, None)
  # 将预测结果从从缓冲区取出
  cuda.memcpy_dtoh_async(output, d_output, stream)
  # 线程同步
  stream.synchronize()

  print("Test Case: " + str(target))
  print("Prediction: " + str(np.argmax(output, axis=1)))
  print("tensorrt time:", time() - Start)

  del context
  del engine

补充知识:Pytorch/Caffe可以先转换为ONNX,再转换为TensorRT

近来工作,试图把Pytorch用TensorRT运行。折腾了半天,没有完成。github中的转换代码,只能处理pytorch 0.2.0的功能(也明确表示不维护了)。和同事一起处理了很多例外,还是没有通过。吾以为,实际上即使勉强过了,能不能跑也是问题。

后来有高手建议,先转换为ONNX,再转换为TensorRT。这个思路基本可行。

是不是这样就万事大吉?当然不是,还是有严重问题要解决的。这只是个思路。

以上这篇Pytorch通过保存为ONNX模型转TensorRT5的实现就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python中的迭代和可迭代对象代码示例
Dec 27 Python
Python分支结构(switch)操作简介
Jan 17 Python
python中的随机函数random的用法示例
Jan 27 Python
如何用Python合并lmdb文件
Jul 02 Python
python使用for循环计算0-100的整数的和方法
Feb 01 Python
Python获取好友地区分布及好友性别分布情况代码详解
Jul 10 Python
使用Windows批处理和WMI设置Python的环境变量方法
Aug 14 Python
解决python 找不到module的问题
Feb 12 Python
Python批量处理csv并保存过程解析
May 16 Python
python相对企业语言优势在哪
Jun 12 Python
Python如何对XML 解析
Jun 28 Python
Django解决frame拒绝问题的方法
Dec 18 Python
tensorflow pb to tflite 精度下降详解
May 25 #Python
Python HTMLTestRunner测试报告view按钮失效解决方案
May 25 #Python
python用opencv完成图像分割并进行目标物的提取
May 25 #Python
Pytorch转tflite方式
May 25 #Python
Python HTMLTestRunner库安装过程解析
May 25 #Python
Anaconda+vscode+pytorch环境搭建过程详解
May 25 #Python
5行Python代码实现图像分割的步骤详解
May 25 #Python
You might like
PHP中在数据库中保存Checkbox数据(2)
2006/10/09 PHP
php实现jQuery扩展函数
2009/10/30 PHP
PHP file_get_contents设置超时处理方法
2013/09/30 PHP
两款万能的php分页类
2015/11/12 PHP
PHP实现操作redis的封装类完整实例
2015/11/14 PHP
CI框架整合widget(页面格局)的方法
2016/05/17 PHP
magento后台无法登录解决办法的两种方法
2016/12/09 PHP
php+Memcached实现简单留言板功能示例
2017/02/15 PHP
PhpStorm+xdebug+postman调试技巧分享
2020/09/15 PHP
jquery的ajax()函数传值中文乱码解决方法介绍
2012/11/08 Javascript
删除select中所有option选项jquery代码
2013/08/12 Javascript
js实现select跳转功能代码
2014/10/22 Javascript
jQuery实现select下拉框获取当前选中文本、值、索引
2017/05/08 jQuery
详解Vue单元测试case写法
2018/05/24 Javascript
讲解vue-router之什么是动态路由
2018/05/28 Javascript
微信小程序中进行地图导航功能的实现方法
2018/06/29 Javascript
关于vue里页面的缓存详解
2019/11/04 Javascript
[01:05:00]2018国际邀请赛 表演赛 Pain vs OpenAI
2018/08/24 DOTA
使用Python的web.py框架实现类似Django的ORM查询的教程
2015/05/02 Python
Python使用shelve模块实现简单数据存储的方法
2015/05/20 Python
Python实现自定义函数的5种常见形式分析
2018/06/16 Python
Python编程快速上手——Excel表格创建乘法表案例分析
2020/02/28 Python
快速解决jupyter启动卡死的问题
2020/04/10 Python
CSS实现聊天气泡效果
2020/04/26 HTML / CSS
使用HTML5 Canvas API中的clip()方法裁剪区域图像
2016/03/25 HTML / CSS
您的健身减肥和健康饮食专家:vitafy
2017/06/06 全球购物
俄罗斯在线手表和珠宝商店:AllTime
2019/09/28 全球购物
大型会议策划方案
2014/05/17 职场文书
2014年冬季防火方案
2014/05/21 职场文书
个人年底工作总结
2015/03/10 职场文书
幼师辞职信范文大全
2015/05/12 职场文书
解决go在函数退出后子协程的退出问题
2021/04/30 Golang
python tkinter Entry控件的焦点移动操作
2021/05/22 Python
MySQL中存储时间的最佳实践指南
2021/07/01 MySQL
实现AJAX异步调用和局部刷新的基本步骤
2022/03/17 Javascript
MySQL提升大量数据查询效率的优化神器
2022/07/07 MySQL