keras实现theano和tensorflow训练的模型相互转换


Posted in Python onJune 19, 2020

我就废话不多说了,大家还是直接看代码吧~

</pre><pre code_snippet_id="1947416" snippet_file_name="blog_20161025_1_3331239" name="code" class="python">

# coding:utf-8
"""
If you want to load pre-trained weights that include convolutions (layers Convolution2D or Convolution1D),
be mindful of this: Theano and TensorFlow implement convolution in different ways (TensorFlow actually implements correlation, much like Caffe),
and thus, convolution kernels trained with Theano (resp. TensorFlow) need to be converted before being with TensorFlow (resp. Theano).
"""
from keras import backend as K
from keras.utils.np_utils import convert_kernel
from text_classifier import keras_text_classifier
import sys
 
def th2tf( model):
  import tensorflow as tf
  ops = []
  for layer in model.layers:
    if layer.__class__.__name__ in ['Convolution1D', 'Convolution2D']:
      original_w = K.get_value(layer.W)
      converted_w = convert_kernel(original_w)
      ops.append(tf.assign(layer.W, converted_w).op)
  K.get_session().run(ops)
  return model
 
def tf2th(model):
  for layer in model.layers:
    if layer.__class__.__name__ in ['Convolution1D', 'Convolution2D']:
      original_w = K.get_value(layer.W)
      converted_w = convert_kernel(original_w)
      K.set_value(layer.W, converted_w)
  return model
 
def conv_layer_converted(tf_weights, th_weights, m = 0):
  """
  :param tf_weights:
  :param th_weights:
  :param m: 0-tf2th, 1-th2tf
  :return:
  """
  if m == 0: # tf2th
    tc = keras_text_classifier(weights_path=tf_weights)
    model = tc.loadmodel()
    model = tf2th(model)
    model.save_weights(th_weights)
  elif m == 1: # th2tf
    tc = keras_text_classifier(weights_path=th_weights)
    model = tc.loadmodel()
    model = th2tf(model)
    model.save_weights(tf_weights)
  else:
    print("0-tf2th, 1-th2tf")
    return
if __name__ == '__main__':
  if len(sys.argv) < 4:
    print("python tf_weights th_weights <0|1>\n0-tensorflow to theano\n1-theano to tensorflow")
    sys.exit(0)
  tf_weights = sys.argv[1]
  th_weights = sys.argv[2]
  m = int(sys.argv[3])
  conv_layer_converted(tf_weights, th_weights, m)

补充知识:keras学习之修改底层为TensorFlow还是theano

我们知道,keras的底层是TensorFlow或者theano

要知道我们是用的哪个为底层,只需要import keras即可显示

修改方法:

打开

keras实现theano和tensorflow训练的模型相互转换

修改

keras实现theano和tensorflow训练的模型相互转换

以上这篇keras实现theano和tensorflow训练的模型相互转换就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python时间模块datetime、time、calendar的使用方法
Jan 13 Python
Python实现选择排序
Jun 04 Python
使用Python写一个贪吃蛇游戏实例代码
Aug 21 Python
windows 下python+numpy安装实用教程
Dec 23 Python
对numpy中轴与维度的理解
Apr 18 Python
Python3.5面向对象程序设计之类的继承和多态详解
Apr 24 Python
windows上安装python3教程以及环境变量配置详解
Jul 18 Python
python数据持久存储 pickle模块的基本使用方法解析
Aug 30 Python
使用Python和OpenCV检测图像中的物体并将物体裁剪下来
Oct 30 Python
Python 剪绳子的多种思路实现(动态规划和贪心)
Feb 24 Python
在python中使用pymysql往mysql数据库中插入(insert)数据实例
Mar 02 Python
Keras 切换后端方式(Theano和TensorFlow)
Jun 19 #Python
python中怎么表示空值
Jun 19 #Python
Python调用OpenCV实现图像平滑代码实例
Jun 19 #Python
使用OpenCV对车道进行实时检测的实现示例代码
Jun 19 #Python
为什么python比较流行
Jun 19 #Python
查看keras的默认backend实现方式
Jun 19 #Python
Python图像阈值化处理及算法比对实例解析
Jun 19 #Python
You might like
php 操作excel文件的方法小结
2009/12/31 PHP
PHP5与MySQL数据库操作常用代码 收集
2010/03/21 PHP
PHP隐形一句话后门,和ThinkPHP框架加密码程序(base64_decode)
2011/11/02 PHP
php删除数组元素示例分享
2014/02/17 PHP
php源码分析之DZX1.5加密解密函数authcode用法
2015/06/17 PHP
gridpanel动态加载数据的实例代码
2013/07/18 Javascript
用js判断是否为360浏览器的实现代码
2015/01/15 Javascript
JavaScript制作简易的微信打飞机
2015/03/31 Javascript
JS+CSS实现精美的二级导航效果代码
2015/09/17 Javascript
浅析JS异步加载进度条
2016/05/05 Javascript
基于MVC+EasyUI的web开发框架之使用云打印控件C-Lodop打印页面或套打报关运单信息
2016/08/29 Javascript
JavaScript中利用for循环遍历数组
2017/01/15 Javascript
Node.js连接mongodb实例代码
2017/06/06 Javascript
在Vue组件上动态添加和删除属性方法
2018/02/23 Javascript
vue 做移动端微信公众号采坑经验记录
2018/04/26 Javascript
PostgreSQL Node.js实现函数计算方法示例
2019/02/12 Javascript
vue多层嵌套路由实例分析
2019/03/19 Javascript
javascript异步处理与Jquery deferred对象用法总结
2019/06/04 jQuery
Vue 实现复制功能,不需要任何结构内容直接复制方式
2019/11/09 Javascript
Python操作MySQL简单实现方法
2015/01/26 Python
Python实现在某个数组中查找一个值的算法示例
2018/06/27 Python
win10系统下Anaconda3安装配置方法图文教程
2018/09/19 Python
numpy concatenate数组拼接方法示例介绍
2019/05/27 Python
python中sort和sorted排序的实例方法
2019/08/26 Python
GitHub上值得推荐的8个python 项目
2020/10/30 Python
美国最受欢迎的度假目的地优惠套餐:BookVIP
2018/09/27 全球购物
最便宜促销价格订机票:Airpaz(总部设在印尼,支持中文)
2018/11/13 全球购物
如何用Java实现列出某个目录下的所有子目录
2015/07/20 面试题
新闻学专业应届生求职信
2013/11/08 职场文书
2014年售后服务工作总结
2014/11/18 职场文书
孕妇离婚协议书范本
2014/11/20 职场文书
2014个人年度工作总结
2014/12/15 职场文书
大学生自荐信范文
2015/03/05 职场文书
开天辟地观后感
2015/06/09 职场文书
教师年度考核自我评鉴
2015/08/11 职场文书
关于MySQL中的 like操作符详情
2021/11/17 MySQL