keras实现theano和tensorflow训练的模型相互转换


Posted in Python onJune 19, 2020

我就废话不多说了,大家还是直接看代码吧~

</pre><pre code_snippet_id="1947416" snippet_file_name="blog_20161025_1_3331239" name="code" class="python">

# coding:utf-8
"""
If you want to load pre-trained weights that include convolutions (layers Convolution2D or Convolution1D),
be mindful of this: Theano and TensorFlow implement convolution in different ways (TensorFlow actually implements correlation, much like Caffe),
and thus, convolution kernels trained with Theano (resp. TensorFlow) need to be converted before being with TensorFlow (resp. Theano).
"""
from keras import backend as K
from keras.utils.np_utils import convert_kernel
from text_classifier import keras_text_classifier
import sys
 
def th2tf( model):
  import tensorflow as tf
  ops = []
  for layer in model.layers:
    if layer.__class__.__name__ in ['Convolution1D', 'Convolution2D']:
      original_w = K.get_value(layer.W)
      converted_w = convert_kernel(original_w)
      ops.append(tf.assign(layer.W, converted_w).op)
  K.get_session().run(ops)
  return model
 
def tf2th(model):
  for layer in model.layers:
    if layer.__class__.__name__ in ['Convolution1D', 'Convolution2D']:
      original_w = K.get_value(layer.W)
      converted_w = convert_kernel(original_w)
      K.set_value(layer.W, converted_w)
  return model
 
def conv_layer_converted(tf_weights, th_weights, m = 0):
  """
  :param tf_weights:
  :param th_weights:
  :param m: 0-tf2th, 1-th2tf
  :return:
  """
  if m == 0: # tf2th
    tc = keras_text_classifier(weights_path=tf_weights)
    model = tc.loadmodel()
    model = tf2th(model)
    model.save_weights(th_weights)
  elif m == 1: # th2tf
    tc = keras_text_classifier(weights_path=th_weights)
    model = tc.loadmodel()
    model = th2tf(model)
    model.save_weights(tf_weights)
  else:
    print("0-tf2th, 1-th2tf")
    return
if __name__ == '__main__':
  if len(sys.argv) < 4:
    print("python tf_weights th_weights <0|1>\n0-tensorflow to theano\n1-theano to tensorflow")
    sys.exit(0)
  tf_weights = sys.argv[1]
  th_weights = sys.argv[2]
  m = int(sys.argv[3])
  conv_layer_converted(tf_weights, th_weights, m)

补充知识:keras学习之修改底层为TensorFlow还是theano

我们知道,keras的底层是TensorFlow或者theano

要知道我们是用的哪个为底层,只需要import keras即可显示

修改方法:

打开

keras实现theano和tensorflow训练的模型相互转换

修改

keras实现theano和tensorflow训练的模型相互转换

以上这篇keras实现theano和tensorflow训练的模型相互转换就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python基础教程之基本数据类型和变量声明介绍
Aug 29 Python
Python深入学习之装饰器
Aug 31 Python
python获取mp3文件信息的方法
Jun 15 Python
python自动翻译实现方法
May 28 Python
Django查询数据库的性能优化示例代码
Sep 24 Python
Python pandas DataFrame操作的实现代码
Jun 21 Python
Python使用matplotlib实现交换式图形显示功能示例
Sep 06 Python
Linux下升级安装python3.8并配置pip及yum的教程
Jan 02 Python
matplotlib jupyter notebook 图像可视化 plt show操作
Apr 24 Python
python矩阵运算,转置,逆运算,共轭矩阵实例
May 11 Python
Python如何把字典写入到CSV文件的方法示例
Aug 23 Python
十个Python自动化常用操作,即拿即用
May 10 Python
Keras 切换后端方式(Theano和TensorFlow)
Jun 19 #Python
python中怎么表示空值
Jun 19 #Python
Python调用OpenCV实现图像平滑代码实例
Jun 19 #Python
使用OpenCV对车道进行实时检测的实现示例代码
Jun 19 #Python
为什么python比较流行
Jun 19 #Python
查看keras的默认backend实现方式
Jun 19 #Python
Python图像阈值化处理及算法比对实例解析
Jun 19 #Python
You might like
PHP内置加密函数详解
2016/11/20 PHP
PHPCrawl爬虫库实现抓取酷狗歌单的方法示例
2017/12/21 PHP
PHP设计模式之迭代器模式Iterator实例分析【对象行为型】
2020/04/26 PHP
javascript 函数调用的对象和方法
2010/07/01 Javascript
计算新浪Weibo消息长度(还可以输入119字)
2013/07/02 Javascript
用jquery统计子菜单的条数示例代码
2013/10/18 Javascript
JS实现可调整倒计时间代码分享
2015/08/18 Javascript
推荐阅读的js快速判断IE浏览器(兼容IE10与IE11)
2015/12/13 Javascript
jquery UI Datepicker时间控件冲突问题解决
2016/12/16 Javascript
node.js实现复制文本到剪切板的功能
2017/01/23 Javascript
NodeJs的fs读写删除移动监听
2017/04/28 NodeJs
JS中获取 DOM 元素的绝对位置实例详解
2018/04/23 Javascript
angularjs结合html5实现拖拽功能
2018/06/25 Javascript
JavaScript中的一些实用小技巧总结
2019/04/07 Javascript
解决antd 表单设置默认值initialValue后验证失效的问题
2020/11/02 Javascript
VUE实现吸底按钮
2021/03/04 Vue.js
Python 返回汉字的汉语拼音
2009/02/27 Python
tornado捕获和处理404错误的方法
2014/02/26 Python
python实现进程间通信简单实例
2014/07/23 Python
使用PM2+nginx部署python项目的方法示例
2018/11/07 Python
解析Python的缩进规则的使用
2019/01/16 Python
说说如何遍历Python列表的方法示例
2019/02/11 Python
tensorflow如何批量读取图片
2019/08/29 Python
Django app配置多个数据库代码实例
2019/12/17 Python
Python文件操作函数用法实例详解
2019/12/24 Python
python 如何实现遗传算法
2020/09/22 Python
虚拟环境及venv和virtualenv的区别说明
2021/02/05 Python
俄罗斯首家面向中国消费者的一站式购物网站:Wruru
2020/05/08 全球购物
当一个对象被当作参数传递到一个方法后,此方法可改变这个对象的属性,并可返回变化后的结果,那么这里到底是值传递还是引用传递?
2014/09/09 面试题
优秀应届生推荐信
2013/11/09 职场文书
英语复习计划
2015/01/19 职场文书
公司放假通知范文
2015/04/14 职场文书
朝花夕拾读书笔记
2015/06/29 职场文书
2016年万圣节家长开放日活动总结
2016/04/05 职场文书
Python基础 括号()[]{}的详解
2021/11/07 Python
Flutter Navigator 实现路由传递参数
2022/04/22 Java/Android