Pytorch通过保存为ONNX模型转TensorRT5的实现


Posted in Python onMay 25, 2020

1 Pytorch以ONNX方式保存模型

def saveONNX(model, filepath):
  '''
  保存ONNX模型
  :param model: 神经网络模型
  :param filepath: 文件保存路径
  '''
  
  # 神经网络输入数据类型
  dummy_input = torch.randn(self.config.BATCH_SIZE, 1, 28, 28, device='cuda')
  torch.onnx.export(model, dummy_input, filepath, verbose=True)

2 利用TensorRT5中ONNX解析器构建Engine

def ONNX_build_engine(onnx_file_path):
  '''
  通过加载onnx文件,构建engine
  :param onnx_file_path: onnx文件路径
  :return: engine
  '''
  # 打印日志
  G_LOGGER = trt.Logger(trt.Logger.WARNING)

  with trt.Builder(G_LOGGER) as builder, builder.create_network() as network, trt.OnnxParser(network, G_LOGGER) as parser:
   builder.max_batch_size = 100
   builder.max_workspace_size = 1 << 20

   print('Loading ONNX file from path {}...'.format(onnx_file_path))
   with open(onnx_file_path, 'rb') as model:
    print('Beginning ONNX file parsing')
    parser.parse(model.read())
   print('Completed parsing of ONNX file')

   print('Building an engine from file {}; this may take a while...'.format(onnx_file_path))
   engine = builder.build_cuda_engine(network)
   print("Completed creating Engine")

   # 保存计划文件
   # with open(engine_file_path, "wb") as f:
   #  f.write(engine.serialize())
   return engine

3 构建TensorRT运行引擎进行预测

def loadONNX2TensorRT(filepath):
  '''
  通过onnx文件,构建TensorRT运行引擎
  :param filepath: onnx文件路径
  '''
  # 计算开始时间
  Start = time()

  engine = self.ONNX_build_engine(filepath)

  # 读取测试集
  datas = DataLoaders()
  test_loader = datas.testDataLoader()
  img, target = next(iter(test_loader))
  img = img.numpy()
  target = target.numpy()

  img = img.ravel()

  context = engine.create_execution_context()
  output = np.empty((100, 10), dtype=np.float32)

  # 分配内存
  d_input = cuda.mem_alloc(1 * img.size * img.dtype.itemsize)
  d_output = cuda.mem_alloc(1 * output.size * output.dtype.itemsize)
  bindings = [int(d_input), int(d_output)]

  # pycuda操作缓冲区
  stream = cuda.Stream()
  # 将输入数据放入device
  cuda.memcpy_htod_async(d_input, img, stream)
  # 执行模型
  context.execute_async(100, bindings, stream.handle, None)
  # 将预测结果从从缓冲区取出
  cuda.memcpy_dtoh_async(output, d_output, stream)
  # 线程同步
  stream.synchronize()

  print("Test Case: " + str(target))
  print("Prediction: " + str(np.argmax(output, axis=1)))
  print("tensorrt time:", time() - Start)

  del context
  del engine

补充知识:Pytorch/Caffe可以先转换为ONNX,再转换为TensorRT

近来工作,试图把Pytorch用TensorRT运行。折腾了半天,没有完成。github中的转换代码,只能处理pytorch 0.2.0的功能(也明确表示不维护了)。和同事一起处理了很多例外,还是没有通过。吾以为,实际上即使勉强过了,能不能跑也是问题。

后来有高手建议,先转换为ONNX,再转换为TensorRT。这个思路基本可行。

是不是这样就万事大吉?当然不是,还是有严重问题要解决的。这只是个思路。

以上这篇Pytorch通过保存为ONNX模型转TensorRT5的实现就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python天气预报采集器实现代码(网页爬虫)
Oct 07 Python
浅谈Django自定义模板标签template_tags的用处
Dec 20 Python
python代码过长的换行方法
Jul 19 Python
让代码变得更易维护的7个Python库
Oct 09 Python
python制作图片缩略图
Apr 30 Python
Python中遍历列表的方法总结
Jun 27 Python
使用python来调用CAN通讯的DLL实现方法
Jul 03 Python
windows下Python安装、使用教程和Notepad++的使用教程
Oct 06 Python
Python中的引用和拷贝实例解析
Nov 14 Python
Python PyInstaller安装和使用教程详解
Jan 08 Python
Python装饰器的应用场景代码总结
Apr 10 Python
pandas求平均数和中位数的方法实例
Aug 04 Python
tensorflow pb to tflite 精度下降详解
May 25 #Python
Python HTMLTestRunner测试报告view按钮失效解决方案
May 25 #Python
python用opencv完成图像分割并进行目标物的提取
May 25 #Python
Pytorch转tflite方式
May 25 #Python
Python HTMLTestRunner库安装过程解析
May 25 #Python
Anaconda+vscode+pytorch环境搭建过程详解
May 25 #Python
5行Python代码实现图像分割的步骤详解
May 25 #Python
You might like
rephactor 优秀的PHP的重构工具
2011/06/09 PHP
php遍历类中包含的所有元素的方法
2015/05/12 PHP
PHP 实现base64编码文件上传出现问题详解
2020/09/01 PHP
关于IE、Firefox、Opera页面呈现异同 写脚本很痛苦
2009/08/28 Javascript
写给想学习Javascript的朋友一点学习经验小结
2010/11/23 Javascript
js当一个变量为函数时 应该注意的一点细节小结
2011/12/29 Javascript
js调用webservice中的方法实现思路及代码
2013/02/25 Javascript
利用js制作html table分页示例(js实现分页)
2014/04/25 Javascript
Javascript中获取浏览器类型和操作系统版本等客户端信息常用代码
2016/06/28 Javascript
Angular获取手机验证码实现移动端登录注册功能
2017/05/17 Javascript
bootstrap multiselect 多选功能实现方法
2017/06/05 Javascript
Vue使用vue-cli创建项目
2017/09/01 Javascript
JavaScript通过mouseover()实现图片变大效果的示例
2017/12/20 Javascript
小程序测试后台服务的方法(ngrok)
2019/03/08 Javascript
Vue中通过Vue.extend动态创建实例的方法
2019/08/13 Javascript
浅谈layer的Icon样式以及一些常用的layer窗口使用方法
2019/09/11 Javascript
javascript实现倒计时效果
2020/02/17 Javascript
[42:32]VP vs RNG 2019国际邀请赛淘汰赛 败者组 BO3 第一场 8.21.mp4
2020/07/19 DOTA
python相似模块用例
2016/03/04 Python
Python 使用os.remove删除文件夹时报错的解决方法
2017/01/13 Python
Python实现高斯函数的三维显示方法
2018/12/29 Python
python读取各种文件数据方法解析
2018/12/29 Python
python 实现12bit灰度图像映射到8bit显示的方法
2019/07/08 Python
用Python解数独的方法示例
2019/10/24 Python
简约控的天堂:The Undone
2016/12/21 全球购物
Pamela Love官网:纽约设计师Pamela Love的精美、时尚和穿孔珠宝
2020/10/19 全球购物
党校学习思想汇报
2014/01/06 职场文书
人事部岗位职责范本
2014/03/05 职场文书
广告宣传策划方案
2014/05/21 职场文书
档案保密承诺书
2014/06/03 职场文书
法学自荐信
2014/06/20 职场文书
法英专业大学生职业生涯规划书范文
2014/09/22 职场文书
《画家和牧童》教学反思
2016/02/17 职场文书
Python 内置函数速查表一览
2021/06/02 Python
SQLServer中JSON文档型数据的查询问题解决
2021/06/27 SQL Server
SQL Server使用CROSS APPLY与OUTER APPLY实现连接查询
2022/05/25 SQL Server