Pytorch通过保存为ONNX模型转TensorRT5的实现


Posted in Python onMay 25, 2020

1 Pytorch以ONNX方式保存模型

def saveONNX(model, filepath):
  '''
  保存ONNX模型
  :param model: 神经网络模型
  :param filepath: 文件保存路径
  '''
  
  # 神经网络输入数据类型
  dummy_input = torch.randn(self.config.BATCH_SIZE, 1, 28, 28, device='cuda')
  torch.onnx.export(model, dummy_input, filepath, verbose=True)

2 利用TensorRT5中ONNX解析器构建Engine

def ONNX_build_engine(onnx_file_path):
  '''
  通过加载onnx文件,构建engine
  :param onnx_file_path: onnx文件路径
  :return: engine
  '''
  # 打印日志
  G_LOGGER = trt.Logger(trt.Logger.WARNING)

  with trt.Builder(G_LOGGER) as builder, builder.create_network() as network, trt.OnnxParser(network, G_LOGGER) as parser:
   builder.max_batch_size = 100
   builder.max_workspace_size = 1 << 20

   print('Loading ONNX file from path {}...'.format(onnx_file_path))
   with open(onnx_file_path, 'rb') as model:
    print('Beginning ONNX file parsing')
    parser.parse(model.read())
   print('Completed parsing of ONNX file')

   print('Building an engine from file {}; this may take a while...'.format(onnx_file_path))
   engine = builder.build_cuda_engine(network)
   print("Completed creating Engine")

   # 保存计划文件
   # with open(engine_file_path, "wb") as f:
   #  f.write(engine.serialize())
   return engine

3 构建TensorRT运行引擎进行预测

def loadONNX2TensorRT(filepath):
  '''
  通过onnx文件,构建TensorRT运行引擎
  :param filepath: onnx文件路径
  '''
  # 计算开始时间
  Start = time()

  engine = self.ONNX_build_engine(filepath)

  # 读取测试集
  datas = DataLoaders()
  test_loader = datas.testDataLoader()
  img, target = next(iter(test_loader))
  img = img.numpy()
  target = target.numpy()

  img = img.ravel()

  context = engine.create_execution_context()
  output = np.empty((100, 10), dtype=np.float32)

  # 分配内存
  d_input = cuda.mem_alloc(1 * img.size * img.dtype.itemsize)
  d_output = cuda.mem_alloc(1 * output.size * output.dtype.itemsize)
  bindings = [int(d_input), int(d_output)]

  # pycuda操作缓冲区
  stream = cuda.Stream()
  # 将输入数据放入device
  cuda.memcpy_htod_async(d_input, img, stream)
  # 执行模型
  context.execute_async(100, bindings, stream.handle, None)
  # 将预测结果从从缓冲区取出
  cuda.memcpy_dtoh_async(output, d_output, stream)
  # 线程同步
  stream.synchronize()

  print("Test Case: " + str(target))
  print("Prediction: " + str(np.argmax(output, axis=1)))
  print("tensorrt time:", time() - Start)

  del context
  del engine

补充知识:Pytorch/Caffe可以先转换为ONNX,再转换为TensorRT

近来工作,试图把Pytorch用TensorRT运行。折腾了半天,没有完成。github中的转换代码,只能处理pytorch 0.2.0的功能(也明确表示不维护了)。和同事一起处理了很多例外,还是没有通过。吾以为,实际上即使勉强过了,能不能跑也是问题。

后来有高手建议,先转换为ONNX,再转换为TensorRT。这个思路基本可行。

是不是这样就万事大吉?当然不是,还是有严重问题要解决的。这只是个思路。

以上这篇Pytorch通过保存为ONNX模型转TensorRT5的实现就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python matplotlib 画图窗口显示到gui或者控制台的实例
May 24 Python
python最小生成树kruskal与prim算法详解
Jan 17 Python
对python读取CT医学图像的实例详解
Jan 24 Python
Python基础之文件读取的讲解
Feb 16 Python
为什么说Python可以实现所有的算法
Oct 04 Python
Python 3 使用Pillow生成漂亮的分形树图片
Dec 24 Python
pytorch 求网络模型参数实例
Dec 30 Python
详解Django配置JWT认证方式
May 09 Python
tensorflow下的图片标准化函数per_image_standardization用法
Jun 30 Python
requests在python中发送请求的实例讲解
Feb 17 Python
用Python简陋模拟n阶魔方
Apr 17 Python
python 解决微分方程的操作(数值解法)
May 26 Python
tensorflow pb to tflite 精度下降详解
May 25 #Python
Python HTMLTestRunner测试报告view按钮失效解决方案
May 25 #Python
python用opencv完成图像分割并进行目标物的提取
May 25 #Python
Pytorch转tflite方式
May 25 #Python
Python HTMLTestRunner库安装过程解析
May 25 #Python
Anaconda+vscode+pytorch环境搭建过程详解
May 25 #Python
5行Python代码实现图像分割的步骤详解
May 25 #Python
You might like
计数器详细设计
2006/10/09 PHP
PHP基于MySQLI函数封装的数据库连接工具类【定义与用法】
2017/08/11 PHP
JavaScript中的私有成员
2006/09/18 Javascript
jQuery弹出层始终垂直居中相对于屏幕或当前窗口
2013/04/01 Javascript
javascript中创建对象的几种方法总结
2013/11/01 Javascript
js判断为空Null与字符串为空简写方法
2014/02/24 Javascript
jquery新的绑定事件机制on方法的使用方法
2014/04/15 Javascript
JS比较2个日期间隔的示例代码
2014/04/15 Javascript
javascript框架设计读书笔记之字符串的扩展和修复
2014/12/02 Javascript
js中的事件捕捉模型与冒泡模型实例分析
2015/01/10 Javascript
jquery插件bxslider用法实例分析
2015/04/16 Javascript
全面解析Bootstrap中nav、collapse的使用方法
2016/05/22 Javascript
BootStrap中
2016/12/10 Javascript
JavaScript函数表达式详解及实例
2017/05/05 Javascript
jquery插件canvaspercent.js实现百分比圆饼效果
2017/07/18 jQuery
vue遍历生成的输入框 绑定及修改值示例
2019/10/30 Javascript
浅谈vue中使用编辑器vue-quill-editor踩过的坑
2020/08/03 Javascript
jQuery实现电梯导航模块
2020/12/22 jQuery
从源码角度来回答keep-alive组件的缓存原理
2021/01/18 Javascript
跟老齐学Python之不要红头文件(2)
2014/09/28 Python
Python3实现并发检验代理池地址的方法
2016/09/18 Python
python快速建立超简单的web服务器的实现方法
2018/02/17 Python
Python除法之传统除法、Floor除法及真除法实例详解
2019/05/23 Python
Ubuntu18.04下python版本完美切换的解决方法
2019/06/14 Python
Python 使用folium绘制leaflet地图的实现方法
2019/07/05 Python
python采集百度搜索结果带有特定URL的链接代码实例
2019/08/30 Python
Python Gluon参数和模块命名操作教程
2019/12/18 Python
世界上最大的专业美容用品零售商:Sally Beauty
2017/07/02 全球购物
BSTN意大利:德国街头和运动文化高品质商店
2020/12/22 全球购物
J2EE中常用的名词进行解释
2015/11/09 面试题
入党积极分子思想汇报
2014/01/02 职场文书
法人代表授权委托书范文
2014/09/10 职场文书
2014年高中班主任工作总结
2014/11/08 职场文书
爸爸的三轮车观后感
2015/06/16 职场文书
2016年大学迎新晚会工作总结
2015/10/15 职场文书
2016教师校本培训心得体会
2016/01/08 职场文书