Pytorch转tflite方式


Posted in Python onMay 25, 2020

目标是想把在服务器上用pytorch训练好的模型转换为可以在移动端运行的tflite模型。

最直接的思路是想把pytorch模型转换为tensorflow的模型,然后转换为tflite。但是这个转换目前没有发现比较靠谱的方法。

经过调研发现最新的tflite已经支持直接从keras模型的转换,所以可以采用keras作为中间转换的桥梁,这样就能充分利用keras高层API的便利性。

转换的基本思想就是用pytorch中的各层网络的权重取出来后直接赋值给keras网络中的对应layer层的权重。

转换为Keras模型后,再通过tf.contrib.lite.TocoConverter把模型直接转为tflite.

下面是一个例子,假设转换的是一个两层的CNN网络。

import tensorflow as tf
from tensorflow import keras
import numpy as np

import torch
from torchvision import models
import torch.nn as nn
# import torch.nn.functional as F
from torch.autograd import Variable

class PytorchNet(nn.Module):
 def __init__(self):
 super(PytorchNet, self).__init__()
 conv1 = nn.Sequential(
  nn.Conv2d(3, 32, 3, 2),
  nn.BatchNorm2d(32),
  nn.ReLU(inplace=True),
  nn.MaxPool2d(2, 2))
 conv2 = nn.Sequential(
  nn.Conv2d(32, 64, 3, 1, groups=1),
  nn.BatchNorm2d(64),
  nn.ReLU(inplace=True),
  nn.MaxPool2d(2, 2))
 self.feature = nn.Sequential(conv1, conv2)
 self.init_weights()

 def forward(self, x):
 return self.feature(x)

 def init_weights(self):
 for m in self.modules():
  if isinstance(m, nn.Conv2d):
  nn.init.kaiming_normal_(
   m.weight.data, mode='fan_out', nonlinearity='relu')
  if m.bias is not None:
   m.bias.data.zero_()
  if isinstance(m, nn.BatchNorm2d):
  m.weight.data.fill_(1)
  m.bias.data.zero_()

def KerasNet(input_shape=(224, 224, 3)):
 image_input = keras.layers.Input(shape=input_shape)
 # conv1
 network = keras.layers.Conv2D(
 32, (3, 3), strides=(2, 2), padding="valid")(image_input)
 network = keras.layers.BatchNormalization(
 trainable=False, fused=False)(network)
 network = keras.layers.Activation("relu")(network)
 network = keras.layers.MaxPool2D(pool_size=(2, 2), strides=(2, 2))(network)

 # conv2
 network = keras.layers.Conv2D(
 64, (3, 3), strides=(1, 1), padding="valid")(network)
 network = keras.layers.BatchNormalization(
 trainable=False, fused=True)(network)
 network = keras.layers.Activation("relu")(network)
 network = keras.layers.MaxPool2D(pool_size=(2, 2), strides=(2, 2))(network)

 model = keras.Model(inputs=image_input, outputs=network)

 return model

class PytorchToKeras(object):
 def __init__(self, pModel, kModel):
 super(PytorchToKeras, self)
 self.__source_layers = []
 self.__target_layers = []
 self.pModel = pModel
 self.kModel = kModel
 tf.keras.backend.set_learning_phase(0)

 def __retrieve_k_layers(self):
 for i, layer in enumerate(self.kModel.layers):
  if len(layer.weights) > 0:
  self.__target_layers.append(i)

 def __retrieve_p_layers(self, input_size):

 input = torch.randn(input_size)
 input = Variable(input.unsqueeze(0))
 hooks = []

 def add_hooks(module):

  def hook(module, input, output):
  if hasattr(module, "weight"):
   # print(module)
   self.__source_layers.append(module)

  if not isinstance(module, nn.ModuleList) and not isinstance(module, nn.Sequential) and module != self.pModel:
  hooks.append(module.register_forward_hook(hook))

 self.pModel.apply(add_hooks)

 self.pModel(input)
 for hook in hooks:
  hook.remove()

 def convert(self, input_size):
 self.__retrieve_k_layers()
 self.__retrieve_p_layers(input_size)

 for i, (source_layer, target_layer) in enumerate(zip(self.__source_layers, self.__target_layers)):
  print(source_layer)
  weight_size = len(source_layer.weight.data.size())
  transpose_dims = []
  for i in range(weight_size):
  transpose_dims.append(weight_size - i - 1)
  if isinstance(source_layer, nn.Conv2d):
  transpose_dims = [2,3,1,0]
  self.kModel.layers[target_layer].set_weights([source_layer.weight.data.numpy(
  ).transpose(transpose_dims), source_layer.bias.data.numpy()])
  elif isinstance(source_layer, nn.BatchNorm2d):
  self.kModel.layers[target_layer].set_weights([source_layer.weight.data.numpy(), source_layer.bias.data.numpy(),
        source_layer.running_mean.data.numpy(), source_layer.running_var.data.numpy()])
 def save_model(self, output_file):
 self.kModel.save(output_file)

 def save_weights(self, output_file):
 self.kModel.save_weights(output_file, save_format='h5')

pytorch_model = PytorchNet()
keras_model = KerasNet(input_shape=(224, 224, 3))

torch.save(pytorch_model, 'test.pth')

#Load the pretrained model
pytorch_model = torch.load('test.pth')

# #Time to transfer weights
converter = PytorchToKeras(pytorch_model, keras_model)
converter.convert((3, 224, 224))

# #Save the converted keras model for later use
# converter.save_weights("keras.h5")
converter.save_model("keras_model.h5")

# convert keras model to tflite model
converter = tf.contrib.lite.TocoConverter.from_keras_model_file(
 "keras_model.h5")
tflite_model = converter.convert()
open("convert_model.tflite", "wb").write(tflite_model)

补充知识:tensorflow模型转换成tensorflow lite模型

1.把graph和网络模型打包在一个文件中

bazel build tensorflow/python/tools:freeze_graph && \
 bazel-bin/tensorflow/python/tools/freeze_graph \
 --input_graph=eval_graph_def.pb \
 --input_checkpoint=checkpoint \
 --output_graph=frozen_eval_graph.pb \
 --output_node_names=outputs

For example:

bazel-bin/tensorflow/python/tools/freeze_graph \ 
 --input_graph=./mobilenet_v1_1.0_224/mobilenet_v1_1.0_224_eval.pbtxt \
 --input_checkpoint=./mobilenet_v1_1.0_224/mobilenet_v1_1.0_224.ckpt \
 --output_graph=./mobilenet_v1_1.0_224/frozen_eval_graph_test.pb \
 --output_node_names=MobilenetV1/Predictions/Reshape_1

2.把第一步中生成的tensorflow pb模型转换为tf lite模型

转换前需要先编译转换工具

bazel build tensorflow/contrib/lite/toco:toco

转换分两种,一种的转换为float的tf lite,另一种可以转换为对模型进行unit8的量化版本的模型。两种方式如下:

非量化的转换:

./bazel-bin/third_party/tensorflow/contrib/lite/toco/toco \ 官网给的这个路径不对       
./bazel-bin/tensorflow/contrib/lite/toco/toco \         
 —input_file=./mobilenet_v1_1.0_224/frozen_eval_graph_test.pb \  
 —output_file=./mobilenet_v1_1.0_224/tflite_model_test.tflite \  
 --input_format=TENSORFLOW_GRAPHDEF --output_format=TFLITE \       
 --inference_type=FLOAT \           
 --input_shape="1,224, 224,3" \           
 --input_array=input \            
 --output_array=MobilenetV1/Predictions/Reshape_1

量化方式的转换(注意,只有量化训练的模型才能进行量化的tf_lite转换):

./bazel-bin/third_party/tensorflow/contrib/lite/toco/toco \
./bazel-bin/tensorflow/contrib/lite/toco/toco \
 --input_file=frozen_eval_graph.pb \
 --output_file=tflite_model.tflite \
 --input_format=TENSORFLOW_GRAPHDEF --output_format=TFLITE \
 --inference_type=QUANTIZED_UINT8 \
 --input_shape="1,224, 224,3" \
 --input_array=input \
 --output_array=outputs \
 --std_value=127.5 --mean_value=127.5

以上这篇Pytorch转tflite方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python多线程实例教程
Sep 06 Python
Python中字符串格式化str.format的详细介绍
Feb 17 Python
Python实现简单石头剪刀布游戏
Jan 20 Python
Python中遍历列表的方法总结
Jun 27 Python
Python列表与元组的异同详解
Jul 02 Python
简单了解Django应用app及分布式路由
Jul 24 Python
Python进程间通信 multiProcessing Queue队列实现详解
Sep 23 Python
python实现单目标、多目标、多尺度、自定义特征的KCF跟踪算法(实例代码)
Jan 08 Python
Python版中国省市经纬度
Feb 11 Python
Python类和实例的属性机制原理详解
Mar 21 Python
Python类的继承super相关原理解析
Oct 22 Python
Python中的socket网络模块介绍
Jul 23 Python
Python HTMLTestRunner库安装过程解析
May 25 #Python
Anaconda+vscode+pytorch环境搭建过程详解
May 25 #Python
5行Python代码实现图像分割的步骤详解
May 25 #Python
Win10用vscode打开anaconda环境中的python出错问题的解决
May 25 #Python
keras .h5转移动端的.tflite文件实现方式
May 25 #Python
Python虚拟环境venv用法详解
May 25 #Python
将keras的h5模型转换为tensorflow的pb模型操作
May 25 #Python
You might like
用 PHP5 轻松解析 XML
2006/12/04 PHP
探讨:如何使用PHP实现计算两个日期间隔的年、月、周、日数
2013/06/13 PHP
php更新修改excel中的内容实例代码
2014/02/26 PHP
php页面函数设置超时限制的方法
2014/12/01 PHP
thinkPHP模型初始化实例分析
2015/12/03 PHP
yii2实现根据时间搜索的方法
2016/05/25 PHP
PHP那些琐碎的知识点(整理)
2017/05/20 PHP
一文掌握PHP Xdebug 本地与远程调试(小结)
2019/04/23 PHP
Laravel 登录后清空COOKIE的操作方法
2019/10/14 PHP
ExtJS 2.0 GridPanel基本表格简明教程
2010/05/25 Javascript
DD_belatedPNG,IE6下PNG透明解决方案(国外)
2010/12/06 Javascript
Extjs3.0 checkboxGroup 动态添加item实现思路
2013/08/14 Javascript
javascript控制在光标位置插入文字适合表情的插入
2014/06/09 Javascript
JavaScript学习笔记之内置对象
2015/01/22 Javascript
简介JavaScript中toTimeString()方法的使用
2015/06/12 Javascript
Bootstrap弹出带合法性检查的登录框实例代码【推荐】
2016/06/23 Javascript
js创建对象几种方式的优缺点对比
2016/09/28 Javascript
Vue.js使用v-show和v-if的注意事项
2016/12/13 Javascript
使用 Github Actions 自动部署 Angular 应用到 Github Pages的方法
2020/07/20 Javascript
Python可变参数函数用法实例
2015/07/07 Python
numpy找出array中的最大值,最小值实例
2018/04/03 Python
python修改list中所有元素类型的三种方法
2018/04/09 Python
python实现按长宽比缩放图片
2018/06/07 Python
pandas 选择某几列的方法
2018/07/03 Python
Python Tensor FLow简单使用方法实例详解
2020/01/14 Python
Python enumerate() 函数如何实现索引功能
2020/06/29 Python
html5 canvas实现跟随鼠标旋转的箭头
2016/03/11 HTML / CSS
全面解析HTML5中的标准属性与自定义属性
2016/02/18 HTML / CSS
Prototype中如何为一个元素添加一个方法
2014/12/08 面试题
口腔医学技术应届生求职信
2013/11/09 职场文书
护士检查书
2014/01/17 职场文书
贷款承诺书
2015/01/20 职场文书
如何撰写创业策划书
2019/06/27 职场文书
Go使用协程交替打印字符
2021/04/29 Golang
解决ObjectMapper.convertValue() 遇到的一些问题
2021/06/30 Java/Android
python脚本框架webpy的url映射详解
2021/11/20 Python