keras实现theano和tensorflow训练的模型相互转换


Posted in Python onJune 19, 2020

我就废话不多说了,大家还是直接看代码吧~

</pre><pre code_snippet_id="1947416" snippet_file_name="blog_20161025_1_3331239" name="code" class="python">

# coding:utf-8
"""
If you want to load pre-trained weights that include convolutions (layers Convolution2D or Convolution1D),
be mindful of this: Theano and TensorFlow implement convolution in different ways (TensorFlow actually implements correlation, much like Caffe),
and thus, convolution kernels trained with Theano (resp. TensorFlow) need to be converted before being with TensorFlow (resp. Theano).
"""
from keras import backend as K
from keras.utils.np_utils import convert_kernel
from text_classifier import keras_text_classifier
import sys
 
def th2tf( model):
  import tensorflow as tf
  ops = []
  for layer in model.layers:
    if layer.__class__.__name__ in ['Convolution1D', 'Convolution2D']:
      original_w = K.get_value(layer.W)
      converted_w = convert_kernel(original_w)
      ops.append(tf.assign(layer.W, converted_w).op)
  K.get_session().run(ops)
  return model
 
def tf2th(model):
  for layer in model.layers:
    if layer.__class__.__name__ in ['Convolution1D', 'Convolution2D']:
      original_w = K.get_value(layer.W)
      converted_w = convert_kernel(original_w)
      K.set_value(layer.W, converted_w)
  return model
 
def conv_layer_converted(tf_weights, th_weights, m = 0):
  """
  :param tf_weights:
  :param th_weights:
  :param m: 0-tf2th, 1-th2tf
  :return:
  """
  if m == 0: # tf2th
    tc = keras_text_classifier(weights_path=tf_weights)
    model = tc.loadmodel()
    model = tf2th(model)
    model.save_weights(th_weights)
  elif m == 1: # th2tf
    tc = keras_text_classifier(weights_path=th_weights)
    model = tc.loadmodel()
    model = th2tf(model)
    model.save_weights(tf_weights)
  else:
    print("0-tf2th, 1-th2tf")
    return
if __name__ == '__main__':
  if len(sys.argv) < 4:
    print("python tf_weights th_weights <0|1>\n0-tensorflow to theano\n1-theano to tensorflow")
    sys.exit(0)
  tf_weights = sys.argv[1]
  th_weights = sys.argv[2]
  m = int(sys.argv[3])
  conv_layer_converted(tf_weights, th_weights, m)

补充知识:keras学习之修改底层为TensorFlow还是theano

我们知道,keras的底层是TensorFlow或者theano

要知道我们是用的哪个为底层,只需要import keras即可显示

修改方法:

打开

keras实现theano和tensorflow训练的模型相互转换

修改

keras实现theano和tensorflow训练的模型相互转换

以上这篇keras实现theano和tensorflow训练的模型相互转换就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python使用自带的ConfigParser模块读写ini配置文件
Jun 26 Python
python使用logging模块发送邮件代码示例
Jan 18 Python
Python使用itertools模块实现排列组合功能示例
Jul 02 Python
不知道这5种下划线的含义,你就不算真的会Python!
Oct 09 Python
python3安装speech语音模块的方法
Dec 24 Python
Python 如何优雅的将数字转化为时间格式的方法
Sep 26 Python
使用Pandas将inf, nan转化成特定的值
Dec 19 Python
pyinstaller还原python代码过程图解
Jan 08 Python
解决Pycharm中恢复被exclude的项目问题(pycharm source root)
Feb 14 Python
Python正则表达式高级使用方法汇总
Jun 18 Python
详解python定时简单爬取网页新闻存入数据库并发送邮件
Nov 27 Python
使用python求解迷宫问题的三种实现方法
Mar 17 Python
Keras 切换后端方式(Theano和TensorFlow)
Jun 19 #Python
python中怎么表示空值
Jun 19 #Python
Python调用OpenCV实现图像平滑代码实例
Jun 19 #Python
使用OpenCV对车道进行实时检测的实现示例代码
Jun 19 #Python
为什么python比较流行
Jun 19 #Python
查看keras的默认backend实现方式
Jun 19 #Python
Python图像阈值化处理及算法比对实例解析
Jun 19 #Python
You might like
如何对PHP程序中的常见漏洞进行攻击(下)
2006/10/09 PHP
php ftp文件上传函数(基础版)
2010/06/03 PHP
解析在zend Farmework下如何创立一个FORM表单
2013/06/28 PHP
PHP简单实现“相关文章推荐”功能的方法
2014/07/19 PHP
Apache启动报错No space left on device: AH00023该怎么解决
2015/10/16 PHP
php、java、android、ios通用的3des方法(推荐)
2016/09/09 PHP
PHP正则删除HTML代码中宽高样式的方法
2017/06/12 PHP
在JavaScript中通过URL传递汉字的方法
2007/04/09 Javascript
浅析LigerUi开发中谨慎载入common.css文件
2013/07/09 Javascript
JS操作JSON要领详细总结
2013/08/25 Javascript
jquery处理json对象
2014/11/03 Javascript
js获取页面description的方法
2015/05/21 Javascript
轻松掌握JavaScript享元模式
2016/08/27 Javascript
Vue项目全局配置微信分享思路详解
2018/05/04 Javascript
微信小程序实现横向增长表格的方法
2018/07/24 Javascript
详解js 创建对象的几种方法
2019/03/08 Javascript
JavaScript中的垃圾回收与内存泄漏示例详解
2019/05/02 Javascript
小程序实现录音上传功能
2019/11/22 Javascript
详解vue中v-model和v-bind绑定数据的异同
2020/08/10 Javascript
浏览器JavaScript调试功能无法使用解决方案
2020/09/18 Javascript
JS如何调用WebAssembly编译出来的.wasm文件
2020/11/05 Javascript
python 从远程服务器下载东西的代码
2013/02/10 Python
python中黄金分割法实现方法
2015/05/06 Python
Python读取一个目录下所有目录和文件的方法
2016/07/15 Python
Python查找两个有序列表中位数的方法【基于归并算法】
2018/04/20 Python
在Python中Dataframe通过print输出多行时显示省略号的实例
2018/12/22 Python
TensorFlow绘制loss/accuracy曲线的实例
2020/01/21 Python
Keras SGD 随机梯度下降优化器参数设置方式
2020/06/19 Python
18-35岁旅游团的全球领导者:Contiki
2017/02/08 全球购物
SQL注入攻击的种类有哪些
2013/12/30 面试题
十八届三中全会个人学习材料
2014/02/13 职场文书
导游词范文
2015/02/13 职场文书
超级礼物观后感
2015/06/15 职场文书
我的中国梦主题班会
2015/08/14 职场文书
2016年万圣节活动个人总结
2016/04/05 职场文书
Python之Matplotlib绘制热力图和面积图
2022/04/13 Python