MxNet预训练模型到Pytorch模型的转换方式


Posted in Python onMay 25, 2020

预训练模型在不同深度学习框架中的转换是一种常见的任务。今天刚好DPN预训练模型转换问题,顺手将这个过程记录一下。

核心转换函数如下所示:

def convert_from_mxnet(model, checkpoint_prefix, debug=False):
 _, mxnet_weights, mxnet_aux = mxnet.model.load_checkpoint(checkpoint_prefix, 0)
 remapped_state = {}
 for state_key in model.state_dict().keys():
  k = state_key.split('.')
  aux = False
  mxnet_key = ''
  if k[0] == 'features':
   if k[1] == 'conv1_1':
    # input block
    mxnet_key += 'conv1_x_1__'
    if k[2] == 'bn':
     mxnet_key += 'relu-sp__bn_'
     aux, key_add = _convert_bn(k[3])
     mxnet_key += key_add
    else:
     assert k[3] == 'weight'
     mxnet_key += 'conv_' + k[3]
   elif k[1] == 'conv5_bn_ac':
    # bn + ac at end of features block
    mxnet_key += 'conv5_x_x__relu-sp__bn_'
    assert k[2] == 'bn'
    aux, key_add = _convert_bn(k[3])
    mxnet_key += key_add
   else:
    # middle blocks
    if model.b and 'c1x1_c' in k[2]:
     bc_block = True # b-variant split c-block special treatment
    else:
     bc_block = False
    ck = k[1].split('_')
    mxnet_key += ck[0] + '_x__' + ck[1] + '_'
    ck = k[2].split('_')
    mxnet_key += ck[0] + '-' + ck[1]
    if ck[1] == 'w' and len(ck) > 2:
     mxnet_key += '(s/2)' if ck[2] == 's2' else '(s/1)'
    mxnet_key += '__'
    if k[3] == 'bn':
     mxnet_key += 'bn_' if bc_block else 'bn__bn_'
     aux, key_add = _convert_bn(k[4])
     mxnet_key += key_add
    else:
     ki = 3 if bc_block else 4
     assert k[ki] == 'weight'
     mxnet_key += 'conv_' + k[ki]
  elif k[0] == 'classifier':
   if 'fc6-1k_weight' in mxnet_weights:
    mxnet_key += 'fc6-1k_'
   else:
    mxnet_key += 'fc6_'
   mxnet_key += k[1]
  else:
   assert False, 'Unexpected token'
 
  if debug:
   print(mxnet_key, '=> ', state_key, end=' ')
 
  mxnet_array = mxnet_aux[mxnet_key] if aux else mxnet_weights[mxnet_key]
  torch_tensor = torch.from_numpy(mxnet_array.asnumpy())
  if k[0] == 'classifier' and k[1] == 'weight':
   torch_tensor = torch_tensor.view(torch_tensor.size() + (1, 1))
  remapped_state[state_key] = torch_tensor
 
  if debug:
   print(list(torch_tensor.size()), torch_tensor.mean(), torch_tensor.std())
 
 model.load_state_dict(remapped_state)
 
 return model

从中可以看出,其转换步骤如下:

(1)创建pytorch的网络结构模型,设为model

(2)利用mxnet来读取其存储的预训练模型,得到mxnet_weights;

(3)遍历加载后模型mxnet_weights的state_dict().keys

(4)对一些指定的key值,需要进行相应的处理和转换

(5)对修改键名之后的key利用numpy之间的转换来实现加载。

为了实现上述转换,首先pip安装mxnet,现在新版的mxnet安装还是非常方便的。

MxNet预训练模型到Pytorch模型的转换方式

第二步,运行转换程序,实现预训练模型的转换。

MxNet预训练模型到Pytorch模型的转换方式

可以看到在相当的文件夹下已经出现了转换后的模型。

以上这篇MxNet预训练模型到Pytorch模型的转换方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python入门篇之正则表达式
Oct 20 Python
用C++封装MySQL的API的教程
May 06 Python
Python 12306抢火车票脚本 Python京东抢手机脚本
Feb 06 Python
Python OpenCV中的resize()函数的使用
Jun 20 Python
Django组件content-type使用方法详解
Jul 19 Python
python3发送邮件需要经过代理服务器的示例代码
Jul 25 Python
浅谈Python3 numpy.ptp()最大值与最小值的差
Aug 24 Python
python logging.basicConfig不生效的原因及解决
Feb 20 Python
Pytorch使用PIL和Numpy将单张图片转为Pytorch张量方式
May 25 Python
java关于string最常出现的面试题整理
Jan 18 Python
Python实现的扫码工具居然这么好用!
Jun 07 Python
python字典的元素访问实例详解
Jul 21 Python
浅谈pytorch 模型 .pt, .pth, .pkl的区别及模型保存方式
May 25 #Python
Pytorch通过保存为ONNX模型转TensorRT5的实现
May 25 #Python
tensorflow pb to tflite 精度下降详解
May 25 #Python
Python HTMLTestRunner测试报告view按钮失效解决方案
May 25 #Python
python用opencv完成图像分割并进行目标物的提取
May 25 #Python
Pytorch转tflite方式
May 25 #Python
Python HTMLTestRunner库安装过程解析
May 25 #Python
You might like
浅析php中常量,变量的作用域和生存周期
2013/08/10 PHP
CI使用Tank Auth转移数据库导致密码用户错误的解决办法
2014/06/12 PHP
ThinkPHP3.1新特性之对页面压缩输出的支持
2014/06/19 PHP
jQuery Pagination Ajax分页插件(分页切换时无刷新与延迟)中文翻译版
2013/01/11 Javascript
js如何设置在iframe框架中指定div不显示
2013/12/04 Javascript
jquery检测input checked 控件是否被选中的方法
2014/03/26 Javascript
如何让你的Lightbox支持滚轮缩放及Base64图片
2014/12/04 Javascript
JavaScript中setUTCFullYear()方法的使用简介
2015/06/12 Javascript
Javascript中的神器——Promise
2017/02/08 Javascript
JS使用面向对象技术实现的tab选项卡效果示例
2017/02/28 Javascript
JavaScript设计模式之策略模式详解
2017/06/09 Javascript
使用Node.js实现简易MVC框架的方法
2017/08/07 Javascript
JS实现图片手风琴效果
2020/04/17 Javascript
jQuery插件Validation表单验证详解
2018/05/26 jQuery
解决JQuery的ajax函数执行失败alert函数弹框一闪而过问题
2019/04/10 jQuery
vue开发chrome插件,实现获取界面数据和保存到数据库功能
2020/12/01 Vue.js
[06:37]2014DOTA2国际邀请赛 昔日王者渴望重回巅峰
2014/07/12 DOTA
[43:58]DOTA2上海特级锦标赛C组败者赛 Newbee VS Archon第二局
2016/02/27 DOTA
[01:16]DOTA2小知识课堂 Ep.03 芒果树无伤肉山
2019/12/05 DOTA
[03:02]2020完美世界城市挑战赛(秋季赛)总决赛回顾
2021/03/11 DOTA
python多重继承实例
2014/10/11 Python
浅谈pycharm出现卡顿的解决方法
2018/12/03 Python
python命令 -u参数用法解析
2019/10/24 Python
浅谈Python类中的self到底是干啥的
2019/11/11 Python
python爬虫学习笔记之Beautifulsoup模块用法详解
2020/04/09 Python
怎么快速自学python
2020/06/22 Python
pytorch 限制GPU使用效率详解(计算效率)
2020/06/27 Python
HTML5中的Web Notification桌面通知功能的实现方法
2019/07/29 HTML / CSS
New Balance澳大利亚官网:运动鞋和健身服装
2019/02/23 全球购物
编写一个类体现构造,公有,私有方法,静态,私有变量
2013/08/10 面试题
最新创业融资计划书
2014/01/19 职场文书
材料专业大学毕业生自荐书
2014/07/02 职场文书
写给领导的感谢信
2015/01/22 职场文书
项目技术负责人岗位职责
2015/04/13 职场文书
css 中多种边框的实现小窍门
2021/04/07 HTML / CSS
vue中this.$http.post()跨域和请求参数丢失的解决
2022/04/08 Vue.js