Pytorch通过保存为ONNX模型转TensorRT5的实现


Posted in Python onMay 25, 2020

1 Pytorch以ONNX方式保存模型

def saveONNX(model, filepath):
  '''
  保存ONNX模型
  :param model: 神经网络模型
  :param filepath: 文件保存路径
  '''
  
  # 神经网络输入数据类型
  dummy_input = torch.randn(self.config.BATCH_SIZE, 1, 28, 28, device='cuda')
  torch.onnx.export(model, dummy_input, filepath, verbose=True)

2 利用TensorRT5中ONNX解析器构建Engine

def ONNX_build_engine(onnx_file_path):
  '''
  通过加载onnx文件,构建engine
  :param onnx_file_path: onnx文件路径
  :return: engine
  '''
  # 打印日志
  G_LOGGER = trt.Logger(trt.Logger.WARNING)

  with trt.Builder(G_LOGGER) as builder, builder.create_network() as network, trt.OnnxParser(network, G_LOGGER) as parser:
   builder.max_batch_size = 100
   builder.max_workspace_size = 1 << 20

   print('Loading ONNX file from path {}...'.format(onnx_file_path))
   with open(onnx_file_path, 'rb') as model:
    print('Beginning ONNX file parsing')
    parser.parse(model.read())
   print('Completed parsing of ONNX file')

   print('Building an engine from file {}; this may take a while...'.format(onnx_file_path))
   engine = builder.build_cuda_engine(network)
   print("Completed creating Engine")

   # 保存计划文件
   # with open(engine_file_path, "wb") as f:
   #  f.write(engine.serialize())
   return engine

3 构建TensorRT运行引擎进行预测

def loadONNX2TensorRT(filepath):
  '''
  通过onnx文件,构建TensorRT运行引擎
  :param filepath: onnx文件路径
  '''
  # 计算开始时间
  Start = time()

  engine = self.ONNX_build_engine(filepath)

  # 读取测试集
  datas = DataLoaders()
  test_loader = datas.testDataLoader()
  img, target = next(iter(test_loader))
  img = img.numpy()
  target = target.numpy()

  img = img.ravel()

  context = engine.create_execution_context()
  output = np.empty((100, 10), dtype=np.float32)

  # 分配内存
  d_input = cuda.mem_alloc(1 * img.size * img.dtype.itemsize)
  d_output = cuda.mem_alloc(1 * output.size * output.dtype.itemsize)
  bindings = [int(d_input), int(d_output)]

  # pycuda操作缓冲区
  stream = cuda.Stream()
  # 将输入数据放入device
  cuda.memcpy_htod_async(d_input, img, stream)
  # 执行模型
  context.execute_async(100, bindings, stream.handle, None)
  # 将预测结果从从缓冲区取出
  cuda.memcpy_dtoh_async(output, d_output, stream)
  # 线程同步
  stream.synchronize()

  print("Test Case: " + str(target))
  print("Prediction: " + str(np.argmax(output, axis=1)))
  print("tensorrt time:", time() - Start)

  del context
  del engine

补充知识:Pytorch/Caffe可以先转换为ONNX,再转换为TensorRT

近来工作,试图把Pytorch用TensorRT运行。折腾了半天,没有完成。github中的转换代码,只能处理pytorch 0.2.0的功能(也明确表示不维护了)。和同事一起处理了很多例外,还是没有通过。吾以为,实际上即使勉强过了,能不能跑也是问题。

后来有高手建议,先转换为ONNX,再转换为TensorRT。这个思路基本可行。

是不是这样就万事大吉?当然不是,还是有严重问题要解决的。这只是个思路。

以上这篇Pytorch通过保存为ONNX模型转TensorRT5的实现就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python使用urllib模块和pyquery实现阿里巴巴排名查询
Jan 16 Python
Python中定时任务框架APScheduler的快速入门指南
Jul 06 Python
Python简单实现控制电脑的方法
Jan 22 Python
python自动查询12306余票并发送邮箱提醒脚本
May 21 Python
python 解决动态的定义变量名,并给其赋值的方法(大数据处理)
Nov 10 Python
python 并发编程 多路复用IO模型详解
Aug 20 Python
python中通过selenium简单操作及元素定位知识点总结
Sep 10 Python
python tkinter图形界面代码统计工具
Sep 18 Python
flask框架自定义过滤器示例【markdown文件读取和展示功能】
Nov 08 Python
pytorch-RNN进行回归曲线预测方式
Jan 14 Python
python集合的新增元素方法整理
Dec 07 Python
python元组打包和解包过程详解
Aug 02 Python
tensorflow pb to tflite 精度下降详解
May 25 #Python
Python HTMLTestRunner测试报告view按钮失效解决方案
May 25 #Python
python用opencv完成图像分割并进行目标物的提取
May 25 #Python
Pytorch转tflite方式
May 25 #Python
Python HTMLTestRunner库安装过程解析
May 25 #Python
Anaconda+vscode+pytorch环境搭建过程详解
May 25 #Python
5行Python代码实现图像分割的步骤详解
May 25 #Python
You might like
javascript中巧用“闭包”实现程序的暂停执行功能
2007/04/04 Javascript
javascript 字符 Escape,encodeURI,encodeURIComponent
2009/07/09 Javascript
JavaScript读取中文cookie时的乱码问题的解决方法
2009/10/14 Javascript
jQuery点击后一组图片左右滑动的实现代码
2012/08/16 Javascript
深入讲解AngularJS中的自定义指令的使用
2015/06/18 Javascript
浅谈Javascript数组的使用
2015/07/29 Javascript
浅谈JavaScript的函数及作用域
2016/12/30 Javascript
Angular2整合其他插件的方法
2018/01/20 Javascript
angular4 获取wifi列表中文显示乱码问题的解决
2018/10/20 Javascript
eslint 的三大通用规则详解
2019/05/16 Javascript
记一次用ts+vuecli4重构项目的实现
2020/05/21 Javascript
Python 抓取动态网页内容方案详解
2014/12/25 Python
Python随机生成信用卡卡号的实现方法
2015/05/14 Python
python多进程共享变量
2016/04/06 Python
Python网络编程中urllib2模块的用法总结
2016/07/12 Python
使用Python将数组的元素导出到变量中(unpacking)
2016/10/27 Python
Python实现将16进制字符串转化为ascii字符的方法分析
2017/07/21 Python
Python文件读写保存操作的示例代码
2018/09/14 Python
Python利用lxml模块爬取豆瓣读书排行榜的方法与分析
2019/04/15 Python
简单了解django缓存方式及配置
2019/07/19 Python
python点击鼠标获取坐标(Graphics)
2019/08/10 Python
python 使用raw socket进行TCP SYN扫描实例
2020/05/05 Python
详解Django中views数据查询使用locals()函数进行优化
2020/08/24 Python
python 解决函数返回return的问题
2020/12/05 Python
css3实现信纸/同学录效果的示例代码
2018/12/11 HTML / CSS
美国著名首饰网站:BaubleBar
2016/08/29 全球购物
意大利中国电子产品购物网站:Geekmall.com
2019/09/30 全球购物
迪士尼西班牙官方网上商店:ShopDisney西班牙
2020/02/02 全球购物
怎么写有吸引力的自荐信
2013/11/17 职场文书
岗位说明书范文
2014/05/07 职场文书
2014年入党积极分子学习三中全会思想汇报
2014/09/13 职场文书
学习雷锋精神活动总结
2015/02/06 职场文书
黑白记忆观后感
2015/06/18 职场文书
致运动员的广播稿
2015/08/19 职场文书
学者《孟子》名人名言
2019/08/09 职场文书
Redis Stream类型的使用详解
2021/11/11 Redis