MxNet预训练模型到Pytorch模型的转换方式


Posted in Python onMay 25, 2020

预训练模型在不同深度学习框架中的转换是一种常见的任务。今天刚好DPN预训练模型转换问题,顺手将这个过程记录一下。

核心转换函数如下所示:

def convert_from_mxnet(model, checkpoint_prefix, debug=False):
 _, mxnet_weights, mxnet_aux = mxnet.model.load_checkpoint(checkpoint_prefix, 0)
 remapped_state = {}
 for state_key in model.state_dict().keys():
  k = state_key.split('.')
  aux = False
  mxnet_key = ''
  if k[0] == 'features':
   if k[1] == 'conv1_1':
    # input block
    mxnet_key += 'conv1_x_1__'
    if k[2] == 'bn':
     mxnet_key += 'relu-sp__bn_'
     aux, key_add = _convert_bn(k[3])
     mxnet_key += key_add
    else:
     assert k[3] == 'weight'
     mxnet_key += 'conv_' + k[3]
   elif k[1] == 'conv5_bn_ac':
    # bn + ac at end of features block
    mxnet_key += 'conv5_x_x__relu-sp__bn_'
    assert k[2] == 'bn'
    aux, key_add = _convert_bn(k[3])
    mxnet_key += key_add
   else:
    # middle blocks
    if model.b and 'c1x1_c' in k[2]:
     bc_block = True # b-variant split c-block special treatment
    else:
     bc_block = False
    ck = k[1].split('_')
    mxnet_key += ck[0] + '_x__' + ck[1] + '_'
    ck = k[2].split('_')
    mxnet_key += ck[0] + '-' + ck[1]
    if ck[1] == 'w' and len(ck) > 2:
     mxnet_key += '(s/2)' if ck[2] == 's2' else '(s/1)'
    mxnet_key += '__'
    if k[3] == 'bn':
     mxnet_key += 'bn_' if bc_block else 'bn__bn_'
     aux, key_add = _convert_bn(k[4])
     mxnet_key += key_add
    else:
     ki = 3 if bc_block else 4
     assert k[ki] == 'weight'
     mxnet_key += 'conv_' + k[ki]
  elif k[0] == 'classifier':
   if 'fc6-1k_weight' in mxnet_weights:
    mxnet_key += 'fc6-1k_'
   else:
    mxnet_key += 'fc6_'
   mxnet_key += k[1]
  else:
   assert False, 'Unexpected token'
 
  if debug:
   print(mxnet_key, '=> ', state_key, end=' ')
 
  mxnet_array = mxnet_aux[mxnet_key] if aux else mxnet_weights[mxnet_key]
  torch_tensor = torch.from_numpy(mxnet_array.asnumpy())
  if k[0] == 'classifier' and k[1] == 'weight':
   torch_tensor = torch_tensor.view(torch_tensor.size() + (1, 1))
  remapped_state[state_key] = torch_tensor
 
  if debug:
   print(list(torch_tensor.size()), torch_tensor.mean(), torch_tensor.std())
 
 model.load_state_dict(remapped_state)
 
 return model

从中可以看出,其转换步骤如下:

(1)创建pytorch的网络结构模型,设为model

(2)利用mxnet来读取其存储的预训练模型,得到mxnet_weights;

(3)遍历加载后模型mxnet_weights的state_dict().keys

(4)对一些指定的key值,需要进行相应的处理和转换

(5)对修改键名之后的key利用numpy之间的转换来实现加载。

为了实现上述转换,首先pip安装mxnet,现在新版的mxnet安装还是非常方便的。

MxNet预训练模型到Pytorch模型的转换方式

第二步,运行转换程序,实现预训练模型的转换。

MxNet预训练模型到Pytorch模型的转换方式

可以看到在相当的文件夹下已经出现了转换后的模型。

以上这篇MxNet预训练模型到Pytorch模型的转换方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python爬虫获取图片并下载保存至本地的实例
Jun 01 Python
实用自动化运维Python脚本分享
Jun 04 Python
python实现输入数字的连续加减方法
Jun 22 Python
python从子线程中获得返回值的方法
Jan 30 Python
python使用参数对嵌套字典进行取值的方法
Apr 26 Python
python实现图片转字符小工具
Apr 30 Python
Python 装饰器原理、定义与用法详解
Dec 07 Python
Python基础之变量基本用法与进阶详解
Jan 03 Python
谈一谈数组拼接tf.concat()和np.concatenate()的区别
Feb 07 Python
python如何将两张图片生成为全景图片
Mar 05 Python
Python图像识别+KNN求解数独的实现
Nov 13 Python
python 统计list中各个元素出现的次数的几种方法
Feb 20 Python
浅谈pytorch 模型 .pt, .pth, .pkl的区别及模型保存方式
May 25 #Python
Pytorch通过保存为ONNX模型转TensorRT5的实现
May 25 #Python
tensorflow pb to tflite 精度下降详解
May 25 #Python
Python HTMLTestRunner测试报告view按钮失效解决方案
May 25 #Python
python用opencv完成图像分割并进行目标物的提取
May 25 #Python
Pytorch转tflite方式
May 25 #Python
Python HTMLTestRunner库安装过程解析
May 25 #Python
You might like
PHP 上传文件大小限制
2009/07/05 PHP
thinkphp 一个页面使用2次分页的实现方法
2013/07/15 PHP
针对thinkPHP5框架存储过程bug重写的存储过程扩展类完整实例
2018/06/16 PHP
php的lavarel框架中join和orWhere的用法
2020/12/28 PHP
JavaScript与Image加载事件(onload)、加载状态(complete)
2011/02/14 Javascript
JavaScript学习笔记(二) js对象
2011/10/25 Javascript
优化Jquery,提升网页加载速度
2013/11/14 Javascript
页面实时更新时间的JS实例代码
2013/12/18 Javascript
jQuery中:submit选择器用法实例
2015/01/03 Javascript
js查看一个函数的执行时间实例代码
2015/09/12 Javascript
静态页面html中跳转传值的JS处理技巧
2016/06/22 Javascript
Javascript对象字面量的理解
2016/06/22 Javascript
浅谈js的html元素的父节点,子节点
2016/08/06 Javascript
JavaScript中获取时间的函数集
2016/08/16 Javascript
jQuery实现CheckBox全选、全不选功能
2017/01/11 Javascript
微信小程序学习之数据处理详解
2017/07/05 Javascript
angular之ng-template模板加载
2017/11/09 Javascript
vue-cli启动本地服务局域网不能访问的原因分析
2018/01/22 Javascript
JS实现仿微信支付弹窗功能
2018/06/25 Javascript
NodeJs crypto加密制作token的实现代码
2019/11/15 NodeJs
JavaScript原生数组函数实例汇总
2020/10/14 Javascript
[01:14]DOTA2亚洲邀请赛 ShowOpen
2015/02/07 DOTA
Python爬虫DNS解析缓存方法实例分析
2017/06/02 Python
Python优先队列实现方法示例
2017/09/21 Python
python读文件保存到字典,修改字典并写入新文件的实例
2018/04/23 Python
Python使用Selenium爬取淘宝异步加载的数据方法
2018/12/17 Python
JetBrains PyCharm(Community版本)的下载、安装和初步使用图文教程详解
2020/03/19 Python
Django 实现 Websocket 广播、点对点发送消息的代码
2020/06/03 Python
canvas实现滑动验证的实现示例
2020/08/11 HTML / CSS
Sperry官网:帆船鞋创始品牌
2016/09/07 全球购物
Belvilla法国:休闲度假房屋出租
2020/10/03 全球购物
党的群众路线教育实践活动批评与自我批评发言稿
2014/10/16 职场文书
2014党的群众路线教育实践活动学习心得体会
2014/10/31 职场文书
小学生2015教师节演讲稿
2015/03/19 职场文书
国际贸易实训总结
2015/08/03 职场文书
mysql性能优化以及配置连接参数设置
2022/05/06 MySQL