MxNet预训练模型到Pytorch模型的转换方式


Posted in Python onMay 25, 2020

预训练模型在不同深度学习框架中的转换是一种常见的任务。今天刚好DPN预训练模型转换问题,顺手将这个过程记录一下。

核心转换函数如下所示:

def convert_from_mxnet(model, checkpoint_prefix, debug=False):
 _, mxnet_weights, mxnet_aux = mxnet.model.load_checkpoint(checkpoint_prefix, 0)
 remapped_state = {}
 for state_key in model.state_dict().keys():
  k = state_key.split('.')
  aux = False
  mxnet_key = ''
  if k[0] == 'features':
   if k[1] == 'conv1_1':
    # input block
    mxnet_key += 'conv1_x_1__'
    if k[2] == 'bn':
     mxnet_key += 'relu-sp__bn_'
     aux, key_add = _convert_bn(k[3])
     mxnet_key += key_add
    else:
     assert k[3] == 'weight'
     mxnet_key += 'conv_' + k[3]
   elif k[1] == 'conv5_bn_ac':
    # bn + ac at end of features block
    mxnet_key += 'conv5_x_x__relu-sp__bn_'
    assert k[2] == 'bn'
    aux, key_add = _convert_bn(k[3])
    mxnet_key += key_add
   else:
    # middle blocks
    if model.b and 'c1x1_c' in k[2]:
     bc_block = True # b-variant split c-block special treatment
    else:
     bc_block = False
    ck = k[1].split('_')
    mxnet_key += ck[0] + '_x__' + ck[1] + '_'
    ck = k[2].split('_')
    mxnet_key += ck[0] + '-' + ck[1]
    if ck[1] == 'w' and len(ck) > 2:
     mxnet_key += '(s/2)' if ck[2] == 's2' else '(s/1)'
    mxnet_key += '__'
    if k[3] == 'bn':
     mxnet_key += 'bn_' if bc_block else 'bn__bn_'
     aux, key_add = _convert_bn(k[4])
     mxnet_key += key_add
    else:
     ki = 3 if bc_block else 4
     assert k[ki] == 'weight'
     mxnet_key += 'conv_' + k[ki]
  elif k[0] == 'classifier':
   if 'fc6-1k_weight' in mxnet_weights:
    mxnet_key += 'fc6-1k_'
   else:
    mxnet_key += 'fc6_'
   mxnet_key += k[1]
  else:
   assert False, 'Unexpected token'
 
  if debug:
   print(mxnet_key, '=> ', state_key, end=' ')
 
  mxnet_array = mxnet_aux[mxnet_key] if aux else mxnet_weights[mxnet_key]
  torch_tensor = torch.from_numpy(mxnet_array.asnumpy())
  if k[0] == 'classifier' and k[1] == 'weight':
   torch_tensor = torch_tensor.view(torch_tensor.size() + (1, 1))
  remapped_state[state_key] = torch_tensor
 
  if debug:
   print(list(torch_tensor.size()), torch_tensor.mean(), torch_tensor.std())
 
 model.load_state_dict(remapped_state)
 
 return model

从中可以看出,其转换步骤如下:

(1)创建pytorch的网络结构模型,设为model

(2)利用mxnet来读取其存储的预训练模型,得到mxnet_weights;

(3)遍历加载后模型mxnet_weights的state_dict().keys

(4)对一些指定的key值,需要进行相应的处理和转换

(5)对修改键名之后的key利用numpy之间的转换来实现加载。

为了实现上述转换,首先pip安装mxnet,现在新版的mxnet安装还是非常方便的。

MxNet预训练模型到Pytorch模型的转换方式

第二步,运行转换程序,实现预训练模型的转换。

MxNet预训练模型到Pytorch模型的转换方式

可以看到在相当的文件夹下已经出现了转换后的模型。

以上这篇MxNet预训练模型到Pytorch模型的转换方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Java Web开发过程中登陆模块的验证码的实现方式总结
May 25 Python
使用Mixin设计模式进行Python编程的方法讲解
Jun 21 Python
如何高效使用Python字典的方法详解
Aug 31 Python
Python实现定时精度可调节的定时器
Apr 15 Python
python3 对list中每个元素进行处理的方法
Jun 29 Python
Windows系统下PhantomJS的安装和基本用法
Oct 21 Python
在Python文件中指定Python解释器的方法
Feb 18 Python
python中pip的使用和修改下载源的方法
Jul 08 Python
Pandas+Matplotlib 箱式图异常值分析示例
Dec 09 Python
Python类中的装饰器在当前类中的声明与调用详解
Apr 15 Python
Python基础之操作MySQL数据库
May 06 Python
Pytorch 中net.train 和 net.eval的使用说明
May 22 Python
浅谈pytorch 模型 .pt, .pth, .pkl的区别及模型保存方式
May 25 #Python
Pytorch通过保存为ONNX模型转TensorRT5的实现
May 25 #Python
tensorflow pb to tflite 精度下降详解
May 25 #Python
Python HTMLTestRunner测试报告view按钮失效解决方案
May 25 #Python
python用opencv完成图像分割并进行目标物的提取
May 25 #Python
Pytorch转tflite方式
May 25 #Python
Python HTMLTestRunner库安装过程解析
May 25 #Python
You might like
php+xml结合Ajax实现点赞功能完整实例
2015/01/30 PHP
php延迟静态绑定实例分析
2015/02/08 PHP
老生常谈PHP面向对象之注册表模式
2017/05/26 PHP
JavaScript iframe的相互操作浅析
2009/10/14 Javascript
基于jquery插件实现常见的幻灯片效果
2013/11/01 Javascript
js模仿hover的具体实现代码
2013/12/30 Javascript
javascript的propertyIsEnumerable()方法使用介绍
2014/04/09 Javascript
node.js中的fs.ftruncate方法使用说明
2014/12/15 Javascript
javascript实现图片循环渐显播放的方法
2015/02/24 Javascript
TypeError document.getElementById(...) is null错误原因
2015/05/18 Javascript
学习JavaScript设计模式之策略模式
2016/01/12 Javascript
javascript创建对象、对象继承的实用方式详解
2016/03/08 Javascript
Vue.js第四天学习笔记
2016/12/02 Javascript
vue.js开发环境搭建教程
2017/05/04 Javascript
Angular实现预加载延迟模块的示例
2017/10/12 Javascript
在 Angular6 中使用 HTTP 请求服务端数据的步骤详解
2018/08/06 Javascript
使用node.js实现微信小程序实时聊天功能
2018/08/13 Javascript
扫微信小程序码实现网站登陆实现解析
2019/08/20 Javascript
vue中defineProperty和Proxy的区别详解
2020/11/30 Vue.js
[01:04:29]DOTA2-DPC中国联赛 正赛 Phoenix vs XG BO3 第二场 1月31日
2021/03/11 DOTA
Python中set与frozenset方法和区别详解
2016/05/23 Python
在Python中字典根据多项规则排序的方法
2019/01/21 Python
Django的用户模块与权限系统的示例代码
2019/07/24 Python
Python Django 页面上展示固定的页码数实现代码
2019/08/21 Python
分享30个新鲜的CSS3打造的精美绚丽效果(附演示下载)
2012/12/28 HTML / CSS
详解css position 5种不同的值的用法
2019/07/30 HTML / CSS
基于canvas使用贝塞尔曲线平滑拟合折线段的方法
2018/01/10 HTML / CSS
为什么要使用servlet
2016/01/17 面试题
邮政员工辞职信
2014/01/16 职场文书
小学班主任培训方案
2014/06/04 职场文书
中秋手机店促销方案
2014/06/16 职场文书
七一活动主持词
2015/06/29 职场文书
初中英语教学反思范文
2016/02/15 职场文书
python 实现定时任务的四种方式
2021/04/01 Python
JavaScript设计模式之原型模式详情
2022/06/21 Javascript
JS前端可扩展的低代码UI框架Sunmao使用详解
2022/07/23 Javascript