MxNet预训练模型到Pytorch模型的转换方式


Posted in Python onMay 25, 2020

预训练模型在不同深度学习框架中的转换是一种常见的任务。今天刚好DPN预训练模型转换问题,顺手将这个过程记录一下。

核心转换函数如下所示:

def convert_from_mxnet(model, checkpoint_prefix, debug=False):
 _, mxnet_weights, mxnet_aux = mxnet.model.load_checkpoint(checkpoint_prefix, 0)
 remapped_state = {}
 for state_key in model.state_dict().keys():
  k = state_key.split('.')
  aux = False
  mxnet_key = ''
  if k[0] == 'features':
   if k[1] == 'conv1_1':
    # input block
    mxnet_key += 'conv1_x_1__'
    if k[2] == 'bn':
     mxnet_key += 'relu-sp__bn_'
     aux, key_add = _convert_bn(k[3])
     mxnet_key += key_add
    else:
     assert k[3] == 'weight'
     mxnet_key += 'conv_' + k[3]
   elif k[1] == 'conv5_bn_ac':
    # bn + ac at end of features block
    mxnet_key += 'conv5_x_x__relu-sp__bn_'
    assert k[2] == 'bn'
    aux, key_add = _convert_bn(k[3])
    mxnet_key += key_add
   else:
    # middle blocks
    if model.b and 'c1x1_c' in k[2]:
     bc_block = True # b-variant split c-block special treatment
    else:
     bc_block = False
    ck = k[1].split('_')
    mxnet_key += ck[0] + '_x__' + ck[1] + '_'
    ck = k[2].split('_')
    mxnet_key += ck[0] + '-' + ck[1]
    if ck[1] == 'w' and len(ck) > 2:
     mxnet_key += '(s/2)' if ck[2] == 's2' else '(s/1)'
    mxnet_key += '__'
    if k[3] == 'bn':
     mxnet_key += 'bn_' if bc_block else 'bn__bn_'
     aux, key_add = _convert_bn(k[4])
     mxnet_key += key_add
    else:
     ki = 3 if bc_block else 4
     assert k[ki] == 'weight'
     mxnet_key += 'conv_' + k[ki]
  elif k[0] == 'classifier':
   if 'fc6-1k_weight' in mxnet_weights:
    mxnet_key += 'fc6-1k_'
   else:
    mxnet_key += 'fc6_'
   mxnet_key += k[1]
  else:
   assert False, 'Unexpected token'
 
  if debug:
   print(mxnet_key, '=> ', state_key, end=' ')
 
  mxnet_array = mxnet_aux[mxnet_key] if aux else mxnet_weights[mxnet_key]
  torch_tensor = torch.from_numpy(mxnet_array.asnumpy())
  if k[0] == 'classifier' and k[1] == 'weight':
   torch_tensor = torch_tensor.view(torch_tensor.size() + (1, 1))
  remapped_state[state_key] = torch_tensor
 
  if debug:
   print(list(torch_tensor.size()), torch_tensor.mean(), torch_tensor.std())
 
 model.load_state_dict(remapped_state)
 
 return model

从中可以看出,其转换步骤如下:

(1)创建pytorch的网络结构模型,设为model

(2)利用mxnet来读取其存储的预训练模型,得到mxnet_weights;

(3)遍历加载后模型mxnet_weights的state_dict().keys

(4)对一些指定的key值,需要进行相应的处理和转换

(5)对修改键名之后的key利用numpy之间的转换来实现加载。

为了实现上述转换,首先pip安装mxnet,现在新版的mxnet安装还是非常方便的。

MxNet预训练模型到Pytorch模型的转换方式

第二步,运行转换程序,实现预训练模型的转换。

MxNet预训练模型到Pytorch模型的转换方式

可以看到在相当的文件夹下已经出现了转换后的模型。

以上这篇MxNet预训练模型到Pytorch模型的转换方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python利用正则表达式提取字符串
Dec 08 Python
Python 多线程实例详解
Mar 25 Python
Python2实现的LED大数字显示效果示例
Sep 04 Python
django manage.py扩展自定义命令方法
May 27 Python
Python爬虫包BeautifulSoup异常处理(二)
Jun 17 Python
浅谈利用numpy对矩阵进行归一化处理的方法
Jul 11 Python
django缓存配置的几种方法详解
Jul 16 Python
Tensorflow实现多GPU并行方式
Feb 03 Python
Django分组聚合查询实例分享
Apr 29 Python
Python用摘要算法生成token及检验token的示例代码
Dec 01 Python
用python获取txt文件中关键字的数量
Dec 24 Python
Python爬虫获取op.gg英雄联盟英雄对位胜率的源码
Jan 29 Python
浅谈pytorch 模型 .pt, .pth, .pkl的区别及模型保存方式
May 25 #Python
Pytorch通过保存为ONNX模型转TensorRT5的实现
May 25 #Python
tensorflow pb to tflite 精度下降详解
May 25 #Python
Python HTMLTestRunner测试报告view按钮失效解决方案
May 25 #Python
python用opencv完成图像分割并进行目标物的提取
May 25 #Python
Pytorch转tflite方式
May 25 #Python
Python HTMLTestRunner库安装过程解析
May 25 #Python
You might like
无线电波是什么?它是怎样传输的?
2021/03/01 无线电
用session做客户验证时的注意事项
2006/10/09 PHP
php读取文件内容的方法汇总
2015/01/24 PHP
Smarty使用自定义资源的方法
2015/08/08 PHP
PHP防盗链的基本思想 防盗链的设置方法
2015/09/25 PHP
PHP 中使用explode()函数切割字符串为数组的示例
2017/05/06 PHP
PHP数据对象映射模式实例分析
2019/03/29 PHP
jQuery调取jSon数据并展示的方法
2015/01/29 Javascript
jQuery中animate动画第二次点击事件没反应
2015/05/07 Javascript
Jquery注册事件实现方法
2015/05/18 Javascript
JavaScript的Number对象的toString()方法
2015/12/18 Javascript
轻松实现js图片预览功能
2016/01/18 Javascript
详解Angular.js的$q.defer()服务异步处理
2016/11/06 Javascript
jquery拼接ajax 的json和字符串拼接的方法
2017/03/11 Javascript
JS表格组件神器bootstrap table使用指南详解
2017/04/12 Javascript
浅谈node的事件机制
2017/10/09 Javascript
mui框架 页面无法滚动的解决方法(推荐)
2018/01/25 Javascript
尝试自己动手用react来写一个分页组件(小结)
2018/02/09 Javascript
vue单页开发父子组件传值思路详解
2018/05/18 Javascript
微信小程序 函数防抖 解决重复点击消耗性能问题实现代码
2019/09/12 Javascript
基于JavaScript实现单例模式
2019/10/30 Javascript
[14:20]刀塔大凶女神互压各路奇葩屌丝
2014/05/16 DOTA
Python编程使用tkinter模块实现计算器软件完整代码示例
2017/11/29 Python
解决Django 在ForeignKey中出现 non-nullable field错误的问题
2019/08/06 Python
Python jieba库用法及实例解析
2019/11/04 Python
Tensorflow 使用pb文件保存(恢复)模型计算图和参数实例详解
2020/02/11 Python
Python利用socket模块开发简单的端口扫描工具的实现
2021/01/27 Python
HTML5表格_动力节点Java学院整理
2017/07/11 HTML / CSS
html5版canvas自由拼图实例
2014/10/15 HTML / CSS
会计专业自荐信
2013/12/02 职场文书
顶撞老师检讨书
2014/02/07 职场文书
年度评优评先方案
2014/06/03 职场文书
基层党员干部四风问题整改方向和措施
2014/09/25 职场文书
2014年初级职称工作总结
2014/12/08 职场文书
mysql 8.0.24 安装配置方法图文教程
2021/05/12 MySQL
深入浅出的讲解:信号调制到底是如何实现的
2022/02/18 无线电