Keras 使用 Lambda层详解


Posted in Python onJune 10, 2020

我就废话不多说了,大家还是直接看代码吧!

from tensorflow.python.keras.models import Sequential, Model
from tensorflow.python.keras.layers import Dense, Flatten, Conv2D, MaxPool2D, Dropout, Conv2DTranspose, Lambda, Input, Reshape, Add, Multiply
from tensorflow.python.keras.optimizers import Adam
 
def deconv(x):
  height = x.get_shape()[1].value
  width = x.get_shape()[2].value
  
  new_height = height*2
  new_width = width*2
  
  x_resized = tf.image.resize_images(x, [new_height, new_width], tf.image.ResizeMethod.NEAREST_NEIGHBOR)
  
  return x_resized
 
def Generator(scope='generator'):
  imgs_noise = Input(shape=inputs_shape)
  x = Conv2D(filters=32, kernel_size=(9,9), strides=(1,1), padding='same', activation='relu')(imgs_noise)
  x = Conv2D(filters=64, kernel_size=(3,3), strides=(2,2), padding='same', activation='relu')(x)
  x = Conv2D(filters=128, kernel_size=(3,3), strides=(2,2), padding='same', activation='relu')(x)
 
  x1 = Conv2D(filters=128, kernel_size=(3,3), strides=(1,1), padding='same', activation='relu')(x)
  x1 = Conv2D(filters=128, kernel_size=(3,3), strides=(1,1), padding='same', activation='relu')(x1)
  x2 = Add()([x1, x])
 
  x3 = Conv2D(filters=128, kernel_size=(3,3), strides=(1,1), padding='same', activation='relu')(x2)
  x3 = Conv2D(filters=128, kernel_size=(3,3), strides=(1,1), padding='same', activation='relu')(x3)
  x4 = Add()([x3, x2])
 
  x5 = Conv2D(filters=128, kernel_size=(3,3), strides=(1,1), padding='same', activation='relu')(x4)
  x5 = Conv2D(filters=128, kernel_size=(3,3), strides=(1,1), padding='same', activation='relu')(x5)
  x6 = Add()([x5, x4])
 
  x = MaxPool2D(pool_size=(2,2))(x6)
 
  x = Lambda(deconv)(x)
  x = Conv2D(filters=64, kernel_size=(3, 3), strides=(1,1), padding='same',activation='relu')(x)
  x = Lambda(deconv)(x)
  x = Conv2D(filters=32, kernel_size=(3, 3), strides=(1,1), padding='same',activation='relu')(x)
  x = Lambda(deconv)(x)
  x = Conv2D(filters=3, kernel_size=(3, 3), strides=(1, 1), padding='same',activation='tanh')(x)
 
  x = Lambda(lambda x: x+1)(x)
  y = Lambda(lambda x: x*127.5)(x)
  
  model = Model(inputs=imgs_noise, outputs=y)
  model.summary()
  
  return model
 
my_generator = Generator()
my_generator.compile(loss='binary_crossentropy', optimizer=Adam(0.7, decay=1e-3), metrics=['accuracy'])

补充知识:含有Lambda自定义层keras模型,保存遇到的问题及解决方案

一,许多应用,keras含有的层已经不能满足要求,需要透过Lambda自定义层来实现一些layer,这个情况下,只能保存模型的权重,无法使用model.save来保存模型。保存时会报

TypeError: can't pickle _thread.RLock objects

Keras 使用 Lambda层详解

二,解决方案,为了便于后续的部署,可以转成tensorflow的PB进行部署。

from keras.models import load_model
import tensorflow as tf
import os, sys
from keras import backend as K
from tensorflow.python.framework import graph_util, graph_io

def h5_to_pb(h5_weight_path, output_dir, out_prefix="output_", log_tensorboard=True):
  if not os.path.exists(output_dir):
    os.mkdir(output_dir)
  h5_model = build_model()
  h5_model.load_weights(h5_weight_path)
  out_nodes = []
  for i in range(len(h5_model.outputs)):
    out_nodes.append(out_prefix + str(i + 1))
    tf.identity(h5_model.output[i], out_prefix + str(i + 1))
  model_name = os.path.splitext(os.path.split(h5_weight_path)[-1])[0] + '.pb'
  sess = K.get_session()
  init_graph = sess.graph.as_graph_def()
  main_graph = graph_util.convert_variables_to_constants(sess, init_graph, out_nodes)
  graph_io.write_graph(main_graph, output_dir, name=model_name, as_text=False)
  if log_tensorboard:
    from tensorflow.python.tools import import_pb_to_tensorboard
    import_pb_to_tensorboard.import_to_tensorboard(os.path.join(output_dir, model_name), output_dir)

def build_model():
  inputs = Input(shape=(784,), name='input_img')
  x = Dense(64, activation='relu')(inputs)
  x = Dense(64, activation='relu')(x)
  y = Dense(10, activation='softmax')(x)
  h5_model = Model(inputs=inputs, outputs=y)
  return h5_model

if __name__ == '__main__':
  if len(sys.argv) == 3:
    # usage: python3 h5_to_pb.py h5_weight_path output_dir
    h5_to_pb(h5_weight_path=sys.argv[1], output_dir=sys.argv[2])

以上这篇Keras 使用 Lambda层详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python实现树莓派WiFi断线自动重连的实例代码
Mar 16 Python
Python进度条实时显示处理进度的示例代码
Jan 30 Python
Tensorflow之构建自己的图片数据集TFrecords的方法
Feb 07 Python
python生成tensorflow输入输出的图像格式的方法
Feb 12 Python
pyqt 实现QlineEdit 输入密码显示成圆点的方法
Jun 24 Python
Python 多个图同时在不同窗口显示的实现方法
Jul 07 Python
python多线程分块读取文件
Aug 29 Python
Python telnet登陆功能实现代码
Apr 16 Python
Python实现电视里的5毛特效实例代码详解
May 15 Python
利用python清除移动硬盘中的临时文件
Oct 28 Python
numba提升python运行速度的实例方法
Jan 25 Python
python正则表达式re.match()匹配多个字符方法的实现
Jan 27 Python
keras打印loss对权重的导数方式
Jun 10 #Python
Python xlrd模块导入过程及常用操作
Jun 10 #Python
keras-siamese用自己的数据集实现详解
Jun 10 #Python
python实现mean-shift聚类算法
Jun 10 #Python
Keras之自定义损失(loss)函数用法说明
Jun 10 #Python
Python xlwt模块使用代码实例
Jun 10 #Python
python中def是做什么的
Jun 10 #Python
You might like
PHP版国家代码、缩写查询函数代码
2011/08/14 PHP
关于PHPDocument 代码注释规范的总结
2013/06/25 PHP
浅谈PHP中try{}catch{}的使用方法
2016/12/09 PHP
一个js实现的所谓的滑动门
2007/05/23 Javascript
js 获取页面高度和宽度兼容 ie firefox chrome等
2014/05/14 Javascript
json属性名为什么要双引号(个人猜测)
2014/07/31 Javascript
js单独获取一个checkbox看其是否被选中
2014/09/22 Javascript
学习JavaScript设计模式之迭代器模式
2016/01/19 Javascript
angularjs实现文字上下无缝滚动特效代码
2016/09/04 Javascript
Angularjs 实现分页功能及示例代码
2016/09/14 Javascript
vue.js学习笔记:如何加载本地json文件
2017/01/17 Javascript
ThinkPHP+jquery实现“加载更多”功能代码
2017/03/11 Javascript
Mongoose经常返回e11000 error的原因分析
2017/03/29 Javascript
jquery写出PC端轮播图实例
2018/01/26 jQuery
基于mpvue的简单弹窗组件mptoast使用详解
2019/08/02 Javascript
vue将data恢复到初始状态 && 重新渲染组件实例
2020/09/04 Javascript
Eclipse + Python 的安装与配置流程
2013/03/05 Python
Python IDLE入门简介
2017/12/08 Python
对Python定时任务的启动和停止方法详解
2019/02/19 Python
python用match()函数爬数据方法详解
2019/07/23 Python
JupyterNotebook设置Python环境的方法步骤
2019/12/03 Python
500行python代码实现飞机大战
2020/04/24 Python
关于tf.matmul() 和tf.multiply() 的区别说明
2020/06/18 Python
python virtualenv虚拟环境配置与使用教程详解
2020/07/13 Python
详解Python 最短匹配模式
2020/07/29 Python
Python使用urlretrieve实现直接远程下载图片的示例代码
2020/08/17 Python
用HTML5实现网站在windows8中贴靠的方法
2013/04/21 HTML / CSS
美国时尚孕妇装品牌:A Pea in the Pod
2017/07/16 全球购物
商超业务员岗位职责
2014/03/12 职场文书
机关门卫的岗位职责
2014/04/29 职场文书
给校长的建议书100字
2014/05/16 职场文书
论文答谢词
2015/01/20 职场文书
增值税发票丢失证明
2015/06/19 职场文书
大卫科波菲尔读书笔记
2015/06/30 职场文书
muduo TcpServer模块源码分析
2022/04/26 Redis
Go gorilla/sessions库安装使用
2022/08/14 Golang