Pytorch转tflite方式


Posted in Python onMay 25, 2020

目标是想把在服务器上用pytorch训练好的模型转换为可以在移动端运行的tflite模型。

最直接的思路是想把pytorch模型转换为tensorflow的模型,然后转换为tflite。但是这个转换目前没有发现比较靠谱的方法。

经过调研发现最新的tflite已经支持直接从keras模型的转换,所以可以采用keras作为中间转换的桥梁,这样就能充分利用keras高层API的便利性。

转换的基本思想就是用pytorch中的各层网络的权重取出来后直接赋值给keras网络中的对应layer层的权重。

转换为Keras模型后,再通过tf.contrib.lite.TocoConverter把模型直接转为tflite.

下面是一个例子,假设转换的是一个两层的CNN网络。

import tensorflow as tf
from tensorflow import keras
import numpy as np

import torch
from torchvision import models
import torch.nn as nn
# import torch.nn.functional as F
from torch.autograd import Variable

class PytorchNet(nn.Module):
 def __init__(self):
 super(PytorchNet, self).__init__()
 conv1 = nn.Sequential(
  nn.Conv2d(3, 32, 3, 2),
  nn.BatchNorm2d(32),
  nn.ReLU(inplace=True),
  nn.MaxPool2d(2, 2))
 conv2 = nn.Sequential(
  nn.Conv2d(32, 64, 3, 1, groups=1),
  nn.BatchNorm2d(64),
  nn.ReLU(inplace=True),
  nn.MaxPool2d(2, 2))
 self.feature = nn.Sequential(conv1, conv2)
 self.init_weights()

 def forward(self, x):
 return self.feature(x)

 def init_weights(self):
 for m in self.modules():
  if isinstance(m, nn.Conv2d):
  nn.init.kaiming_normal_(
   m.weight.data, mode='fan_out', nonlinearity='relu')
  if m.bias is not None:
   m.bias.data.zero_()
  if isinstance(m, nn.BatchNorm2d):
  m.weight.data.fill_(1)
  m.bias.data.zero_()

def KerasNet(input_shape=(224, 224, 3)):
 image_input = keras.layers.Input(shape=input_shape)
 # conv1
 network = keras.layers.Conv2D(
 32, (3, 3), strides=(2, 2), padding="valid")(image_input)
 network = keras.layers.BatchNormalization(
 trainable=False, fused=False)(network)
 network = keras.layers.Activation("relu")(network)
 network = keras.layers.MaxPool2D(pool_size=(2, 2), strides=(2, 2))(network)

 # conv2
 network = keras.layers.Conv2D(
 64, (3, 3), strides=(1, 1), padding="valid")(network)
 network = keras.layers.BatchNormalization(
 trainable=False, fused=True)(network)
 network = keras.layers.Activation("relu")(network)
 network = keras.layers.MaxPool2D(pool_size=(2, 2), strides=(2, 2))(network)

 model = keras.Model(inputs=image_input, outputs=network)

 return model

class PytorchToKeras(object):
 def __init__(self, pModel, kModel):
 super(PytorchToKeras, self)
 self.__source_layers = []
 self.__target_layers = []
 self.pModel = pModel
 self.kModel = kModel
 tf.keras.backend.set_learning_phase(0)

 def __retrieve_k_layers(self):
 for i, layer in enumerate(self.kModel.layers):
  if len(layer.weights) > 0:
  self.__target_layers.append(i)

 def __retrieve_p_layers(self, input_size):

 input = torch.randn(input_size)
 input = Variable(input.unsqueeze(0))
 hooks = []

 def add_hooks(module):

  def hook(module, input, output):
  if hasattr(module, "weight"):
   # print(module)
   self.__source_layers.append(module)

  if not isinstance(module, nn.ModuleList) and not isinstance(module, nn.Sequential) and module != self.pModel:
  hooks.append(module.register_forward_hook(hook))

 self.pModel.apply(add_hooks)

 self.pModel(input)
 for hook in hooks:
  hook.remove()

 def convert(self, input_size):
 self.__retrieve_k_layers()
 self.__retrieve_p_layers(input_size)

 for i, (source_layer, target_layer) in enumerate(zip(self.__source_layers, self.__target_layers)):
  print(source_layer)
  weight_size = len(source_layer.weight.data.size())
  transpose_dims = []
  for i in range(weight_size):
  transpose_dims.append(weight_size - i - 1)
  if isinstance(source_layer, nn.Conv2d):
  transpose_dims = [2,3,1,0]
  self.kModel.layers[target_layer].set_weights([source_layer.weight.data.numpy(
  ).transpose(transpose_dims), source_layer.bias.data.numpy()])
  elif isinstance(source_layer, nn.BatchNorm2d):
  self.kModel.layers[target_layer].set_weights([source_layer.weight.data.numpy(), source_layer.bias.data.numpy(),
        source_layer.running_mean.data.numpy(), source_layer.running_var.data.numpy()])
 def save_model(self, output_file):
 self.kModel.save(output_file)

 def save_weights(self, output_file):
 self.kModel.save_weights(output_file, save_format='h5')

pytorch_model = PytorchNet()
keras_model = KerasNet(input_shape=(224, 224, 3))

torch.save(pytorch_model, 'test.pth')

#Load the pretrained model
pytorch_model = torch.load('test.pth')

# #Time to transfer weights
converter = PytorchToKeras(pytorch_model, keras_model)
converter.convert((3, 224, 224))

# #Save the converted keras model for later use
# converter.save_weights("keras.h5")
converter.save_model("keras_model.h5")

# convert keras model to tflite model
converter = tf.contrib.lite.TocoConverter.from_keras_model_file(
 "keras_model.h5")
tflite_model = converter.convert()
open("convert_model.tflite", "wb").write(tflite_model)

补充知识:tensorflow模型转换成tensorflow lite模型

1.把graph和网络模型打包在一个文件中

bazel build tensorflow/python/tools:freeze_graph && \
 bazel-bin/tensorflow/python/tools/freeze_graph \
 --input_graph=eval_graph_def.pb \
 --input_checkpoint=checkpoint \
 --output_graph=frozen_eval_graph.pb \
 --output_node_names=outputs

For example:

bazel-bin/tensorflow/python/tools/freeze_graph \ 
 --input_graph=./mobilenet_v1_1.0_224/mobilenet_v1_1.0_224_eval.pbtxt \
 --input_checkpoint=./mobilenet_v1_1.0_224/mobilenet_v1_1.0_224.ckpt \
 --output_graph=./mobilenet_v1_1.0_224/frozen_eval_graph_test.pb \
 --output_node_names=MobilenetV1/Predictions/Reshape_1

2.把第一步中生成的tensorflow pb模型转换为tf lite模型

转换前需要先编译转换工具

bazel build tensorflow/contrib/lite/toco:toco

转换分两种,一种的转换为float的tf lite,另一种可以转换为对模型进行unit8的量化版本的模型。两种方式如下:

非量化的转换:

./bazel-bin/third_party/tensorflow/contrib/lite/toco/toco \ 官网给的这个路径不对       
./bazel-bin/tensorflow/contrib/lite/toco/toco \         
 —input_file=./mobilenet_v1_1.0_224/frozen_eval_graph_test.pb \  
 —output_file=./mobilenet_v1_1.0_224/tflite_model_test.tflite \  
 --input_format=TENSORFLOW_GRAPHDEF --output_format=TFLITE \       
 --inference_type=FLOAT \           
 --input_shape="1,224, 224,3" \           
 --input_array=input \            
 --output_array=MobilenetV1/Predictions/Reshape_1

量化方式的转换(注意,只有量化训练的模型才能进行量化的tf_lite转换):

./bazel-bin/third_party/tensorflow/contrib/lite/toco/toco \
./bazel-bin/tensorflow/contrib/lite/toco/toco \
 --input_file=frozen_eval_graph.pb \
 --output_file=tflite_model.tflite \
 --input_format=TENSORFLOW_GRAPHDEF --output_format=TFLITE \
 --inference_type=QUANTIZED_UINT8 \
 --input_shape="1,224, 224,3" \
 --input_array=input \
 --output_array=outputs \
 --std_value=127.5 --mean_value=127.5

以上这篇Pytorch转tflite方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python使用BeautifulSoup库解析HTML基本使用教程
Mar 31 Python
详谈python read readline readlines的区别
Sep 22 Python
pandas.DataFrame 根据条件新建列并赋值的方法
Apr 08 Python
从django的中间件直接返回请求的方法
May 30 Python
python 实现将字典dict、列表list中的中文正常显示方法
Jul 06 Python
Python开发之Nginx+uWSGI+virtualenv多项目部署教程
May 13 Python
Python 如何优雅的将数字转化为时间格式的方法
Sep 26 Python
python可视化实现KNN算法
Oct 16 Python
Python字符串hashlib加密模块使用案例
Mar 10 Python
python Matplotlib数据可视化(1):简单入门
Sep 30 Python
pytorch 计算Parameter和FLOP的操作
Mar 04 Python
python创建字典及相关管理操作
Apr 13 Python
Python HTMLTestRunner库安装过程解析
May 25 #Python
Anaconda+vscode+pytorch环境搭建过程详解
May 25 #Python
5行Python代码实现图像分割的步骤详解
May 25 #Python
Win10用vscode打开anaconda环境中的python出错问题的解决
May 25 #Python
keras .h5转移动端的.tflite文件实现方式
May 25 #Python
Python虚拟环境venv用法详解
May 25 #Python
将keras的h5模型转换为tensorflow的pb模型操作
May 25 #Python
You might like
PHP获取时间排除周六、周日的两个方法
2014/06/30 PHP
PHP按指定键值对二维数组进行排序的方法
2015/12/22 PHP
php使用QueryList轻松采集js动态渲染页面方法
2018/09/11 PHP
基于Web标准的UI组件 — 树状菜单(2)
2006/09/18 Javascript
window.opener用法和用途实例介绍
2013/08/19 Javascript
js 限制input只能输入数字、字母和汉字等等
2013/12/18 Javascript
jQuery(js)获取文字宽度(显示长度)示例代码
2013/12/31 Javascript
使用js实现一个可编辑的select下拉列表
2014/02/20 Javascript
JavaScript实现文字跟随鼠标特效
2015/08/06 Javascript
Angular发布1.5正式版,专注于向Angular 2的过渡
2016/02/18 Javascript
Highcharts+NodeJS搭建数据可视化平台示例
2017/01/01 NodeJs
基于vue实现分页效果
2017/11/06 Javascript
详解vue.js之props传递参数
2017/12/12 Javascript
js实现把时间戳转换为yyyy-MM-dd hh:mm 格式(es6语法)
2017/12/28 Javascript
解决vue-router中的query动态传参问题
2018/03/20 Javascript
vue3.0 CLI - 3.2 路由的初级使用教程
2018/09/20 Javascript
小程序采集录音并上传到后台
2019/11/22 Javascript
javascript自定义右键菜单插件
2019/12/16 Javascript
vue vantUI tab切换时 list组件不触发load事件的问题及解决方法
2020/02/14 Javascript
一篇超完整的Vue新手入门指导教程
2020/11/18 Vue.js
rhythmbox中文名乱码问题解决方法
2008/09/06 Python
Python中的rjust()方法使用详解
2015/05/19 Python
一个Python最简单的接口自动化框架
2018/01/02 Python
Python 通过截图匹配原图中的位置(opencv)实例
2019/08/27 Python
python实现多线程端口扫描
2019/08/31 Python
python web框架Flask实现图形验证码及验证码的动态刷新实例
2019/10/14 Python
python 视频逐帧保存为图片的完整实例
2019/12/10 Python
Python实现自动访问网页的例子
2020/02/21 Python
Python selenium 加载并保存QQ群成员,去除其群主、管理员信息的示例代码
2020/05/28 Python
移动Web—CSS为Retina屏幕替换更高质量的图片
2012/12/24 HTML / CSS
css3 实现圆形旋转倒计时
2018/02/24 HTML / CSS
微软中国官方商城:Microsoft Store中国
2018/10/12 全球购物
网络工程师自荐书范文
2014/04/01 职场文书
写给医生的感谢信
2015/01/22 职场文书
2015年春训学习心得体会范文
2015/03/09 职场文书
总经理致辞
2015/07/29 职场文书