MxNet预训练模型到Pytorch模型的转换方式


Posted in Python onMay 25, 2020

预训练模型在不同深度学习框架中的转换是一种常见的任务。今天刚好DPN预训练模型转换问题,顺手将这个过程记录一下。

核心转换函数如下所示:

def convert_from_mxnet(model, checkpoint_prefix, debug=False):
 _, mxnet_weights, mxnet_aux = mxnet.model.load_checkpoint(checkpoint_prefix, 0)
 remapped_state = {}
 for state_key in model.state_dict().keys():
  k = state_key.split('.')
  aux = False
  mxnet_key = ''
  if k[0] == 'features':
   if k[1] == 'conv1_1':
    # input block
    mxnet_key += 'conv1_x_1__'
    if k[2] == 'bn':
     mxnet_key += 'relu-sp__bn_'
     aux, key_add = _convert_bn(k[3])
     mxnet_key += key_add
    else:
     assert k[3] == 'weight'
     mxnet_key += 'conv_' + k[3]
   elif k[1] == 'conv5_bn_ac':
    # bn + ac at end of features block
    mxnet_key += 'conv5_x_x__relu-sp__bn_'
    assert k[2] == 'bn'
    aux, key_add = _convert_bn(k[3])
    mxnet_key += key_add
   else:
    # middle blocks
    if model.b and 'c1x1_c' in k[2]:
     bc_block = True # b-variant split c-block special treatment
    else:
     bc_block = False
    ck = k[1].split('_')
    mxnet_key += ck[0] + '_x__' + ck[1] + '_'
    ck = k[2].split('_')
    mxnet_key += ck[0] + '-' + ck[1]
    if ck[1] == 'w' and len(ck) > 2:
     mxnet_key += '(s/2)' if ck[2] == 's2' else '(s/1)'
    mxnet_key += '__'
    if k[3] == 'bn':
     mxnet_key += 'bn_' if bc_block else 'bn__bn_'
     aux, key_add = _convert_bn(k[4])
     mxnet_key += key_add
    else:
     ki = 3 if bc_block else 4
     assert k[ki] == 'weight'
     mxnet_key += 'conv_' + k[ki]
  elif k[0] == 'classifier':
   if 'fc6-1k_weight' in mxnet_weights:
    mxnet_key += 'fc6-1k_'
   else:
    mxnet_key += 'fc6_'
   mxnet_key += k[1]
  else:
   assert False, 'Unexpected token'
 
  if debug:
   print(mxnet_key, '=> ', state_key, end=' ')
 
  mxnet_array = mxnet_aux[mxnet_key] if aux else mxnet_weights[mxnet_key]
  torch_tensor = torch.from_numpy(mxnet_array.asnumpy())
  if k[0] == 'classifier' and k[1] == 'weight':
   torch_tensor = torch_tensor.view(torch_tensor.size() + (1, 1))
  remapped_state[state_key] = torch_tensor
 
  if debug:
   print(list(torch_tensor.size()), torch_tensor.mean(), torch_tensor.std())
 
 model.load_state_dict(remapped_state)
 
 return model

从中可以看出,其转换步骤如下:

(1)创建pytorch的网络结构模型,设为model

(2)利用mxnet来读取其存储的预训练模型,得到mxnet_weights;

(3)遍历加载后模型mxnet_weights的state_dict().keys

(4)对一些指定的key值,需要进行相应的处理和转换

(5)对修改键名之后的key利用numpy之间的转换来实现加载。

为了实现上述转换,首先pip安装mxnet,现在新版的mxnet安装还是非常方便的。

MxNet预训练模型到Pytorch模型的转换方式

第二步,运行转换程序,实现预训练模型的转换。

MxNet预训练模型到Pytorch模型的转换方式

可以看到在相当的文件夹下已经出现了转换后的模型。

以上这篇MxNet预训练模型到Pytorch模型的转换方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Pyinstaller将py打包成exe的实例
Mar 31 Python
Python多继承原理与用法示例
Aug 23 Python
Python学习笔记之变量、自定义函数用法示例
May 28 Python
python的pygal模块绘制反正切函数图像方法
Jul 16 Python
解决Django后台ManyToManyField显示成Object的问题
Aug 09 Python
python中几种自动微分库解析
Aug 29 Python
详解python中各种文件打开模式
Jan 19 Python
对Python中 \r, \n, \r\n的彻底理解
Mar 06 Python
关于Django Models CharField 参数说明
Mar 31 Python
Python requests接口测试实现代码
Sep 08 Python
Python中使用Lambda函数的5种用法
Apr 01 Python
解决IDEA翻译插件Translation报错更新TTK失败不能使用
Apr 24 Python
浅谈pytorch 模型 .pt, .pth, .pkl的区别及模型保存方式
May 25 #Python
Pytorch通过保存为ONNX模型转TensorRT5的实现
May 25 #Python
tensorflow pb to tflite 精度下降详解
May 25 #Python
Python HTMLTestRunner测试报告view按钮失效解决方案
May 25 #Python
python用opencv完成图像分割并进行目标物的提取
May 25 #Python
Pytorch转tflite方式
May 25 #Python
Python HTMLTestRunner库安装过程解析
May 25 #Python
You might like
php获取远程文件大小
2015/10/20 PHP
WordPress开发中短代码的实现及相关函数使用技巧
2016/01/05 PHP
jquery 读取页面load get post ajax 四种方式代码写法
2011/04/02 Javascript
Extjs优化(一)删除冗余代码提高运行速度
2013/04/15 Javascript
JavaScript移除数组元素减少长度的方法
2013/09/05 Javascript
跟我学Nodejs(二)--- Node.js事件模块
2014/05/21 NodeJs
node.js使用require()函数加载模块
2014/11/26 Javascript
jqGrid表格应用之新增与删除数据附源码下载
2015/12/02 Javascript
ES6所改良的javascript“缺陷”问题
2016/08/23 Javascript
JS验证图片格式和大小并预览的简单实例
2016/10/11 Javascript
JS表单验证方法实例小结【电话、身份证号、Email、中文、特殊字符、身份证号等】
2017/02/14 Javascript
JS时间控制实现动态效果的实例讲解
2017/07/31 Javascript
如何重置vue打印变量的显示方式
2017/12/06 Javascript
详解如何在项目中使用jest测试react native组件
2018/02/09 Javascript
JS限制输入框输入的实现代码
2018/07/02 Javascript
webuploader实现上传图片到服务器功能
2018/08/16 Javascript
js中位运算的运用实例分析
2018/12/11 Javascript
JS中的算法与数据结构之常见排序(Sort)算法详解
2019/08/16 Javascript
Vue formData实现图片上传
2019/08/20 Javascript
nuxt.js 在middleware(中间件)中实现路由鉴权操作
2020/11/06 Javascript
超详细小程序定位地图模块全系列开发教学
2020/11/24 Javascript
Python连接SQLServer2000的方法详解
2017/04/19 Python
Python实现的排列组合计算操作示例
2017/10/13 Python
python中PS 图像调整算法原理之亮度调整
2019/06/28 Python
解决Jupyter无法导入已安装的 module问题
2020/04/17 Python
如何利用python web框架做文件流下载的实现示例
2020/06/02 Python
阿迪达斯西班牙官方网站:adidas西班牙
2016/07/21 全球购物
香蕉共和国Banana Republic官网:美国GAP旗下偏贵族风格服饰品牌
2016/11/21 全球购物
建筑安全标语
2014/06/07 职场文书
签订劳动合同通知书
2015/04/16 职场文书
有关水浒传的读书笔记
2015/06/25 职场文书
少先队中队工作总结2015
2015/07/23 职场文书
2015年小学师德师风建设工作总结
2015/10/23 职场文书
python基于tkinter制作m3u8视频下载工具
2021/04/24 Python
MySQL一些常用高级SQL语句
2021/07/03 MySQL
SpringBoot使用ip2region获取地理位置信息的方法
2022/06/21 Java/Android