keras实现theano和tensorflow训练的模型相互转换


Posted in Python onJune 19, 2020

我就废话不多说了,大家还是直接看代码吧~

</pre><pre code_snippet_id="1947416" snippet_file_name="blog_20161025_1_3331239" name="code" class="python">

# coding:utf-8
"""
If you want to load pre-trained weights that include convolutions (layers Convolution2D or Convolution1D),
be mindful of this: Theano and TensorFlow implement convolution in different ways (TensorFlow actually implements correlation, much like Caffe),
and thus, convolution kernels trained with Theano (resp. TensorFlow) need to be converted before being with TensorFlow (resp. Theano).
"""
from keras import backend as K
from keras.utils.np_utils import convert_kernel
from text_classifier import keras_text_classifier
import sys
 
def th2tf( model):
  import tensorflow as tf
  ops = []
  for layer in model.layers:
    if layer.__class__.__name__ in ['Convolution1D', 'Convolution2D']:
      original_w = K.get_value(layer.W)
      converted_w = convert_kernel(original_w)
      ops.append(tf.assign(layer.W, converted_w).op)
  K.get_session().run(ops)
  return model
 
def tf2th(model):
  for layer in model.layers:
    if layer.__class__.__name__ in ['Convolution1D', 'Convolution2D']:
      original_w = K.get_value(layer.W)
      converted_w = convert_kernel(original_w)
      K.set_value(layer.W, converted_w)
  return model
 
def conv_layer_converted(tf_weights, th_weights, m = 0):
  """
  :param tf_weights:
  :param th_weights:
  :param m: 0-tf2th, 1-th2tf
  :return:
  """
  if m == 0: # tf2th
    tc = keras_text_classifier(weights_path=tf_weights)
    model = tc.loadmodel()
    model = tf2th(model)
    model.save_weights(th_weights)
  elif m == 1: # th2tf
    tc = keras_text_classifier(weights_path=th_weights)
    model = tc.loadmodel()
    model = th2tf(model)
    model.save_weights(tf_weights)
  else:
    print("0-tf2th, 1-th2tf")
    return
if __name__ == '__main__':
  if len(sys.argv) < 4:
    print("python tf_weights th_weights <0|1>\n0-tensorflow to theano\n1-theano to tensorflow")
    sys.exit(0)
  tf_weights = sys.argv[1]
  th_weights = sys.argv[2]
  m = int(sys.argv[3])
  conv_layer_converted(tf_weights, th_weights, m)

补充知识:keras学习之修改底层为TensorFlow还是theano

我们知道,keras的底层是TensorFlow或者theano

要知道我们是用的哪个为底层,只需要import keras即可显示

修改方法:

打开

keras实现theano和tensorflow训练的模型相互转换

修改

keras实现theano和tensorflow训练的模型相互转换

以上这篇keras实现theano和tensorflow训练的模型相互转换就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
用Python的Django框架来制作一个RSS阅读器
Jul 22 Python
Python算法应用实战之栈详解
Feb 04 Python
Python使用OpenCV进行标定
May 08 Python
快速解决PyCharm无法引用matplotlib的问题
May 24 Python
python3实现SMTP发送邮件详细教程
Jun 19 Python
使用python itchat包爬取微信好友头像形成矩形头像集的方法
Feb 21 Python
python 并发编程 多路复用IO模型详解
Aug 20 Python
Python实现微信机器人的方法
Sep 06 Python
python3.8 微信发送服务器监控报警消息代码实现
Nov 05 Python
python通过链接抓取网站详解
Nov 20 Python
使用python接受tgam的脑波数据实例
Apr 09 Python
在Windows上安装和配置 Jupyter Lab 作为桌面级应用程序教程
Apr 22 Python
Keras 切换后端方式(Theano和TensorFlow)
Jun 19 #Python
python中怎么表示空值
Jun 19 #Python
Python调用OpenCV实现图像平滑代码实例
Jun 19 #Python
使用OpenCV对车道进行实时检测的实现示例代码
Jun 19 #Python
为什么python比较流行
Jun 19 #Python
查看keras的默认backend实现方式
Jun 19 #Python
Python图像阈值化处理及算法比对实例解析
Jun 19 #Python
You might like
自己做矿石收音机
2021/03/02 无线电
php中使用$_REQUEST需要注意的一个问题
2013/05/02 PHP
PHP实现适用于自定义的验证码类
2016/06/15 PHP
js prototype 格式化数字 By shawl.qiu
2007/04/02 Javascript
javascript showModalDialog 多层模态窗口实现页面提交及刷新的代码
2009/11/28 Javascript
js取值中form.all和不加all的区别介绍
2014/01/20 Javascript
jQuery中:button选择器用法实例
2015/01/04 Javascript
解决jquery无法找到其他父级子集问题的方法
2016/05/10 Javascript
Select下拉框模糊查询功能实现代码
2016/07/22 Javascript
关于JavaScript 原型链的一点个人理解
2016/07/31 Javascript
新入门node.js必须要知道的概念(必看篇)
2016/08/10 Javascript
jQuery点击导航栏选中更换样式的实现代码
2017/01/23 Javascript
Node.js websocket使用socket.io库实现实时聊天室
2017/02/20 Javascript
原生JS+Canvas实现五子棋游戏实例
2017/06/19 Javascript
JavaScript基于activexobject连接远程数据库SQL Server 2014的方法
2017/07/12 Javascript
JS函数节流和函数防抖问题分析
2017/12/18 Javascript
微信小程序实现录制、试听、上传音频功能(带波形图)
2020/02/27 Javascript
vue使用echarts画组织结构图
2021/02/06 Vue.js
Django验证码的生成与使用示例
2017/05/20 Python
python 输出上个月的月末日期实例
2018/04/11 Python
Python使用ConfigParser模块操作配置文件的方法
2018/06/29 Python
NLTK 3.2.4 环境搭建教程
2018/09/19 Python
在Python中使用defaultdict初始化字典以及应用方法
2018/10/31 Python
Python实现的爬取百度贴吧图片功能完整示例
2019/05/10 Python
Python OpenCV中的resize()函数的使用
2019/06/20 Python
pytorch 在网络中添加可训练参数,修改预训练权重文件的方法
2019/08/17 Python
Python 求数组局部最大值的实例
2019/11/26 Python
python实现简单的井字棋游戏(gui界面)
2021/01/22 Python
CSS3解析抖音LOGO制作的方法步骤
2019/04/11 HTML / CSS
SheIn俄罗斯:时尚女装网上商店
2017/02/28 全球购物
莫斯科制造商的廉价皮大衣:Fursk
2020/06/09 全球购物
迎接领导欢迎词
2014/01/11 职场文书
保险专业自荐信范文
2014/02/20 职场文书
建材投资建议书
2014/05/16 职场文书
Python快速实现一键抠图功能的全过程
2021/06/29 Python
SQL中的三种去重方法小结
2021/11/01 SQL Server