Pytorch转tflite方式


Posted in Python onMay 25, 2020

目标是想把在服务器上用pytorch训练好的模型转换为可以在移动端运行的tflite模型。

最直接的思路是想把pytorch模型转换为tensorflow的模型,然后转换为tflite。但是这个转换目前没有发现比较靠谱的方法。

经过调研发现最新的tflite已经支持直接从keras模型的转换,所以可以采用keras作为中间转换的桥梁,这样就能充分利用keras高层API的便利性。

转换的基本思想就是用pytorch中的各层网络的权重取出来后直接赋值给keras网络中的对应layer层的权重。

转换为Keras模型后,再通过tf.contrib.lite.TocoConverter把模型直接转为tflite.

下面是一个例子,假设转换的是一个两层的CNN网络。

import tensorflow as tf
from tensorflow import keras
import numpy as np

import torch
from torchvision import models
import torch.nn as nn
# import torch.nn.functional as F
from torch.autograd import Variable

class PytorchNet(nn.Module):
 def __init__(self):
 super(PytorchNet, self).__init__()
 conv1 = nn.Sequential(
  nn.Conv2d(3, 32, 3, 2),
  nn.BatchNorm2d(32),
  nn.ReLU(inplace=True),
  nn.MaxPool2d(2, 2))
 conv2 = nn.Sequential(
  nn.Conv2d(32, 64, 3, 1, groups=1),
  nn.BatchNorm2d(64),
  nn.ReLU(inplace=True),
  nn.MaxPool2d(2, 2))
 self.feature = nn.Sequential(conv1, conv2)
 self.init_weights()

 def forward(self, x):
 return self.feature(x)

 def init_weights(self):
 for m in self.modules():
  if isinstance(m, nn.Conv2d):
  nn.init.kaiming_normal_(
   m.weight.data, mode='fan_out', nonlinearity='relu')
  if m.bias is not None:
   m.bias.data.zero_()
  if isinstance(m, nn.BatchNorm2d):
  m.weight.data.fill_(1)
  m.bias.data.zero_()

def KerasNet(input_shape=(224, 224, 3)):
 image_input = keras.layers.Input(shape=input_shape)
 # conv1
 network = keras.layers.Conv2D(
 32, (3, 3), strides=(2, 2), padding="valid")(image_input)
 network = keras.layers.BatchNormalization(
 trainable=False, fused=False)(network)
 network = keras.layers.Activation("relu")(network)
 network = keras.layers.MaxPool2D(pool_size=(2, 2), strides=(2, 2))(network)

 # conv2
 network = keras.layers.Conv2D(
 64, (3, 3), strides=(1, 1), padding="valid")(network)
 network = keras.layers.BatchNormalization(
 trainable=False, fused=True)(network)
 network = keras.layers.Activation("relu")(network)
 network = keras.layers.MaxPool2D(pool_size=(2, 2), strides=(2, 2))(network)

 model = keras.Model(inputs=image_input, outputs=network)

 return model

class PytorchToKeras(object):
 def __init__(self, pModel, kModel):
 super(PytorchToKeras, self)
 self.__source_layers = []
 self.__target_layers = []
 self.pModel = pModel
 self.kModel = kModel
 tf.keras.backend.set_learning_phase(0)

 def __retrieve_k_layers(self):
 for i, layer in enumerate(self.kModel.layers):
  if len(layer.weights) > 0:
  self.__target_layers.append(i)

 def __retrieve_p_layers(self, input_size):

 input = torch.randn(input_size)
 input = Variable(input.unsqueeze(0))
 hooks = []

 def add_hooks(module):

  def hook(module, input, output):
  if hasattr(module, "weight"):
   # print(module)
   self.__source_layers.append(module)

  if not isinstance(module, nn.ModuleList) and not isinstance(module, nn.Sequential) and module != self.pModel:
  hooks.append(module.register_forward_hook(hook))

 self.pModel.apply(add_hooks)

 self.pModel(input)
 for hook in hooks:
  hook.remove()

 def convert(self, input_size):
 self.__retrieve_k_layers()
 self.__retrieve_p_layers(input_size)

 for i, (source_layer, target_layer) in enumerate(zip(self.__source_layers, self.__target_layers)):
  print(source_layer)
  weight_size = len(source_layer.weight.data.size())
  transpose_dims = []
  for i in range(weight_size):
  transpose_dims.append(weight_size - i - 1)
  if isinstance(source_layer, nn.Conv2d):
  transpose_dims = [2,3,1,0]
  self.kModel.layers[target_layer].set_weights([source_layer.weight.data.numpy(
  ).transpose(transpose_dims), source_layer.bias.data.numpy()])
  elif isinstance(source_layer, nn.BatchNorm2d):
  self.kModel.layers[target_layer].set_weights([source_layer.weight.data.numpy(), source_layer.bias.data.numpy(),
        source_layer.running_mean.data.numpy(), source_layer.running_var.data.numpy()])
 def save_model(self, output_file):
 self.kModel.save(output_file)

 def save_weights(self, output_file):
 self.kModel.save_weights(output_file, save_format='h5')

pytorch_model = PytorchNet()
keras_model = KerasNet(input_shape=(224, 224, 3))

torch.save(pytorch_model, 'test.pth')

#Load the pretrained model
pytorch_model = torch.load('test.pth')

# #Time to transfer weights
converter = PytorchToKeras(pytorch_model, keras_model)
converter.convert((3, 224, 224))

# #Save the converted keras model for later use
# converter.save_weights("keras.h5")
converter.save_model("keras_model.h5")

# convert keras model to tflite model
converter = tf.contrib.lite.TocoConverter.from_keras_model_file(
 "keras_model.h5")
tflite_model = converter.convert()
open("convert_model.tflite", "wb").write(tflite_model)

补充知识:tensorflow模型转换成tensorflow lite模型

1.把graph和网络模型打包在一个文件中

bazel build tensorflow/python/tools:freeze_graph && \
 bazel-bin/tensorflow/python/tools/freeze_graph \
 --input_graph=eval_graph_def.pb \
 --input_checkpoint=checkpoint \
 --output_graph=frozen_eval_graph.pb \
 --output_node_names=outputs

For example:

bazel-bin/tensorflow/python/tools/freeze_graph \ 
 --input_graph=./mobilenet_v1_1.0_224/mobilenet_v1_1.0_224_eval.pbtxt \
 --input_checkpoint=./mobilenet_v1_1.0_224/mobilenet_v1_1.0_224.ckpt \
 --output_graph=./mobilenet_v1_1.0_224/frozen_eval_graph_test.pb \
 --output_node_names=MobilenetV1/Predictions/Reshape_1

2.把第一步中生成的tensorflow pb模型转换为tf lite模型

转换前需要先编译转换工具

bazel build tensorflow/contrib/lite/toco:toco

转换分两种,一种的转换为float的tf lite,另一种可以转换为对模型进行unit8的量化版本的模型。两种方式如下:

非量化的转换:

./bazel-bin/third_party/tensorflow/contrib/lite/toco/toco \ 官网给的这个路径不对       
./bazel-bin/tensorflow/contrib/lite/toco/toco \         
 —input_file=./mobilenet_v1_1.0_224/frozen_eval_graph_test.pb \  
 —output_file=./mobilenet_v1_1.0_224/tflite_model_test.tflite \  
 --input_format=TENSORFLOW_GRAPHDEF --output_format=TFLITE \       
 --inference_type=FLOAT \           
 --input_shape="1,224, 224,3" \           
 --input_array=input \            
 --output_array=MobilenetV1/Predictions/Reshape_1

量化方式的转换(注意,只有量化训练的模型才能进行量化的tf_lite转换):

./bazel-bin/third_party/tensorflow/contrib/lite/toco/toco \
./bazel-bin/tensorflow/contrib/lite/toco/toco \
 --input_file=frozen_eval_graph.pb \
 --output_file=tflite_model.tflite \
 --input_format=TENSORFLOW_GRAPHDEF --output_format=TFLITE \
 --inference_type=QUANTIZED_UINT8 \
 --input_shape="1,224, 224,3" \
 --input_array=input \
 --output_array=outputs \
 --std_value=127.5 --mean_value=127.5

以上这篇Pytorch转tflite方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
在Python的Django框架的视图中使用Session的方法
Jul 23 Python
python基础之入门必看操作
Jul 26 Python
python使用mysql的两种使用方式
Mar 07 Python
Python实现将通信达.day文件读取为DataFrame
Dec 22 Python
python隐藏终端执行cmd命令的方法
Jun 24 Python
Django中celery执行任务结果的保存方法
Jul 12 Python
Python实现最常见加密方式详解
Jul 13 Python
python自动循环定时开关机(非重启)测试
Aug 26 Python
linux 下selenium chrome使用详解
Apr 02 Python
用python打开摄像头并把图像传回qq邮箱(Pyinstaller打包)
May 17 Python
tensorflow实现从.ckpt文件中读取任意变量
May 26 Python
python实现视频压缩功能
Dec 18 Python
Python HTMLTestRunner库安装过程解析
May 25 #Python
Anaconda+vscode+pytorch环境搭建过程详解
May 25 #Python
5行Python代码实现图像分割的步骤详解
May 25 #Python
Win10用vscode打开anaconda环境中的python出错问题的解决
May 25 #Python
keras .h5转移动端的.tflite文件实现方式
May 25 #Python
Python虚拟环境venv用法详解
May 25 #Python
将keras的h5模型转换为tensorflow的pb模型操作
May 25 #Python
You might like
滚动图片效果 jquery实现回旋滚动效果
2013/01/08 Javascript
JavaScript通过Date-Mask将日期转换成字符串的方法
2015/06/04 Javascript
js如何实现淡入淡出效果
2020/11/18 Javascript
Javascript基于AJAX回调函数传递参数实例分析
2015/12/15 Javascript
JavaScript编程学习技巧汇总
2016/02/21 Javascript
浅谈JavaScript中的this指针和引用知识
2016/08/05 Javascript
js仿网易表单及时验证功能
2017/03/07 Javascript
利用Vue.js+Node.js+MongoDB实现一个博客系统(附源码)
2017/04/24 Javascript
利用chrome浏览器进行js调试并找出元素绑定的点击事件详解
2021/01/30 Javascript
Vue递归实现树形菜单方法实例
2018/11/06 Javascript
微信小程序实现展示评分结果功能
2019/02/15 Javascript
JavaScript中将值转换为字符串的五种方法总结
2019/06/06 Javascript
如何在JavaScript中谨慎使用代码注释
2019/06/21 Javascript
微信小程序解析富文本过程详解
2019/07/13 Javascript
js图数据结构处理 迪杰斯特拉算法代码实例
2019/09/11 Javascript
Vue中watch、computed、updated三者的区别及用法
2020/07/27 Javascript
vue 路由meta 设置导航隐藏与显示功能的示例代码
2020/09/04 Javascript
python使用str & repr转换字符串
2016/10/13 Python
Python 多线程的实例详解
2017/09/07 Python
python如何使用正则表达式的前向、后向搜索及前向搜索否定模式详解
2017/11/08 Python
django的登录注册系统的示例代码
2018/05/14 Python
Python类中的魔法方法之 __slots__原理解析
2019/08/26 Python
Python 解决相对路径问题:"No such file or directory"
2020/06/05 Python
Omio葡萄牙:全欧洲低价大巴、火车和航班搜索和比价
2019/02/09 全球购物
阿里巴巴Oracle DBA笔试题答案-备份恢复类
2013/11/20 面试题
如何利用find命令查找文件
2016/11/18 面试题
自动化职业生涯规划书范文
2014/01/03 职场文书
办公室主任职责范本
2014/03/07 职场文书
艺术节主持词
2014/04/02 职场文书
班级口号大全
2014/06/09 职场文书
英语自我介绍演讲稿
2014/09/01 职场文书
2014医学院领导干部四风对照检查材料思想汇报
2014/09/16 职场文书
python实现层次聚类的方法
2021/11/01 Python
Nginx利用Logrotate实现日志分割
2022/05/20 Servers
Windows Server 修改远程桌面端口的实现
2022/06/25 Servers
html原生table实现合并单元格以及合并表头的示例代码
2023/05/07 HTML / CSS