MxNet预训练模型到Pytorch模型的转换方式


Posted in Python onMay 25, 2020

预训练模型在不同深度学习框架中的转换是一种常见的任务。今天刚好DPN预训练模型转换问题,顺手将这个过程记录一下。

核心转换函数如下所示:

def convert_from_mxnet(model, checkpoint_prefix, debug=False):
 _, mxnet_weights, mxnet_aux = mxnet.model.load_checkpoint(checkpoint_prefix, 0)
 remapped_state = {}
 for state_key in model.state_dict().keys():
  k = state_key.split('.')
  aux = False
  mxnet_key = ''
  if k[0] == 'features':
   if k[1] == 'conv1_1':
    # input block
    mxnet_key += 'conv1_x_1__'
    if k[2] == 'bn':
     mxnet_key += 'relu-sp__bn_'
     aux, key_add = _convert_bn(k[3])
     mxnet_key += key_add
    else:
     assert k[3] == 'weight'
     mxnet_key += 'conv_' + k[3]
   elif k[1] == 'conv5_bn_ac':
    # bn + ac at end of features block
    mxnet_key += 'conv5_x_x__relu-sp__bn_'
    assert k[2] == 'bn'
    aux, key_add = _convert_bn(k[3])
    mxnet_key += key_add
   else:
    # middle blocks
    if model.b and 'c1x1_c' in k[2]:
     bc_block = True # b-variant split c-block special treatment
    else:
     bc_block = False
    ck = k[1].split('_')
    mxnet_key += ck[0] + '_x__' + ck[1] + '_'
    ck = k[2].split('_')
    mxnet_key += ck[0] + '-' + ck[1]
    if ck[1] == 'w' and len(ck) > 2:
     mxnet_key += '(s/2)' if ck[2] == 's2' else '(s/1)'
    mxnet_key += '__'
    if k[3] == 'bn':
     mxnet_key += 'bn_' if bc_block else 'bn__bn_'
     aux, key_add = _convert_bn(k[4])
     mxnet_key += key_add
    else:
     ki = 3 if bc_block else 4
     assert k[ki] == 'weight'
     mxnet_key += 'conv_' + k[ki]
  elif k[0] == 'classifier':
   if 'fc6-1k_weight' in mxnet_weights:
    mxnet_key += 'fc6-1k_'
   else:
    mxnet_key += 'fc6_'
   mxnet_key += k[1]
  else:
   assert False, 'Unexpected token'
 
  if debug:
   print(mxnet_key, '=> ', state_key, end=' ')
 
  mxnet_array = mxnet_aux[mxnet_key] if aux else mxnet_weights[mxnet_key]
  torch_tensor = torch.from_numpy(mxnet_array.asnumpy())
  if k[0] == 'classifier' and k[1] == 'weight':
   torch_tensor = torch_tensor.view(torch_tensor.size() + (1, 1))
  remapped_state[state_key] = torch_tensor
 
  if debug:
   print(list(torch_tensor.size()), torch_tensor.mean(), torch_tensor.std())
 
 model.load_state_dict(remapped_state)
 
 return model

从中可以看出,其转换步骤如下:

(1)创建pytorch的网络结构模型,设为model

(2)利用mxnet来读取其存储的预训练模型,得到mxnet_weights;

(3)遍历加载后模型mxnet_weights的state_dict().keys

(4)对一些指定的key值,需要进行相应的处理和转换

(5)对修改键名之后的key利用numpy之间的转换来实现加载。

为了实现上述转换,首先pip安装mxnet,现在新版的mxnet安装还是非常方便的。

MxNet预训练模型到Pytorch模型的转换方式

第二步,运行转换程序,实现预训练模型的转换。

MxNet预训练模型到Pytorch模型的转换方式

可以看到在相当的文件夹下已经出现了转换后的模型。

以上这篇MxNet预训练模型到Pytorch模型的转换方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python实现汉诺塔递归算法经典案例
Mar 01 Python
python实现字典(dict)和字符串(string)的相互转换方法
Mar 01 Python
Django的信号机制详解
May 05 Python
Python两个字典键同值相加的几种方法
Mar 05 Python
python 为什么说eval要慎用
Mar 26 Python
对python中的os.getpid()和os.fork()函数详解
Aug 08 Python
Django自定义模板过滤器和标签的实现方法
Aug 21 Python
python 用户交互输入input的4种用法详解
Sep 24 Python
Python通过VGG16模型实现图像风格转换操作详解
Jan 16 Python
Python中的sys.stdout.write实现打印刷新功能
Feb 21 Python
PyQt5-QDateEdit的简单使用操作
Jul 12 Python
Python pyecharts案例超市4年数据可视化分析
Aug 14 Python
浅谈pytorch 模型 .pt, .pth, .pkl的区别及模型保存方式
May 25 #Python
Pytorch通过保存为ONNX模型转TensorRT5的实现
May 25 #Python
tensorflow pb to tflite 精度下降详解
May 25 #Python
Python HTMLTestRunner测试报告view按钮失效解决方案
May 25 #Python
python用opencv完成图像分割并进行目标物的提取
May 25 #Python
Pytorch转tflite方式
May 25 #Python
Python HTMLTestRunner库安装过程解析
May 25 #Python
You might like
php封装json通信接口详解及实例
2017/03/07 PHP
在JavaScript中遭遇级联表达式陷阱
2007/03/08 Javascript
加载 Javascript 最佳实践
2011/10/30 Javascript
JS应用正则表达式转换大小写示例
2014/09/18 Javascript
JavaScript 里的类数组对象
2015/04/08 Javascript
jQuery实现简单下拉导航效果
2015/09/07 Javascript
jQuery的ajax和遍历数组json实例代码
2016/08/01 Javascript
微信小程序 欢迎界面开发的实例详解
2016/11/30 Javascript
浅析jQuery操作select控件的取值和设值
2016/12/07 Javascript
Vue2.0 从零开始_环境搭建操作步骤
2017/06/14 Javascript
详解webpack+angular2开发环境搭建
2017/06/28 Javascript
vue实现导航栏效果(选中状态刷新不消失)
2017/12/13 Javascript
Nodejs异步回调之异常处理实例分析
2018/06/22 NodeJs
Vue press 支持图片放大功能的实例代码
2018/11/09 Javascript
Vue 递归多级菜单的实例代码
2019/05/05 Javascript
使用layui前端框架弹出form表单以及提交的示例
2019/10/25 Javascript
jquery实现拖拽添加元素功能
2020/12/01 jQuery
Python爬虫常用库的安装及其环境配置
2018/09/19 Python
Python装饰器限制函数运行时间超时则退出执行
2019/04/09 Python
Python变量访问权限控制详解
2019/06/29 Python
在python 中split()使用多符号分割的例子
2019/07/15 Python
softmax及python实现过程解析
2019/09/30 Python
Python上下文管理器全实例详解
2019/11/12 Python
Django ValuesQuerySet转json方式
2020/03/16 Python
Django DRF路由与扩展功能的实现
2020/06/03 Python
缓解脚、腿和背部疼痛:Z-CoiL鞋
2019/03/12 全球购物
印度排名第一的蛋糕、鲜花和礼品送货:Winni
2019/08/02 全球购物
类的返射机制中的包及核心类
2016/09/12 面试题
Structs界面控制层技术
2013/10/11 面试题
老师给学生的表扬信
2014/01/17 职场文书
难忘的一课教学反思
2014/04/30 职场文书
个人工作主要事迹
2014/05/08 职场文书
扶贫办主任查摆“四风”问题个人对照检查材料思想汇报
2014/10/02 职场文书
教师查摆问题及整改措施
2014/10/11 职场文书
离婚案件答辩状
2015/05/22 职场文书
Golang ort 中的sortInts 方法
2022/04/24 Golang