MxNet预训练模型到Pytorch模型的转换方式


Posted in Python onMay 25, 2020

预训练模型在不同深度学习框架中的转换是一种常见的任务。今天刚好DPN预训练模型转换问题,顺手将这个过程记录一下。

核心转换函数如下所示:

def convert_from_mxnet(model, checkpoint_prefix, debug=False):
 _, mxnet_weights, mxnet_aux = mxnet.model.load_checkpoint(checkpoint_prefix, 0)
 remapped_state = {}
 for state_key in model.state_dict().keys():
  k = state_key.split('.')
  aux = False
  mxnet_key = ''
  if k[0] == 'features':
   if k[1] == 'conv1_1':
    # input block
    mxnet_key += 'conv1_x_1__'
    if k[2] == 'bn':
     mxnet_key += 'relu-sp__bn_'
     aux, key_add = _convert_bn(k[3])
     mxnet_key += key_add
    else:
     assert k[3] == 'weight'
     mxnet_key += 'conv_' + k[3]
   elif k[1] == 'conv5_bn_ac':
    # bn + ac at end of features block
    mxnet_key += 'conv5_x_x__relu-sp__bn_'
    assert k[2] == 'bn'
    aux, key_add = _convert_bn(k[3])
    mxnet_key += key_add
   else:
    # middle blocks
    if model.b and 'c1x1_c' in k[2]:
     bc_block = True # b-variant split c-block special treatment
    else:
     bc_block = False
    ck = k[1].split('_')
    mxnet_key += ck[0] + '_x__' + ck[1] + '_'
    ck = k[2].split('_')
    mxnet_key += ck[0] + '-' + ck[1]
    if ck[1] == 'w' and len(ck) > 2:
     mxnet_key += '(s/2)' if ck[2] == 's2' else '(s/1)'
    mxnet_key += '__'
    if k[3] == 'bn':
     mxnet_key += 'bn_' if bc_block else 'bn__bn_'
     aux, key_add = _convert_bn(k[4])
     mxnet_key += key_add
    else:
     ki = 3 if bc_block else 4
     assert k[ki] == 'weight'
     mxnet_key += 'conv_' + k[ki]
  elif k[0] == 'classifier':
   if 'fc6-1k_weight' in mxnet_weights:
    mxnet_key += 'fc6-1k_'
   else:
    mxnet_key += 'fc6_'
   mxnet_key += k[1]
  else:
   assert False, 'Unexpected token'
 
  if debug:
   print(mxnet_key, '=> ', state_key, end=' ')
 
  mxnet_array = mxnet_aux[mxnet_key] if aux else mxnet_weights[mxnet_key]
  torch_tensor = torch.from_numpy(mxnet_array.asnumpy())
  if k[0] == 'classifier' and k[1] == 'weight':
   torch_tensor = torch_tensor.view(torch_tensor.size() + (1, 1))
  remapped_state[state_key] = torch_tensor
 
  if debug:
   print(list(torch_tensor.size()), torch_tensor.mean(), torch_tensor.std())
 
 model.load_state_dict(remapped_state)
 
 return model

从中可以看出,其转换步骤如下:

(1)创建pytorch的网络结构模型,设为model

(2)利用mxnet来读取其存储的预训练模型,得到mxnet_weights;

(3)遍历加载后模型mxnet_weights的state_dict().keys

(4)对一些指定的key值,需要进行相应的处理和转换

(5)对修改键名之后的key利用numpy之间的转换来实现加载。

为了实现上述转换,首先pip安装mxnet,现在新版的mxnet安装还是非常方便的。

MxNet预训练模型到Pytorch模型的转换方式

第二步,运行转换程序,实现预训练模型的转换。

MxNet预训练模型到Pytorch模型的转换方式

可以看到在相当的文件夹下已经出现了转换后的模型。

以上这篇MxNet预训练模型到Pytorch模型的转换方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python使用mysqldb连接数据库操作方法示例详解
Dec 03 Python
python脚本实现xls(xlsx)转成csv
Apr 10 Python
python利用urllib实现爬取京东网站商品图片的爬虫实例
Aug 24 Python
Python中的上下文管理器和with语句的使用
Apr 17 Python
python石头剪刀布小游戏(三局两胜制)
Jan 20 Python
python使用多线程编写tcp客户端程序
Sep 02 Python
Django REST framework 单元测试实例解析
Nov 07 Python
Python3 mmap内存映射文件示例解析
Mar 23 Python
安装python3.7编译器后如何正确安装opnecv的方法详解
Jun 16 Python
Python把图片转化为pdf代码实例
Jul 28 Python
基于Python的接口自动化unittest测试框架和ddt数据驱动详解
Jan 27 Python
Python实现科学占卜 让视频自动打码
Apr 09 Python
浅谈pytorch 模型 .pt, .pth, .pkl的区别及模型保存方式
May 25 #Python
Pytorch通过保存为ONNX模型转TensorRT5的实现
May 25 #Python
tensorflow pb to tflite 精度下降详解
May 25 #Python
Python HTMLTestRunner测试报告view按钮失效解决方案
May 25 #Python
python用opencv完成图像分割并进行目标物的提取
May 25 #Python
Pytorch转tflite方式
May 25 #Python
Python HTMLTestRunner库安装过程解析
May 25 #Python
You might like
PHP过滤★等特殊符号的正则
2014/01/27 PHP
php实现表单多按钮提交action的处理方法
2015/10/24 PHP
PHP Imagick完美实现图片裁切、生成缩略图、添加水印
2016/02/22 PHP
laravel-admin 在列表页添加自定义按钮的例子
2019/09/30 PHP
如何在Mozilla Gecko 用Javascript加载XSL
2007/01/09 Javascript
疯掉了,尽然有js写的操作系统
2007/04/23 Javascript
JS中==与===操作符的比较
2009/03/21 Javascript
Date对象格式化函数代码
2010/07/17 Javascript
自己使用js/jquery写的一个定制对话框控件
2014/05/02 Javascript
JQuery 使用attr方法实现下拉列表选中
2014/10/13 Javascript
js实现同一页面多个不同运动效果的方法
2015/04/10 Javascript
jQuery插件imgPreviewQs实现上传图片预览
2016/01/15 Javascript
jquery点击展示与隐藏更多内容
2016/12/03 Javascript
从零开始学习Node.js系列教程四:多页面实现的数学运算示例
2017/04/13 Javascript
解决JS外部文件中文注释出现乱码问题
2017/07/09 Javascript
Vue如何从1.0迁移到2.0
2017/10/19 Javascript
vue的token刷新处理的方法
2018/07/17 Javascript
vue 二维码长按保存和复制内容操作
2020/09/22 Javascript
TypeScript 运行时类型检查补充工具
2020/09/28 Javascript
小程序实现列表倒计时功能
2021/01/29 Javascript
[01:14]英雄,所敬略同——2018完美盛典宣传视频
2018/12/05 DOTA
Python中的迭代器漫谈
2015/02/03 Python
利用Python和OpenCV库将URL转换为OpenCV格式的方法
2015/03/27 Python
浅析Python 中整型对象存储的位置
2016/05/16 Python
深入浅析Python传值与传址
2018/07/10 Python
python 获取毫秒数,计算调用时长的方法
2019/02/20 Python
python用requests实现http请求代码实例
2019/10/31 Python
dpn网络的pytorch实现方式
2020/01/14 Python
基于Python生成个性二维码过程详解
2020/03/05 Python
美国最大的万圣节服装网站:HalloweenCostumes.com
2017/10/12 全球购物
香港彩色隐形眼镜在线商店:Stunninglens(全球免费送货)
2019/05/10 全球购物
初中生个人学习的自我评价
2013/12/04 职场文书
市场部规章制度
2014/01/24 职场文书
村长反四风问题个人对照检查材料
2014/09/21 职场文书
交通处罚决定书
2015/06/24 职场文书
python 如何在list中找Topk的数值和索引
2021/05/20 Python