MxNet预训练模型到Pytorch模型的转换方式


Posted in Python onMay 25, 2020

预训练模型在不同深度学习框架中的转换是一种常见的任务。今天刚好DPN预训练模型转换问题,顺手将这个过程记录一下。

核心转换函数如下所示:

def convert_from_mxnet(model, checkpoint_prefix, debug=False):
 _, mxnet_weights, mxnet_aux = mxnet.model.load_checkpoint(checkpoint_prefix, 0)
 remapped_state = {}
 for state_key in model.state_dict().keys():
  k = state_key.split('.')
  aux = False
  mxnet_key = ''
  if k[0] == 'features':
   if k[1] == 'conv1_1':
    # input block
    mxnet_key += 'conv1_x_1__'
    if k[2] == 'bn':
     mxnet_key += 'relu-sp__bn_'
     aux, key_add = _convert_bn(k[3])
     mxnet_key += key_add
    else:
     assert k[3] == 'weight'
     mxnet_key += 'conv_' + k[3]
   elif k[1] == 'conv5_bn_ac':
    # bn + ac at end of features block
    mxnet_key += 'conv5_x_x__relu-sp__bn_'
    assert k[2] == 'bn'
    aux, key_add = _convert_bn(k[3])
    mxnet_key += key_add
   else:
    # middle blocks
    if model.b and 'c1x1_c' in k[2]:
     bc_block = True # b-variant split c-block special treatment
    else:
     bc_block = False
    ck = k[1].split('_')
    mxnet_key += ck[0] + '_x__' + ck[1] + '_'
    ck = k[2].split('_')
    mxnet_key += ck[0] + '-' + ck[1]
    if ck[1] == 'w' and len(ck) > 2:
     mxnet_key += '(s/2)' if ck[2] == 's2' else '(s/1)'
    mxnet_key += '__'
    if k[3] == 'bn':
     mxnet_key += 'bn_' if bc_block else 'bn__bn_'
     aux, key_add = _convert_bn(k[4])
     mxnet_key += key_add
    else:
     ki = 3 if bc_block else 4
     assert k[ki] == 'weight'
     mxnet_key += 'conv_' + k[ki]
  elif k[0] == 'classifier':
   if 'fc6-1k_weight' in mxnet_weights:
    mxnet_key += 'fc6-1k_'
   else:
    mxnet_key += 'fc6_'
   mxnet_key += k[1]
  else:
   assert False, 'Unexpected token'
 
  if debug:
   print(mxnet_key, '=> ', state_key, end=' ')
 
  mxnet_array = mxnet_aux[mxnet_key] if aux else mxnet_weights[mxnet_key]
  torch_tensor = torch.from_numpy(mxnet_array.asnumpy())
  if k[0] == 'classifier' and k[1] == 'weight':
   torch_tensor = torch_tensor.view(torch_tensor.size() + (1, 1))
  remapped_state[state_key] = torch_tensor
 
  if debug:
   print(list(torch_tensor.size()), torch_tensor.mean(), torch_tensor.std())
 
 model.load_state_dict(remapped_state)
 
 return model

从中可以看出,其转换步骤如下:

(1)创建pytorch的网络结构模型,设为model

(2)利用mxnet来读取其存储的预训练模型,得到mxnet_weights;

(3)遍历加载后模型mxnet_weights的state_dict().keys

(4)对一些指定的key值,需要进行相应的处理和转换

(5)对修改键名之后的key利用numpy之间的转换来实现加载。

为了实现上述转换,首先pip安装mxnet,现在新版的mxnet安装还是非常方便的。

MxNet预训练模型到Pytorch模型的转换方式

第二步,运行转换程序,实现预训练模型的转换。

MxNet预训练模型到Pytorch模型的转换方式

可以看到在相当的文件夹下已经出现了转换后的模型。

以上这篇MxNet预训练模型到Pytorch模型的转换方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python实现定时播放mp3
Mar 29 Python
Python读取word文本操作详解
Jan 22 Python
python爬虫自动创建文件夹的功能
Aug 01 Python
Python使用pickle模块储存对象操作示例
Aug 15 Python
DES加密解密算法之python实现版(图文并茂)
Dec 06 Python
python2.7实现邮件发送功能
Dec 12 Python
Python Matplotlib库安装与基本作图示例
Jan 09 Python
Python列表切片操作实例总结
Feb 19 Python
Python微信操控itchat的方法
May 31 Python
Python实现的对一个数进行因式分解操作示例
Jun 27 Python
Python jieba库用法及实例解析
Nov 04 Python
Python进行特征提取的示例代码
Oct 15 Python
浅谈pytorch 模型 .pt, .pth, .pkl的区别及模型保存方式
May 25 #Python
Pytorch通过保存为ONNX模型转TensorRT5的实现
May 25 #Python
tensorflow pb to tflite 精度下降详解
May 25 #Python
Python HTMLTestRunner测试报告view按钮失效解决方案
May 25 #Python
python用opencv完成图像分割并进行目标物的提取
May 25 #Python
Pytorch转tflite方式
May 25 #Python
Python HTMLTestRunner库安装过程解析
May 25 #Python
You might like
php异常处理使用示例
2014/02/25 PHP
php中$_GET与$_POST过滤sql注入的方法
2014/11/03 PHP
Laravel框架中集成MongoDB和使用详解
2019/10/17 PHP
TP5框架使用QueryList采集框架爬小说操作示例
2020/03/26 PHP
Jquery 插件学习实例1 插件制作说明与tableUI优化
2010/04/02 Javascript
基于Jquery制作的幻灯片图集效果打包下载
2011/02/12 Javascript
三级下拉菜单的js实现代码
2011/05/23 Javascript
jquery教程ajax请求json数据示例
2014/01/13 Javascript
jQuery操作元素css样式的三种方法
2014/06/04 Javascript
浅析Javascript中“==”与“===”的区别
2014/12/23 Javascript
jQuery+css3实现文字跟随鼠标的上下抖动
2015/07/31 Javascript
javascript同步服务器时间和同步倒计时小技巧
2015/09/24 Javascript
Bootstrap每天必学之缩略图与警示窗
2015/11/29 Javascript
JQuery的常用选择器、过滤器、方法全面介绍
2016/05/25 Javascript
前端学习笔记style,currentStyle,getComputedStyle的用法与区别
2016/05/28 Javascript
jQuery操作dom实现弹出页面遮罩层(web端和移动端阻止遮罩层的滑动)
2016/08/25 Javascript
jQuery Ajax使用FormData对象上传文件的方法
2016/09/07 Javascript
bootstrap配合Masonry插件实现瀑布式布局
2017/01/18 Javascript
Ionic 2 实现列表滑动删除按钮的方法
2017/01/22 Javascript
微信小程序实现时间预约功能
2018/11/27 Javascript
JS尾递归的实现方法及代码优化技巧
2019/01/19 Javascript
微信小程序防止多次点击跳转(函数节流)
2019/09/19 Javascript
[45:15]Optic vs VP 2018国际邀请赛淘汰赛BO3 第一场 8.24
2018/08/25 DOTA
Python迭代和迭代器详解
2016/11/10 Python
详解python配置虚拟环境
2019/04/08 Python
基于python3的socket聊天编程
2020/02/17 Python
Python3中小括号()、中括号[]、花括号{}的区别详解
2020/11/15 Python
Python3+Flask安装使用教程详解
2021/02/16 Python
Python字节单位转换(将字节转换为K M G T)
2021/03/02 Python
精彩的大学生自我评价
2013/11/17 职场文书
2014学雷锋活动总结
2014/03/09 职场文书
学雷锋志愿服务月活动总结
2014/03/09 职场文书
党员批评与自我批评思想汇报
2014/10/08 职场文书
北京故宫的导游词
2015/01/31 职场文书
2015年八一建军节慰问信
2015/03/23 职场文书
Python办公自动化之教你用Python批量识别发票并录入到Excel表格中
2021/06/26 Python