Keras 使用 Lambda层详解


Posted in Python onJune 10, 2020

我就废话不多说了,大家还是直接看代码吧!

from tensorflow.python.keras.models import Sequential, Model
from tensorflow.python.keras.layers import Dense, Flatten, Conv2D, MaxPool2D, Dropout, Conv2DTranspose, Lambda, Input, Reshape, Add, Multiply
from tensorflow.python.keras.optimizers import Adam
 
def deconv(x):
  height = x.get_shape()[1].value
  width = x.get_shape()[2].value
  
  new_height = height*2
  new_width = width*2
  
  x_resized = tf.image.resize_images(x, [new_height, new_width], tf.image.ResizeMethod.NEAREST_NEIGHBOR)
  
  return x_resized
 
def Generator(scope='generator'):
  imgs_noise = Input(shape=inputs_shape)
  x = Conv2D(filters=32, kernel_size=(9,9), strides=(1,1), padding='same', activation='relu')(imgs_noise)
  x = Conv2D(filters=64, kernel_size=(3,3), strides=(2,2), padding='same', activation='relu')(x)
  x = Conv2D(filters=128, kernel_size=(3,3), strides=(2,2), padding='same', activation='relu')(x)
 
  x1 = Conv2D(filters=128, kernel_size=(3,3), strides=(1,1), padding='same', activation='relu')(x)
  x1 = Conv2D(filters=128, kernel_size=(3,3), strides=(1,1), padding='same', activation='relu')(x1)
  x2 = Add()([x1, x])
 
  x3 = Conv2D(filters=128, kernel_size=(3,3), strides=(1,1), padding='same', activation='relu')(x2)
  x3 = Conv2D(filters=128, kernel_size=(3,3), strides=(1,1), padding='same', activation='relu')(x3)
  x4 = Add()([x3, x2])
 
  x5 = Conv2D(filters=128, kernel_size=(3,3), strides=(1,1), padding='same', activation='relu')(x4)
  x5 = Conv2D(filters=128, kernel_size=(3,3), strides=(1,1), padding='same', activation='relu')(x5)
  x6 = Add()([x5, x4])
 
  x = MaxPool2D(pool_size=(2,2))(x6)
 
  x = Lambda(deconv)(x)
  x = Conv2D(filters=64, kernel_size=(3, 3), strides=(1,1), padding='same',activation='relu')(x)
  x = Lambda(deconv)(x)
  x = Conv2D(filters=32, kernel_size=(3, 3), strides=(1,1), padding='same',activation='relu')(x)
  x = Lambda(deconv)(x)
  x = Conv2D(filters=3, kernel_size=(3, 3), strides=(1, 1), padding='same',activation='tanh')(x)
 
  x = Lambda(lambda x: x+1)(x)
  y = Lambda(lambda x: x*127.5)(x)
  
  model = Model(inputs=imgs_noise, outputs=y)
  model.summary()
  
  return model
 
my_generator = Generator()
my_generator.compile(loss='binary_crossentropy', optimizer=Adam(0.7, decay=1e-3), metrics=['accuracy'])

补充知识:含有Lambda自定义层keras模型,保存遇到的问题及解决方案

一,许多应用,keras含有的层已经不能满足要求,需要透过Lambda自定义层来实现一些layer,这个情况下,只能保存模型的权重,无法使用model.save来保存模型。保存时会报

TypeError: can't pickle _thread.RLock objects

Keras 使用 Lambda层详解

二,解决方案,为了便于后续的部署,可以转成tensorflow的PB进行部署。

from keras.models import load_model
import tensorflow as tf
import os, sys
from keras import backend as K
from tensorflow.python.framework import graph_util, graph_io

def h5_to_pb(h5_weight_path, output_dir, out_prefix="output_", log_tensorboard=True):
  if not os.path.exists(output_dir):
    os.mkdir(output_dir)
  h5_model = build_model()
  h5_model.load_weights(h5_weight_path)
  out_nodes = []
  for i in range(len(h5_model.outputs)):
    out_nodes.append(out_prefix + str(i + 1))
    tf.identity(h5_model.output[i], out_prefix + str(i + 1))
  model_name = os.path.splitext(os.path.split(h5_weight_path)[-1])[0] + '.pb'
  sess = K.get_session()
  init_graph = sess.graph.as_graph_def()
  main_graph = graph_util.convert_variables_to_constants(sess, init_graph, out_nodes)
  graph_io.write_graph(main_graph, output_dir, name=model_name, as_text=False)
  if log_tensorboard:
    from tensorflow.python.tools import import_pb_to_tensorboard
    import_pb_to_tensorboard.import_to_tensorboard(os.path.join(output_dir, model_name), output_dir)

def build_model():
  inputs = Input(shape=(784,), name='input_img')
  x = Dense(64, activation='relu')(inputs)
  x = Dense(64, activation='relu')(x)
  y = Dense(10, activation='softmax')(x)
  h5_model = Model(inputs=inputs, outputs=y)
  return h5_model

if __name__ == '__main__':
  if len(sys.argv) == 3:
    # usage: python3 h5_to_pb.py h5_weight_path output_dir
    h5_to_pb(h5_weight_path=sys.argv[1], output_dir=sys.argv[2])

以上这篇Keras 使用 Lambda层详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python list使用示例 list中找连续的数字
Jan 27 Python
讲解Python中的标识运算符
May 14 Python
python检索特定内容的文本文件实例
Jun 05 Python
Python线程下使用锁的技巧分享
Sep 13 Python
基于Python实现用户管理系统
Feb 26 Python
PyQt5 在label显示的图片中绘制矩形的方法
Jun 17 Python
Pandas分组与排序的实现
Jul 23 Python
Python如何调用JS文件中的函数
Aug 16 Python
浅谈Python_Openpyxl使用(最全总结)
Sep 05 Python
python自动发微信监控报警
Sep 06 Python
Python实现一个简单的递归下降分析器
Aug 01 Python
matplotlib设置颜色、标记、线条,让你的图像更加丰富(推荐)
Sep 25 Python
keras打印loss对权重的导数方式
Jun 10 #Python
Python xlrd模块导入过程及常用操作
Jun 10 #Python
keras-siamese用自己的数据集实现详解
Jun 10 #Python
python实现mean-shift聚类算法
Jun 10 #Python
Keras之自定义损失(loss)函数用法说明
Jun 10 #Python
Python xlwt模块使用代码实例
Jun 10 #Python
python中def是做什么的
Jun 10 #Python
You might like
同时提取多条新闻中的文本一例
2006/10/09 PHP
浅析php创建者模式
2014/11/25 PHP
浅谈php冒泡排序
2014/12/30 PHP
深入理解PHP中的empty和isset函数
2016/05/26 PHP
PHP实现的随机红包算法示例
2017/08/14 PHP
PHP实现的微信公众号扫码模拟登录功能示例
2019/05/30 PHP
js中有关IE版本检测
2012/01/04 Javascript
jQuery通过扩展实现抖动效果的方法
2015/03/11 Javascript
AngularJs Forms详解及简单示例
2016/09/01 Javascript
AngularJS自定义插件实现网站用户引导功能示例
2016/11/07 Javascript
详解为Angular.js内置$http服务添加拦截器的方法
2016/12/20 Javascript
React进阶学习之组件的解耦之道
2017/08/07 Javascript
在微信小程序里使用watch和computed的方法
2018/08/02 Javascript
Angular PWA使用的Demo示例
2019/01/31 Javascript
vue移动端实现手机左右滑动入场动画
2020/06/17 Javascript
vue实现简单学生信息管理
2020/05/30 Javascript
jQuery实现简单全选框
2020/09/13 jQuery
基于vuex实现购物车功能
2021/01/10 Vue.js
ubuntu环境下python虚拟环境的安装过程
2018/01/07 Python
python如何在循环引用中管理内存
2018/03/20 Python
如何在Python中实现goto语句的方法
2019/05/18 Python
python读取word 中指定位置的表格及表格数据
2019/10/23 Python
Python 执行矩阵与线性代数运算
2020/08/01 Python
夏威夷航空官网:Hawaiian Airlines
2016/09/11 全球购物
小狗电器官方商城:中国高端吸尘器品牌
2017/03/29 全球购物
爱普生美国官网:Epson美国
2018/11/05 全球购物
向全球直邮输送天然健康产品:iHerb.com
2020/05/03 全球购物
旅游网创业计划书
2014/01/31 职场文书
企业总经理岗位职责
2014/02/13 职场文书
公务员上班玩游戏检讨书
2014/09/17 职场文书
2014年班务工作总结
2014/12/02 职场文书
环保守法证明
2015/06/24 职场文书
企业催款函范本
2015/06/24 职场文书
《田忌赛马》教学反思
2016/02/19 职场文书
Python使用protobuf序列化和反序列化的实现
2021/05/19 Python
手写实现JS中的new
2021/11/07 Javascript