Pytorch通过保存为ONNX模型转TensorRT5的实现


Posted in Python onMay 25, 2020

1 Pytorch以ONNX方式保存模型

def saveONNX(model, filepath):
  '''
  保存ONNX模型
  :param model: 神经网络模型
  :param filepath: 文件保存路径
  '''
  
  # 神经网络输入数据类型
  dummy_input = torch.randn(self.config.BATCH_SIZE, 1, 28, 28, device='cuda')
  torch.onnx.export(model, dummy_input, filepath, verbose=True)

2 利用TensorRT5中ONNX解析器构建Engine

def ONNX_build_engine(onnx_file_path):
  '''
  通过加载onnx文件,构建engine
  :param onnx_file_path: onnx文件路径
  :return: engine
  '''
  # 打印日志
  G_LOGGER = trt.Logger(trt.Logger.WARNING)

  with trt.Builder(G_LOGGER) as builder, builder.create_network() as network, trt.OnnxParser(network, G_LOGGER) as parser:
   builder.max_batch_size = 100
   builder.max_workspace_size = 1 << 20

   print('Loading ONNX file from path {}...'.format(onnx_file_path))
   with open(onnx_file_path, 'rb') as model:
    print('Beginning ONNX file parsing')
    parser.parse(model.read())
   print('Completed parsing of ONNX file')

   print('Building an engine from file {}; this may take a while...'.format(onnx_file_path))
   engine = builder.build_cuda_engine(network)
   print("Completed creating Engine")

   # 保存计划文件
   # with open(engine_file_path, "wb") as f:
   #  f.write(engine.serialize())
   return engine

3 构建TensorRT运行引擎进行预测

def loadONNX2TensorRT(filepath):
  '''
  通过onnx文件,构建TensorRT运行引擎
  :param filepath: onnx文件路径
  '''
  # 计算开始时间
  Start = time()

  engine = self.ONNX_build_engine(filepath)

  # 读取测试集
  datas = DataLoaders()
  test_loader = datas.testDataLoader()
  img, target = next(iter(test_loader))
  img = img.numpy()
  target = target.numpy()

  img = img.ravel()

  context = engine.create_execution_context()
  output = np.empty((100, 10), dtype=np.float32)

  # 分配内存
  d_input = cuda.mem_alloc(1 * img.size * img.dtype.itemsize)
  d_output = cuda.mem_alloc(1 * output.size * output.dtype.itemsize)
  bindings = [int(d_input), int(d_output)]

  # pycuda操作缓冲区
  stream = cuda.Stream()
  # 将输入数据放入device
  cuda.memcpy_htod_async(d_input, img, stream)
  # 执行模型
  context.execute_async(100, bindings, stream.handle, None)
  # 将预测结果从从缓冲区取出
  cuda.memcpy_dtoh_async(output, d_output, stream)
  # 线程同步
  stream.synchronize()

  print("Test Case: " + str(target))
  print("Prediction: " + str(np.argmax(output, axis=1)))
  print("tensorrt time:", time() - Start)

  del context
  del engine

补充知识:Pytorch/Caffe可以先转换为ONNX,再转换为TensorRT

近来工作,试图把Pytorch用TensorRT运行。折腾了半天,没有完成。github中的转换代码,只能处理pytorch 0.2.0的功能(也明确表示不维护了)。和同事一起处理了很多例外,还是没有通过。吾以为,实际上即使勉强过了,能不能跑也是问题。

后来有高手建议,先转换为ONNX,再转换为TensorRT。这个思路基本可行。

是不是这样就万事大吉?当然不是,还是有严重问题要解决的。这只是个思路。

以上这篇Pytorch通过保存为ONNX模型转TensorRT5的实现就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python显示天气预报
Mar 02 Python
Python中的条件判断语句基础学习教程
Feb 07 Python
Python检查和同步本地时间(北京时间)的实现方法
Dec 03 Python
python的内存管理和垃圾回收机制详解
May 18 Python
pycharm 安装JPype的教程
Aug 08 Python
将数据集制作成VOC数据集格式的实例
Feb 17 Python
Python paramiko 模块浅谈与SSH主要功能模拟解析
Feb 29 Python
Python如何把十进制数转换成ip地址
May 25 Python
Python连接HDFS实现文件上传下载及Pandas转换文本文件到CSV操作
Jun 06 Python
Python项目跨域问题解决方案
Jun 22 Python
Python 使用xlwt模块将多行多列数据循环写入excel文档的操作
Nov 10 Python
OpenCV-Python模板匹配人眼的实例
Jun 08 Python
tensorflow pb to tflite 精度下降详解
May 25 #Python
Python HTMLTestRunner测试报告view按钮失效解决方案
May 25 #Python
python用opencv完成图像分割并进行目标物的提取
May 25 #Python
Pytorch转tflite方式
May 25 #Python
Python HTMLTestRunner库安装过程解析
May 25 #Python
Anaconda+vscode+pytorch环境搭建过程详解
May 25 #Python
5行Python代码实现图像分割的步骤详解
May 25 #Python
You might like
php通过curl模拟登陆DZ论坛
2015/05/11 PHP
Joomla数据库操作之JFactory::getDBO用法
2016/05/05 PHP
让你的PHP7更快之Hugepage用法分析
2016/05/31 PHP
Avengerls vs Newbee BO3 第一场2.18
2021/03/10 DOTA
JS动画效果代码3
2008/04/03 Javascript
return false,对阻止事件默认动作的一些测试代码
2010/11/17 Javascript
js 实现在离开页面时提醒未保存的信息(减少用户重复操作)
2013/01/16 Javascript
JQuery的AJAX实现文件下载的小例子
2013/05/15 Javascript
jQuery的显示和隐藏方法与css隐藏的样式对比
2013/10/18 Javascript
jQuery实现的Div窗口震动特效
2014/06/09 Javascript
通过node-mysql搭建Windows+Node.js+MySQL环境的教程
2016/03/01 Javascript
angular动态删除ng-repaeat添加的dom节点的方法
2017/07/20 Javascript
JS原生轮播图的简单实现(推荐)
2017/07/22 Javascript
webpack下实现动态引入文件方法
2018/02/22 Javascript
vue后台管理之动态加载路由的方法
2018/08/13 Javascript
Element-UI中关于table表格的那些骚操作(小结)
2019/08/15 Javascript
详解基于原生JS验证表单组件xy-form
2019/08/20 Javascript
微信小程序实现打开并下载服务器上面的pdf文件到手机
2019/09/20 Javascript
详解Vue中的Props与Data细微差别
2020/03/02 Javascript
Element图表初始大小及窗口自适应实现
2020/07/10 Javascript
js实现验证码功能
2020/07/24 Javascript
k8s node节点重新加入master集群的实现
2021/02/22 Javascript
python在多玩图片上下载妹子图的实现代码
2013/08/13 Python
python3操作mysql数据库的方法
2017/06/23 Python
基于Django filter中用contains和icontains的区别(详解)
2017/12/12 Python
Python拼接字符串的7种方法总结
2018/11/01 Python
python用插值法绘制平滑曲线
2021/02/19 Python
Python的logging模块基本用法
2020/12/24 Python
浅谈HTML5 服务器推送事件(Server-sent Events)
2017/08/01 HTML / CSS
欧洲、亚洲、非洲和拉丁美洲的度假套餐:Great Value Vacations
2019/03/30 全球购物
Perfume’s Club中文官网:西班牙美妆在线零售品牌
2020/08/24 全球购物
教师个人成长总结
2015/02/11 职场文书
上课讲话检讨书范文
2015/05/07 职场文书
阿甘正传观后感
2015/06/01 职场文书
卢旺达饭店观后感
2015/06/05 职场文书
Vue3中toRef与toRefs的区别
2022/03/24 Vue.js