MxNet预训练模型到Pytorch模型的转换方式


Posted in Python onMay 25, 2020

预训练模型在不同深度学习框架中的转换是一种常见的任务。今天刚好DPN预训练模型转换问题,顺手将这个过程记录一下。

核心转换函数如下所示:

def convert_from_mxnet(model, checkpoint_prefix, debug=False):
 _, mxnet_weights, mxnet_aux = mxnet.model.load_checkpoint(checkpoint_prefix, 0)
 remapped_state = {}
 for state_key in model.state_dict().keys():
  k = state_key.split('.')
  aux = False
  mxnet_key = ''
  if k[0] == 'features':
   if k[1] == 'conv1_1':
    # input block
    mxnet_key += 'conv1_x_1__'
    if k[2] == 'bn':
     mxnet_key += 'relu-sp__bn_'
     aux, key_add = _convert_bn(k[3])
     mxnet_key += key_add
    else:
     assert k[3] == 'weight'
     mxnet_key += 'conv_' + k[3]
   elif k[1] == 'conv5_bn_ac':
    # bn + ac at end of features block
    mxnet_key += 'conv5_x_x__relu-sp__bn_'
    assert k[2] == 'bn'
    aux, key_add = _convert_bn(k[3])
    mxnet_key += key_add
   else:
    # middle blocks
    if model.b and 'c1x1_c' in k[2]:
     bc_block = True # b-variant split c-block special treatment
    else:
     bc_block = False
    ck = k[1].split('_')
    mxnet_key += ck[0] + '_x__' + ck[1] + '_'
    ck = k[2].split('_')
    mxnet_key += ck[0] + '-' + ck[1]
    if ck[1] == 'w' and len(ck) > 2:
     mxnet_key += '(s/2)' if ck[2] == 's2' else '(s/1)'
    mxnet_key += '__'
    if k[3] == 'bn':
     mxnet_key += 'bn_' if bc_block else 'bn__bn_'
     aux, key_add = _convert_bn(k[4])
     mxnet_key += key_add
    else:
     ki = 3 if bc_block else 4
     assert k[ki] == 'weight'
     mxnet_key += 'conv_' + k[ki]
  elif k[0] == 'classifier':
   if 'fc6-1k_weight' in mxnet_weights:
    mxnet_key += 'fc6-1k_'
   else:
    mxnet_key += 'fc6_'
   mxnet_key += k[1]
  else:
   assert False, 'Unexpected token'
 
  if debug:
   print(mxnet_key, '=> ', state_key, end=' ')
 
  mxnet_array = mxnet_aux[mxnet_key] if aux else mxnet_weights[mxnet_key]
  torch_tensor = torch.from_numpy(mxnet_array.asnumpy())
  if k[0] == 'classifier' and k[1] == 'weight':
   torch_tensor = torch_tensor.view(torch_tensor.size() + (1, 1))
  remapped_state[state_key] = torch_tensor
 
  if debug:
   print(list(torch_tensor.size()), torch_tensor.mean(), torch_tensor.std())
 
 model.load_state_dict(remapped_state)
 
 return model

从中可以看出,其转换步骤如下:

(1)创建pytorch的网络结构模型,设为model

(2)利用mxnet来读取其存储的预训练模型,得到mxnet_weights;

(3)遍历加载后模型mxnet_weights的state_dict().keys

(4)对一些指定的key值,需要进行相应的处理和转换

(5)对修改键名之后的key利用numpy之间的转换来实现加载。

为了实现上述转换,首先pip安装mxnet,现在新版的mxnet安装还是非常方便的。

MxNet预训练模型到Pytorch模型的转换方式

第二步,运行转换程序,实现预训练模型的转换。

MxNet预训练模型到Pytorch模型的转换方式

可以看到在相当的文件夹下已经出现了转换后的模型。

以上这篇MxNet预训练模型到Pytorch模型的转换方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python实现的生成自我描述脚本分享(很有意思的程序)
Jul 18 Python
《Python之禅》中对于Python编程过程中的一些建议
Apr 03 Python
Python实现代码统计工具(终极篇)
Jul 04 Python
Pandas探索之高性能函数eval和query解析
Oct 28 Python
使用Python读取大文件的方法
Feb 11 Python
Python中.join()和os.path.join()两个函数的用法详解
Jun 11 Python
详解Python 定时框架 Apscheduler原理及安装过程
Jun 14 Python
python Django框架实现web端分页呈现数据
Oct 31 Python
Python TCP通信客户端服务端代码实例
Nov 21 Python
Python3的unicode编码转换成中文的问题及解决方案
Dec 10 Python
Python爬虫之Selenium实现键盘事件
Dec 04 Python
python网络爬虫实现发送短信验证码的方法
Feb 25 Python
浅谈pytorch 模型 .pt, .pth, .pkl的区别及模型保存方式
May 25 #Python
Pytorch通过保存为ONNX模型转TensorRT5的实现
May 25 #Python
tensorflow pb to tflite 精度下降详解
May 25 #Python
Python HTMLTestRunner测试报告view按钮失效解决方案
May 25 #Python
python用opencv完成图像分割并进行目标物的提取
May 25 #Python
Pytorch转tflite方式
May 25 #Python
Python HTMLTestRunner库安装过程解析
May 25 #Python
You might like
《斗罗大陆》六翼天使武魂最强,为什么老千家不是上三宗?
2020/03/02 国漫
PHP读取ACCESS数据到MYSQL的代码
2011/05/11 PHP
PHP 命令行参数详解及应用
2011/05/18 PHP
PHPCrawl爬虫库实现抓取酷狗歌单的方法示例
2017/12/21 PHP
PHP pthreads v3下同步处理synchronized用法示例
2020/02/21 PHP
学习YUI.Ext第七日-View&JSONView Part Two-一个画室网站的案例
2007/03/10 Javascript
jquery load()在firefox(火狐)下显示不正常的解决方法
2011/04/05 Javascript
用unescape反编码得出汉字示例
2014/04/24 Javascript
JS兼容浏览器的导出Excel(CSV)文件的方法
2014/05/03 Javascript
jquery地址栏链接与a标签链接匹配之特效代码总结
2015/08/24 Javascript
jQuery事件的绑定、触发、及监听方法简单说明
2016/05/10 Javascript
JS获取一个未知DIV高度的方法
2016/08/09 Javascript
Node.js学习之查询字符串解析querystring详解
2017/09/28 Javascript
js实现圆形菜单选择器
2020/12/03 Javascript
[44:21]Ti4 循环赛第四日 附加赛NEWBEE vs LGD
2014/07/13 DOTA
Python判断文本中消息重复次数的方法
2016/04/27 Python
Python只用40行代码编写的计算器实例
2017/05/10 Python
Python如何筛选序列中的元素的方法实现
2019/07/15 Python
修改Pandas的行或列的名字(重命名)
2019/12/18 Python
python实现全排列代码(回溯、深度优先搜索)
2020/02/26 Python
Python3.9 beta2版本发布了,看看这7个新的PEP都是什么
2020/06/10 Python
简单介绍CSS3中Media Query的使用
2015/07/07 HTML / CSS
关于前端上传文件全面基础扫盲贴(入门)
2019/08/01 HTML / CSS
科茨沃尔德家居商店:Scotts of Stow
2018/06/29 全球购物
Ellesse英国官网:意大利高级运动品牌
2019/07/23 全球购物
2019年Java 最常见的 面试题
2016/10/19 面试题
怎么写好自荐信
2013/10/30 职场文书
趣味活动策划方案
2014/02/08 职场文书
员工考核评语大全
2014/04/26 职场文书
项目投资建议书
2014/05/16 职场文书
改革共识倡议书
2014/08/29 职场文书
学会Python数据可视化必须尝试这7个库
2021/06/16 Python
sqlserver连接错误之SQL评估期已过的问题解决
2022/03/23 SQL Server
vue+iview实现手机号分段输入框
2022/03/25 Vue.js
详解使用内网穿透工具Ngrok代理本地服务
2022/03/31 Servers
Python进程间的通信之语法学习
2022/04/11 Python