MxNet预训练模型到Pytorch模型的转换方式


Posted in Python onMay 25, 2020

预训练模型在不同深度学习框架中的转换是一种常见的任务。今天刚好DPN预训练模型转换问题,顺手将这个过程记录一下。

核心转换函数如下所示:

def convert_from_mxnet(model, checkpoint_prefix, debug=False):
 _, mxnet_weights, mxnet_aux = mxnet.model.load_checkpoint(checkpoint_prefix, 0)
 remapped_state = {}
 for state_key in model.state_dict().keys():
  k = state_key.split('.')
  aux = False
  mxnet_key = ''
  if k[0] == 'features':
   if k[1] == 'conv1_1':
    # input block
    mxnet_key += 'conv1_x_1__'
    if k[2] == 'bn':
     mxnet_key += 'relu-sp__bn_'
     aux, key_add = _convert_bn(k[3])
     mxnet_key += key_add
    else:
     assert k[3] == 'weight'
     mxnet_key += 'conv_' + k[3]
   elif k[1] == 'conv5_bn_ac':
    # bn + ac at end of features block
    mxnet_key += 'conv5_x_x__relu-sp__bn_'
    assert k[2] == 'bn'
    aux, key_add = _convert_bn(k[3])
    mxnet_key += key_add
   else:
    # middle blocks
    if model.b and 'c1x1_c' in k[2]:
     bc_block = True # b-variant split c-block special treatment
    else:
     bc_block = False
    ck = k[1].split('_')
    mxnet_key += ck[0] + '_x__' + ck[1] + '_'
    ck = k[2].split('_')
    mxnet_key += ck[0] + '-' + ck[1]
    if ck[1] == 'w' and len(ck) > 2:
     mxnet_key += '(s/2)' if ck[2] == 's2' else '(s/1)'
    mxnet_key += '__'
    if k[3] == 'bn':
     mxnet_key += 'bn_' if bc_block else 'bn__bn_'
     aux, key_add = _convert_bn(k[4])
     mxnet_key += key_add
    else:
     ki = 3 if bc_block else 4
     assert k[ki] == 'weight'
     mxnet_key += 'conv_' + k[ki]
  elif k[0] == 'classifier':
   if 'fc6-1k_weight' in mxnet_weights:
    mxnet_key += 'fc6-1k_'
   else:
    mxnet_key += 'fc6_'
   mxnet_key += k[1]
  else:
   assert False, 'Unexpected token'
 
  if debug:
   print(mxnet_key, '=> ', state_key, end=' ')
 
  mxnet_array = mxnet_aux[mxnet_key] if aux else mxnet_weights[mxnet_key]
  torch_tensor = torch.from_numpy(mxnet_array.asnumpy())
  if k[0] == 'classifier' and k[1] == 'weight':
   torch_tensor = torch_tensor.view(torch_tensor.size() + (1, 1))
  remapped_state[state_key] = torch_tensor
 
  if debug:
   print(list(torch_tensor.size()), torch_tensor.mean(), torch_tensor.std())
 
 model.load_state_dict(remapped_state)
 
 return model

从中可以看出,其转换步骤如下:

(1)创建pytorch的网络结构模型,设为model

(2)利用mxnet来读取其存储的预训练模型,得到mxnet_weights;

(3)遍历加载后模型mxnet_weights的state_dict().keys

(4)对一些指定的key值,需要进行相应的处理和转换

(5)对修改键名之后的key利用numpy之间的转换来实现加载。

为了实现上述转换,首先pip安装mxnet,现在新版的mxnet安装还是非常方便的。

MxNet预训练模型到Pytorch模型的转换方式

第二步,运行转换程序,实现预训练模型的转换。

MxNet预训练模型到Pytorch模型的转换方式

可以看到在相当的文件夹下已经出现了转换后的模型。

以上这篇MxNet预训练模型到Pytorch模型的转换方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python中的ctime()方法使用教程
May 22 Python
使用Python读写文本文件及编写简单的文本编辑器
Mar 11 Python
Python代码缩进和测试模块示例详解
May 07 Python
python 把列表转化为字符串的方法
Oct 23 Python
对Python通过pypyodbc访问Access数据库的方法详解
Oct 27 Python
Python并发:多线程与多进程的详解
Jan 24 Python
Python高级特性与几种函数的讲解
Mar 08 Python
详解Python下载图片并保存本地的两种方式
May 15 Python
Python实现CNN的多通道输入实例
Jan 17 Python
Scrapy框架基本命令与settings.py设置
Feb 06 Python
Pandas —— resample()重采样和asfreq()频度转换方式
Feb 26 Python
matplotlib bar()实现百分比堆积柱状图
Feb 24 Python
浅谈pytorch 模型 .pt, .pth, .pkl的区别及模型保存方式
May 25 #Python
Pytorch通过保存为ONNX模型转TensorRT5的实现
May 25 #Python
tensorflow pb to tflite 精度下降详解
May 25 #Python
Python HTMLTestRunner测试报告view按钮失效解决方案
May 25 #Python
python用opencv完成图像分割并进行目标物的提取
May 25 #Python
Pytorch转tflite方式
May 25 #Python
Python HTMLTestRunner库安装过程解析
May 25 #Python
You might like
php2html php生成静态页函数
2008/12/08 PHP
php类中private属性继承问题分析
2012/11/01 PHP
对淘宝URL中ID提取的PHP代码
2013/09/01 PHP
PHP中的reflection反射机制测试例子
2014/08/05 PHP
ThinkPHP入口文件设置及相关注意事项分析
2014/12/05 PHP
MacOS 安装 PHP的图片裁剪扩展Tclip
2015/03/25 PHP
php实现JWT(json web token)鉴权实例详解
2019/11/05 PHP
Aster vs KG BO3 第二场2.19
2021/03/10 DOTA
innerHTML,outerHTML,innerTEXT三者之间的区别
2007/01/28 Javascript
js与jquery实时监听输入框值的oninput与onpropertychange方法
2015/02/05 Javascript
Svg.js实例教程及使用手册详解(一)
2016/05/16 Javascript
JS针对浏览器窗口关闭事件的监听方法集锦
2016/06/24 Javascript
一篇文章搞定JavaScript类型转换(面试常见)
2017/01/21 Javascript
利用JS测试目标网站的打开响应速度
2017/12/01 Javascript
微信小程序自定义prompt组件步骤详解
2018/06/12 Javascript
微信小程序开发之路由切换页面重定向问题
2018/09/18 Javascript
使用vue-cli webpack 快速搭建项目的代码
2018/11/21 Javascript
解决微信小程序中转换时间格式IOS不兼容的问题
2019/02/15 Javascript
微信小程序性能优化之checkSession的使用
2019/03/06 Javascript
layui实现数据分页功能(ajax异步)
2019/07/27 Javascript
Vue防止白屏添加首屏动画的实例
2019/10/31 Javascript
python实现根据主机名字获得所有ip地址的方法
2015/06/28 Python
Google开源的Python格式化工具YAPF的安装和使用教程
2016/05/31 Python
全面了解python中的类,对象,方法,属性
2016/09/11 Python
python实现俄罗斯方块游戏
2020/03/25 Python
Python中安装easy_install的方法
2018/11/18 Python
pytorch中的自定义数据处理详解
2020/01/06 Python
澳大利亚便宜的家庭购物网站:CrazySales
2018/02/06 全球购物
Falconeri美国官网:由羊绒和羊毛制成的针织服装
2018/04/08 全球购物
如何在存储过程中使用Loop
2016/01/05 面试题
事业单位个人应聘自荐信
2013/09/21 职场文书
财务管理专业应届毕业生求职信
2013/09/22 职场文书
大学生村官典型材料
2014/01/12 职场文书
关于成绩下滑的自我检讨书
2014/09/20 职场文书
领导干部查摆“四风”问题自我剖析材料思想汇报
2014/10/05 职场文书
2015年世界无车日活动总结
2015/03/23 职场文书