MxNet预训练模型到Pytorch模型的转换方式


Posted in Python onMay 25, 2020

预训练模型在不同深度学习框架中的转换是一种常见的任务。今天刚好DPN预训练模型转换问题,顺手将这个过程记录一下。

核心转换函数如下所示:

def convert_from_mxnet(model, checkpoint_prefix, debug=False):
 _, mxnet_weights, mxnet_aux = mxnet.model.load_checkpoint(checkpoint_prefix, 0)
 remapped_state = {}
 for state_key in model.state_dict().keys():
  k = state_key.split('.')
  aux = False
  mxnet_key = ''
  if k[0] == 'features':
   if k[1] == 'conv1_1':
    # input block
    mxnet_key += 'conv1_x_1__'
    if k[2] == 'bn':
     mxnet_key += 'relu-sp__bn_'
     aux, key_add = _convert_bn(k[3])
     mxnet_key += key_add
    else:
     assert k[3] == 'weight'
     mxnet_key += 'conv_' + k[3]
   elif k[1] == 'conv5_bn_ac':
    # bn + ac at end of features block
    mxnet_key += 'conv5_x_x__relu-sp__bn_'
    assert k[2] == 'bn'
    aux, key_add = _convert_bn(k[3])
    mxnet_key += key_add
   else:
    # middle blocks
    if model.b and 'c1x1_c' in k[2]:
     bc_block = True # b-variant split c-block special treatment
    else:
     bc_block = False
    ck = k[1].split('_')
    mxnet_key += ck[0] + '_x__' + ck[1] + '_'
    ck = k[2].split('_')
    mxnet_key += ck[0] + '-' + ck[1]
    if ck[1] == 'w' and len(ck) > 2:
     mxnet_key += '(s/2)' if ck[2] == 's2' else '(s/1)'
    mxnet_key += '__'
    if k[3] == 'bn':
     mxnet_key += 'bn_' if bc_block else 'bn__bn_'
     aux, key_add = _convert_bn(k[4])
     mxnet_key += key_add
    else:
     ki = 3 if bc_block else 4
     assert k[ki] == 'weight'
     mxnet_key += 'conv_' + k[ki]
  elif k[0] == 'classifier':
   if 'fc6-1k_weight' in mxnet_weights:
    mxnet_key += 'fc6-1k_'
   else:
    mxnet_key += 'fc6_'
   mxnet_key += k[1]
  else:
   assert False, 'Unexpected token'
 
  if debug:
   print(mxnet_key, '=> ', state_key, end=' ')
 
  mxnet_array = mxnet_aux[mxnet_key] if aux else mxnet_weights[mxnet_key]
  torch_tensor = torch.from_numpy(mxnet_array.asnumpy())
  if k[0] == 'classifier' and k[1] == 'weight':
   torch_tensor = torch_tensor.view(torch_tensor.size() + (1, 1))
  remapped_state[state_key] = torch_tensor
 
  if debug:
   print(list(torch_tensor.size()), torch_tensor.mean(), torch_tensor.std())
 
 model.load_state_dict(remapped_state)
 
 return model

从中可以看出,其转换步骤如下:

(1)创建pytorch的网络结构模型,设为model

(2)利用mxnet来读取其存储的预训练模型,得到mxnet_weights;

(3)遍历加载后模型mxnet_weights的state_dict().keys

(4)对一些指定的key值,需要进行相应的处理和转换

(5)对修改键名之后的key利用numpy之间的转换来实现加载。

为了实现上述转换,首先pip安装mxnet,现在新版的mxnet安装还是非常方便的。

MxNet预训练模型到Pytorch模型的转换方式

第二步,运行转换程序,实现预训练模型的转换。

MxNet预训练模型到Pytorch模型的转换方式

可以看到在相当的文件夹下已经出现了转换后的模型。

以上这篇MxNet预训练模型到Pytorch模型的转换方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
浅析Python多线程下的变量问题
Apr 28 Python
python实现隐马尔科夫模型HMM
Mar 25 Python
在python中使用requests 模拟浏览器发送请求数据的方法
Dec 26 Python
Python3.7 dataclass使用指南小结
Feb 22 Python
django的聚合函数和aggregate、annotate方法使用详解
Jul 23 Python
django迁移数据库错误问题解决
Jul 29 Python
Django文件存储 默认存储系统解析
Aug 02 Python
如何用Python来理一理红楼梦里的那些关系
Aug 14 Python
Python 获取指定文件夹下的目录和文件的实现
Aug 30 Python
Pytorch实现神经网络的分类方式
Jan 08 Python
利用Tensorflow构建和训练自己的CNN来做简单的验证码识别方式
Jan 20 Python
Python用摘要算法生成token及检验token的示例代码
Dec 01 Python
浅谈pytorch 模型 .pt, .pth, .pkl的区别及模型保存方式
May 25 #Python
Pytorch通过保存为ONNX模型转TensorRT5的实现
May 25 #Python
tensorflow pb to tflite 精度下降详解
May 25 #Python
Python HTMLTestRunner测试报告view按钮失效解决方案
May 25 #Python
python用opencv完成图像分割并进行目标物的提取
May 25 #Python
Pytorch转tflite方式
May 25 #Python
Python HTMLTestRunner库安装过程解析
May 25 #Python
You might like
S900/ ETON E1-XM 收音机
2021/03/02 无线电
php MsSql server时遇到的中文编码问题
2009/06/11 PHP
PHP extract 将数组拆分成多个变量的函数
2010/06/30 PHP
PHP与javascript实现变量交互的示例代码
2013/07/23 PHP
ThinkPHP框架设计及扩展详解
2014/11/25 PHP
自定义min版smarty模板引擎MinSmarty.class.php文件及用法
2016/05/20 PHP
document.write()及其输出内容的样式、位置控制
2013/08/12 Javascript
JS小功能(offsetLeft实现图片滚动效果)实例代码
2013/11/28 Javascript
jQuery使用ajaxSubmit()提交表单示例
2014/04/04 Javascript
javascript复制粘贴与clipboardData的使用
2014/10/16 Javascript
Node.js开源应用框架HapiJS介绍
2015/01/14 Javascript
JS实现在网页中弹出一个输入框的方法
2015/03/03 Javascript
jQuery实现的多屏图像图层切换效果实例
2015/05/07 Javascript
深入理解JavaScript编程中的原型概念
2015/06/25 Javascript
三种带箭头提示框总结实例
2016/06/14 Javascript
清除js缓存的多种方法总结
2016/12/09 Javascript
jQuery插件HighCharts绘制2D金字塔图效果示例【附demo源码下载】
2017/03/09 Javascript
elemetUi 组件--el-upload实现上传Excel文件的实例
2017/10/27 Javascript
vue-week-picker实现支持按周切换的日历
2019/06/26 Javascript
JS实现提示效果弹出及延迟隐藏的功能
2019/08/26 Javascript
vue 解决provide和inject响应的问题
2020/11/12 Javascript
用vite搭建vue3应用的实现方法
2021/02/22 Vue.js
[57:59]完美世界DOTA2联赛循环赛 Ink Ice vs LBZS BO2第一场 11.05
2020/11/05 DOTA
python Pandas 读取txt表格的实例
2018/04/29 Python
python的pstuil模块使用方法总结
2019/07/26 Python
Django使用unittest模块进行单元测试过程解析
2019/08/02 Python
基于python实现把图片转换成素描
2019/11/13 Python
关于python 的legend图例,参数使用说明
2020/04/17 Python
django rest framework 过滤时间操作
2020/07/12 Python
英国香水店:The Perfume Shop
2017/03/27 全球购物
猫途鹰:全球领先的旅游点评社区
2017/04/07 全球购物
美津浓美国官网:Mizuno美国
2018/08/07 全球购物
Marlies Dekkers内衣美国官方网上商店:高端内衣品牌
2018/11/12 全球购物
JAVA语言如何进行异常处理,关键字:throws,throw,try,catch,finally分别代表什么意义?在try块中可以抛出异常吗?
2013/07/02 面试题
周年庆典邀请函范文
2014/01/23 职场文书
离婚协议书怎样才有法律效力
2014/10/10 职场文书