keras实现theano和tensorflow训练的模型相互转换


Posted in Python onJune 19, 2020

我就废话不多说了,大家还是直接看代码吧~

</pre><pre code_snippet_id="1947416" snippet_file_name="blog_20161025_1_3331239" name="code" class="python">

# coding:utf-8
"""
If you want to load pre-trained weights that include convolutions (layers Convolution2D or Convolution1D),
be mindful of this: Theano and TensorFlow implement convolution in different ways (TensorFlow actually implements correlation, much like Caffe),
and thus, convolution kernels trained with Theano (resp. TensorFlow) need to be converted before being with TensorFlow (resp. Theano).
"""
from keras import backend as K
from keras.utils.np_utils import convert_kernel
from text_classifier import keras_text_classifier
import sys
 
def th2tf( model):
  import tensorflow as tf
  ops = []
  for layer in model.layers:
    if layer.__class__.__name__ in ['Convolution1D', 'Convolution2D']:
      original_w = K.get_value(layer.W)
      converted_w = convert_kernel(original_w)
      ops.append(tf.assign(layer.W, converted_w).op)
  K.get_session().run(ops)
  return model
 
def tf2th(model):
  for layer in model.layers:
    if layer.__class__.__name__ in ['Convolution1D', 'Convolution2D']:
      original_w = K.get_value(layer.W)
      converted_w = convert_kernel(original_w)
      K.set_value(layer.W, converted_w)
  return model
 
def conv_layer_converted(tf_weights, th_weights, m = 0):
  """
  :param tf_weights:
  :param th_weights:
  :param m: 0-tf2th, 1-th2tf
  :return:
  """
  if m == 0: # tf2th
    tc = keras_text_classifier(weights_path=tf_weights)
    model = tc.loadmodel()
    model = tf2th(model)
    model.save_weights(th_weights)
  elif m == 1: # th2tf
    tc = keras_text_classifier(weights_path=th_weights)
    model = tc.loadmodel()
    model = th2tf(model)
    model.save_weights(tf_weights)
  else:
    print("0-tf2th, 1-th2tf")
    return
if __name__ == '__main__':
  if len(sys.argv) < 4:
    print("python tf_weights th_weights <0|1>\n0-tensorflow to theano\n1-theano to tensorflow")
    sys.exit(0)
  tf_weights = sys.argv[1]
  th_weights = sys.argv[2]
  m = int(sys.argv[3])
  conv_layer_converted(tf_weights, th_weights, m)

补充知识:keras学习之修改底层为TensorFlow还是theano

我们知道,keras的底层是TensorFlow或者theano

要知道我们是用的哪个为底层,只需要import keras即可显示

修改方法:

打开

keras实现theano和tensorflow训练的模型相互转换

修改

keras实现theano和tensorflow训练的模型相互转换

以上这篇keras实现theano和tensorflow训练的模型相互转换就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python脚本实现代码行数统计代码分享
Mar 10 Python
Python2.x版本中maketrans()方法的使用介绍
May 19 Python
利用Python的Django框架生成PDF文件的教程
Jul 22 Python
Python3.6实现连接mysql或mariadb的方法分析
May 18 Python
python使用zip将list转为json的方法
Dec 31 Python
Pandas中resample方法详解
Jul 02 Python
详解如何减少python内存的消耗
Aug 09 Python
TensorFlow 输出checkpoint 中的变量名与变量值方式
Feb 11 Python
Python3创建Django项目的几种方法(3种)
Jun 03 Python
python 爬取百度文库并下载(免费文章限定)
Dec 04 Python
PyCharm 解决找不到新打开项目的窗口问题
Jan 15 Python
Python利用机器学习算法实现垃圾邮件的识别
Jun 28 Python
Keras 切换后端方式(Theano和TensorFlow)
Jun 19 #Python
python中怎么表示空值
Jun 19 #Python
Python调用OpenCV实现图像平滑代码实例
Jun 19 #Python
使用OpenCV对车道进行实时检测的实现示例代码
Jun 19 #Python
为什么python比较流行
Jun 19 #Python
查看keras的默认backend实现方式
Jun 19 #Python
Python图像阈值化处理及算法比对实例解析
Jun 19 #Python
You might like
php学习之 认清变量的作用范围
2010/01/26 PHP
PHP循环结构实例讲解
2014/02/10 PHP
PHP、Nginx、Apache中禁止网页被iframe引用的方法
2020/10/01 PHP
跟我学Laravel之路由
2014/10/15 PHP
Linux系统下php获得系统分区信息的方法
2015/03/30 PHP
实例讲解PHP页面静态化
2018/02/05 PHP
document对象execCommand的command参数介绍
2006/08/01 Javascript
很全的显示阴历(农历)日期的js代码
2009/01/01 Javascript
Mootools 1.2教程 设置和获取样式表属性
2009/09/15 Javascript
jQuery多项选项卡的实现思路附样式及代码
2014/06/03 Javascript
js实现touch移动触屏滑动事件
2015/04/17 Javascript
JS实现滑动菜单效果代码(包括Tab,选项卡,横向等效果)
2015/09/24 Javascript
详解在Angularjs中ui-sref和$state.go如何传递参数
2017/04/24 Javascript
Angular实现响应式表单
2017/08/04 Javascript
vue 实现Web端的定位功能 获取经纬度
2019/08/08 Javascript
使用 Vue 实现一个虚拟列表的方法
2019/08/20 Javascript
jquery获取input输入框中的值
2019/11/13 jQuery
vue3弹出层V3Popup实例详解
2021/01/04 Vue.js
Python变量和字符串详解
2017/04/29 Python
Python基础学习之函数方法实例详解
2019/06/18 Python
Python代码生成视频的缩略图的实例讲解
2019/12/22 Python
Pycharm 2020最新永久激活码(附最新激活码和插件)
2020/09/17 Python
Python如何访问字符串中的值
2020/02/09 Python
python 如何调用远程接口
2020/09/11 Python
python 动态渲染 mysql 配置文件的示例
2020/11/20 Python
HTML5添加鼠标悬浮音响效果不使用FLASH
2014/04/23 HTML / CSS
Superdry瑞典官网:英国日本街头风品牌
2017/05/17 全球购物
ECCO爱步官方旗舰店:丹麦鞋履品牌
2018/01/02 全球购物
百思买加拿大:Best Buy Canada
2018/03/20 全球购物
临床医学大学生求职信
2013/09/28 职场文书
银行职员思想汇报
2013/12/31 职场文书
小学教研工作制度
2014/01/15 职场文书
村捐赠仪式答谢词
2014/01/21 职场文书
集体备课反思
2014/02/12 职场文书
团支部书记竞选稿
2015/11/21 职场文书
Python matplotlib绘制雷达图
2022/04/13 Python