Keras 使用 Lambda层详解


Posted in Python onJune 10, 2020

我就废话不多说了,大家还是直接看代码吧!

from tensorflow.python.keras.models import Sequential, Model
from tensorflow.python.keras.layers import Dense, Flatten, Conv2D, MaxPool2D, Dropout, Conv2DTranspose, Lambda, Input, Reshape, Add, Multiply
from tensorflow.python.keras.optimizers import Adam
 
def deconv(x):
  height = x.get_shape()[1].value
  width = x.get_shape()[2].value
  
  new_height = height*2
  new_width = width*2
  
  x_resized = tf.image.resize_images(x, [new_height, new_width], tf.image.ResizeMethod.NEAREST_NEIGHBOR)
  
  return x_resized
 
def Generator(scope='generator'):
  imgs_noise = Input(shape=inputs_shape)
  x = Conv2D(filters=32, kernel_size=(9,9), strides=(1,1), padding='same', activation='relu')(imgs_noise)
  x = Conv2D(filters=64, kernel_size=(3,3), strides=(2,2), padding='same', activation='relu')(x)
  x = Conv2D(filters=128, kernel_size=(3,3), strides=(2,2), padding='same', activation='relu')(x)
 
  x1 = Conv2D(filters=128, kernel_size=(3,3), strides=(1,1), padding='same', activation='relu')(x)
  x1 = Conv2D(filters=128, kernel_size=(3,3), strides=(1,1), padding='same', activation='relu')(x1)
  x2 = Add()([x1, x])
 
  x3 = Conv2D(filters=128, kernel_size=(3,3), strides=(1,1), padding='same', activation='relu')(x2)
  x3 = Conv2D(filters=128, kernel_size=(3,3), strides=(1,1), padding='same', activation='relu')(x3)
  x4 = Add()([x3, x2])
 
  x5 = Conv2D(filters=128, kernel_size=(3,3), strides=(1,1), padding='same', activation='relu')(x4)
  x5 = Conv2D(filters=128, kernel_size=(3,3), strides=(1,1), padding='same', activation='relu')(x5)
  x6 = Add()([x5, x4])
 
  x = MaxPool2D(pool_size=(2,2))(x6)
 
  x = Lambda(deconv)(x)
  x = Conv2D(filters=64, kernel_size=(3, 3), strides=(1,1), padding='same',activation='relu')(x)
  x = Lambda(deconv)(x)
  x = Conv2D(filters=32, kernel_size=(3, 3), strides=(1,1), padding='same',activation='relu')(x)
  x = Lambda(deconv)(x)
  x = Conv2D(filters=3, kernel_size=(3, 3), strides=(1, 1), padding='same',activation='tanh')(x)
 
  x = Lambda(lambda x: x+1)(x)
  y = Lambda(lambda x: x*127.5)(x)
  
  model = Model(inputs=imgs_noise, outputs=y)
  model.summary()
  
  return model
 
my_generator = Generator()
my_generator.compile(loss='binary_crossentropy', optimizer=Adam(0.7, decay=1e-3), metrics=['accuracy'])

补充知识:含有Lambda自定义层keras模型,保存遇到的问题及解决方案

一,许多应用,keras含有的层已经不能满足要求,需要透过Lambda自定义层来实现一些layer,这个情况下,只能保存模型的权重,无法使用model.save来保存模型。保存时会报

TypeError: can't pickle _thread.RLock objects

Keras 使用 Lambda层详解

二,解决方案,为了便于后续的部署,可以转成tensorflow的PB进行部署。

from keras.models import load_model
import tensorflow as tf
import os, sys
from keras import backend as K
from tensorflow.python.framework import graph_util, graph_io

def h5_to_pb(h5_weight_path, output_dir, out_prefix="output_", log_tensorboard=True):
  if not os.path.exists(output_dir):
    os.mkdir(output_dir)
  h5_model = build_model()
  h5_model.load_weights(h5_weight_path)
  out_nodes = []
  for i in range(len(h5_model.outputs)):
    out_nodes.append(out_prefix + str(i + 1))
    tf.identity(h5_model.output[i], out_prefix + str(i + 1))
  model_name = os.path.splitext(os.path.split(h5_weight_path)[-1])[0] + '.pb'
  sess = K.get_session()
  init_graph = sess.graph.as_graph_def()
  main_graph = graph_util.convert_variables_to_constants(sess, init_graph, out_nodes)
  graph_io.write_graph(main_graph, output_dir, name=model_name, as_text=False)
  if log_tensorboard:
    from tensorflow.python.tools import import_pb_to_tensorboard
    import_pb_to_tensorboard.import_to_tensorboard(os.path.join(output_dir, model_name), output_dir)

def build_model():
  inputs = Input(shape=(784,), name='input_img')
  x = Dense(64, activation='relu')(inputs)
  x = Dense(64, activation='relu')(x)
  y = Dense(10, activation='softmax')(x)
  h5_model = Model(inputs=inputs, outputs=y)
  return h5_model

if __name__ == '__main__':
  if len(sys.argv) == 3:
    # usage: python3 h5_to_pb.py h5_weight_path output_dir
    h5_to_pb(h5_weight_path=sys.argv[1], output_dir=sys.argv[2])

以上这篇Keras 使用 Lambda层详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python中文字符串截取问题
Jun 15 Python
如何使用七牛Python SDK写一个同步脚本及使用教程
Aug 23 Python
Python实现对字符串的加密解密方法示例
Apr 29 Python
高效测试用例组织算法pairwise之Python实现方法
Jul 19 Python
Python实现图片转字符画的示例
Aug 22 Python
python正则中最短匹配实现代码
Jan 16 Python
python实现电脑自动关机
Jun 20 Python
解决Python中定时任务线程无法自动退出的问题
Feb 18 Python
python celery分布式任务队列的使用详解
Jul 08 Python
如何在python中实现随机选择
Nov 02 Python
python-numpy-指数分布实例详解
Dec 07 Python
python GUI库图形界面开发之PyQt5开发环境配置与基础使用
Feb 25 Python
keras打印loss对权重的导数方式
Jun 10 #Python
Python xlrd模块导入过程及常用操作
Jun 10 #Python
keras-siamese用自己的数据集实现详解
Jun 10 #Python
python实现mean-shift聚类算法
Jun 10 #Python
Keras之自定义损失(loss)函数用法说明
Jun 10 #Python
Python xlwt模块使用代码实例
Jun 10 #Python
python中def是做什么的
Jun 10 #Python
You might like
php完全过滤HTML,JS,CSS等标签
2009/01/16 PHP
PHP 第二节 数据类型之数组
2012/04/28 PHP
javascript中有趣的反柯里化深入分析
2012/12/05 Javascript
JavaScript 操作table,可以新增行和列并且隔一行换背景色代码分享
2013/07/05 Javascript
全面兼容的javascript时间格式化函数(比较实用)
2014/05/14 Javascript
jQuery仿Flash上下翻动的中英文导航菜单实例
2015/03/10 Javascript
nodejs通过phantomjs实现下载网页
2015/05/04 NodeJs
关于验证码在IE中不刷新的快速解决方法
2016/09/23 Javascript
JavaScript实现数组降维详解
2017/01/05 Javascript
在百度搜索结果中去除掉一些网站的资料(通过js控制不让显示)
2017/05/02 Javascript
微信小程序扫描二维码获取信息实例详解
2019/05/07 Javascript
非常漂亮的js烟花效果
2020/03/10 Javascript
JS实现滑动拼图验证功能完整示例
2020/03/29 Javascript
Python入门篇之对象类型
2014/10/17 Python
python使用Queue在多个子进程间交换数据的方法
2015/04/18 Python
python 定时修改数据库的示例代码
2018/04/08 Python
Django如何配置mysql数据库
2018/05/04 Python
python获取中文字符串长度的方法
2018/11/14 Python
Python3实现mysql连接和数据框的形成(实例代码)
2020/01/17 Python
基于python 等频分箱qcut问题的解决
2020/03/03 Python
Python爬虫实现模拟点击动态页面
2020/03/05 Python
Python下划线5种含义代码实例解析
2020/07/10 Python
如何用Python编写一个电子考勤系统
2021/02/08 Python
基于Pytorch版yolov5的滑块验证码破解思路详解
2021/02/25 Python
你不知道的5个HTML5新功能
2016/06/28 HTML / CSS
Html5 new XMLHttpRequest()监听附件上传进度
2021/01/14 HTML / CSS
Ruby如何创建一个线程
2013/03/10 面试题
机关节能减排实施方案
2014/03/17 职场文书
小学生家长寄语
2014/04/02 职场文书
机械操作工岗位职责
2014/08/08 职场文书
群众路线自我剖析及整改措施
2014/11/04 职场文书
2015年全国保险公众宣传日活动方案
2015/05/06 职场文书
在职证明书模板
2015/06/15 职场文书
python 远程执行命令的详细代码
2022/02/15 Python
MySQL优化及索引解析
2022/03/17 MySQL
java获取一个文本文件的编码(格式)信息
2022/09/23 Java/Android