Keras 使用 Lambda层详解


Posted in Python onJune 10, 2020

我就废话不多说了,大家还是直接看代码吧!

from tensorflow.python.keras.models import Sequential, Model
from tensorflow.python.keras.layers import Dense, Flatten, Conv2D, MaxPool2D, Dropout, Conv2DTranspose, Lambda, Input, Reshape, Add, Multiply
from tensorflow.python.keras.optimizers import Adam
 
def deconv(x):
  height = x.get_shape()[1].value
  width = x.get_shape()[2].value
  
  new_height = height*2
  new_width = width*2
  
  x_resized = tf.image.resize_images(x, [new_height, new_width], tf.image.ResizeMethod.NEAREST_NEIGHBOR)
  
  return x_resized
 
def Generator(scope='generator'):
  imgs_noise = Input(shape=inputs_shape)
  x = Conv2D(filters=32, kernel_size=(9,9), strides=(1,1), padding='same', activation='relu')(imgs_noise)
  x = Conv2D(filters=64, kernel_size=(3,3), strides=(2,2), padding='same', activation='relu')(x)
  x = Conv2D(filters=128, kernel_size=(3,3), strides=(2,2), padding='same', activation='relu')(x)
 
  x1 = Conv2D(filters=128, kernel_size=(3,3), strides=(1,1), padding='same', activation='relu')(x)
  x1 = Conv2D(filters=128, kernel_size=(3,3), strides=(1,1), padding='same', activation='relu')(x1)
  x2 = Add()([x1, x])
 
  x3 = Conv2D(filters=128, kernel_size=(3,3), strides=(1,1), padding='same', activation='relu')(x2)
  x3 = Conv2D(filters=128, kernel_size=(3,3), strides=(1,1), padding='same', activation='relu')(x3)
  x4 = Add()([x3, x2])
 
  x5 = Conv2D(filters=128, kernel_size=(3,3), strides=(1,1), padding='same', activation='relu')(x4)
  x5 = Conv2D(filters=128, kernel_size=(3,3), strides=(1,1), padding='same', activation='relu')(x5)
  x6 = Add()([x5, x4])
 
  x = MaxPool2D(pool_size=(2,2))(x6)
 
  x = Lambda(deconv)(x)
  x = Conv2D(filters=64, kernel_size=(3, 3), strides=(1,1), padding='same',activation='relu')(x)
  x = Lambda(deconv)(x)
  x = Conv2D(filters=32, kernel_size=(3, 3), strides=(1,1), padding='same',activation='relu')(x)
  x = Lambda(deconv)(x)
  x = Conv2D(filters=3, kernel_size=(3, 3), strides=(1, 1), padding='same',activation='tanh')(x)
 
  x = Lambda(lambda x: x+1)(x)
  y = Lambda(lambda x: x*127.5)(x)
  
  model = Model(inputs=imgs_noise, outputs=y)
  model.summary()
  
  return model
 
my_generator = Generator()
my_generator.compile(loss='binary_crossentropy', optimizer=Adam(0.7, decay=1e-3), metrics=['accuracy'])

补充知识:含有Lambda自定义层keras模型,保存遇到的问题及解决方案

一,许多应用,keras含有的层已经不能满足要求,需要透过Lambda自定义层来实现一些layer,这个情况下,只能保存模型的权重,无法使用model.save来保存模型。保存时会报

TypeError: can't pickle _thread.RLock objects

Keras 使用 Lambda层详解

二,解决方案,为了便于后续的部署,可以转成tensorflow的PB进行部署。

from keras.models import load_model
import tensorflow as tf
import os, sys
from keras import backend as K
from tensorflow.python.framework import graph_util, graph_io

def h5_to_pb(h5_weight_path, output_dir, out_prefix="output_", log_tensorboard=True):
  if not os.path.exists(output_dir):
    os.mkdir(output_dir)
  h5_model = build_model()
  h5_model.load_weights(h5_weight_path)
  out_nodes = []
  for i in range(len(h5_model.outputs)):
    out_nodes.append(out_prefix + str(i + 1))
    tf.identity(h5_model.output[i], out_prefix + str(i + 1))
  model_name = os.path.splitext(os.path.split(h5_weight_path)[-1])[0] + '.pb'
  sess = K.get_session()
  init_graph = sess.graph.as_graph_def()
  main_graph = graph_util.convert_variables_to_constants(sess, init_graph, out_nodes)
  graph_io.write_graph(main_graph, output_dir, name=model_name, as_text=False)
  if log_tensorboard:
    from tensorflow.python.tools import import_pb_to_tensorboard
    import_pb_to_tensorboard.import_to_tensorboard(os.path.join(output_dir, model_name), output_dir)

def build_model():
  inputs = Input(shape=(784,), name='input_img')
  x = Dense(64, activation='relu')(inputs)
  x = Dense(64, activation='relu')(x)
  y = Dense(10, activation='softmax')(x)
  h5_model = Model(inputs=inputs, outputs=y)
  return h5_model

if __name__ == '__main__':
  if len(sys.argv) == 3:
    # usage: python3 h5_to_pb.py h5_weight_path output_dir
    h5_to_pb(h5_weight_path=sys.argv[1], output_dir=sys.argv[2])

以上这篇Keras 使用 Lambda层详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python传递参数方式小结
Apr 17 Python
python3抓取中文网页的方法
Jul 28 Python
python之文件的读写和文件目录以及文件夹的操作实现代码
Aug 28 Python
Python+OpenCV让电脑帮你玩微信跳一跳
Jan 04 Python
tensorflow训练中出现nan问题的解决
Feb 10 Python
Python numpy中矩阵的基本用法汇总
Feb 12 Python
Python实现去除列表中重复元素的方法总结【7种方法】
Feb 16 Python
浅谈pyqt5在QMainWindow中布局的问题
Jun 21 Python
django认证系统 Authentication使用详解
Jul 22 Python
pytorch 准备、训练和测试自己的图片数据的方法
Jan 10 Python
对tensorflow中cifar-10文档的Read操作详解
Feb 10 Python
python中如何使用虚拟环境
Oct 14 Python
keras打印loss对权重的导数方式
Jun 10 #Python
Python xlrd模块导入过程及常用操作
Jun 10 #Python
keras-siamese用自己的数据集实现详解
Jun 10 #Python
python实现mean-shift聚类算法
Jun 10 #Python
Keras之自定义损失(loss)函数用法说明
Jun 10 #Python
Python xlwt模块使用代码实例
Jun 10 #Python
python中def是做什么的
Jun 10 #Python
You might like
PHP中魔术变量__METHOD__与__FUNCTION__的区别
2014/09/29 PHP
推荐一本PHP程序猿都应该拜读的书
2014/12/31 PHP
thinkPHP实现将excel导入到数据库中的方法
2016/04/22 PHP
利用php-cli和任务计划实现订单同步功能的方法
2017/05/03 PHP
Yii2配置Nginx伪静态的方法
2017/05/05 PHP
解决form中action属性后面?传递参数 获取不到的问题
2017/07/21 PHP
ThinkPHP框架中使用Memcached缓存数据的方法
2018/03/31 PHP
JavaScript几种形式的树结构菜单
2010/05/10 Javascript
javascript 单例/单体模式(Singleton)
2011/04/07 Javascript
javascript学习笔记(九) js对象 设计模式
2012/06/19 Javascript
javascript中节点的最近的相关节点访问方法
2013/03/20 Javascript
jquery分页插件jquery.pagination.js使用方法解析
2016/04/01 Javascript
js实现简易垂直滚动条
2017/02/22 Javascript
详解AngularJS之$window窗口对象
2018/01/17 Javascript
vue观察模式浅析
2018/09/25 Javascript
vue-cli脚手架打包静态资源请求出错的原因与解决
2019/06/06 Javascript
vue2路由基本用法实例分析
2020/03/06 Javascript
[06:25]第二届DOTA2亚洲邀请赛主赛事第二天比赛集锦.mp4
2017/04/03 DOTA
python Selenium爬取内容并存储至MySQL数据库的实现代码
2017/03/16 Python
Python实现PS滤镜特效之扇形变换效果示例
2018/01/26 Python
Flask和Django框架中自定义模型类的表名、父类相关问题分析
2018/07/19 Python
Python和Go语言的区别总结
2019/02/20 Python
Python Pandas分组聚合的实现方法
2019/07/02 Python
Python爬虫爬取百度搜索内容代码实例
2020/06/05 Python
html5页面结构_动力节点Java学院整理
2017/07/10 HTML / CSS
购买澳大利亚最好的服装和内衣在线:BONDS
2016/10/14 全球购物
英国最大的手表网站:The Watch Hut
2017/03/31 全球购物
银行演讲稿范文
2014/01/03 职场文书
旅游管理毕业生自荐信范文
2014/03/19 职场文书
质量保证书怎么写
2015/02/27 职场文书
大学生个人简历自我评价
2015/03/11 职场文书
党员心得体会范文2016
2016/01/23 职场文书
《百分数的认识》教学反思
2016/02/19 职场文书
2019年员工旷工保证书!
2019/06/28 职场文书
带你彻底理解JavaScript中的原型对象
2021/04/14 Javascript
mybatis使用oracle进行添加数据的方法
2021/04/27 Oracle