Keras 使用 Lambda层详解


Posted in Python onJune 10, 2020

我就废话不多说了,大家还是直接看代码吧!

from tensorflow.python.keras.models import Sequential, Model
from tensorflow.python.keras.layers import Dense, Flatten, Conv2D, MaxPool2D, Dropout, Conv2DTranspose, Lambda, Input, Reshape, Add, Multiply
from tensorflow.python.keras.optimizers import Adam
 
def deconv(x):
  height = x.get_shape()[1].value
  width = x.get_shape()[2].value
  
  new_height = height*2
  new_width = width*2
  
  x_resized = tf.image.resize_images(x, [new_height, new_width], tf.image.ResizeMethod.NEAREST_NEIGHBOR)
  
  return x_resized
 
def Generator(scope='generator'):
  imgs_noise = Input(shape=inputs_shape)
  x = Conv2D(filters=32, kernel_size=(9,9), strides=(1,1), padding='same', activation='relu')(imgs_noise)
  x = Conv2D(filters=64, kernel_size=(3,3), strides=(2,2), padding='same', activation='relu')(x)
  x = Conv2D(filters=128, kernel_size=(3,3), strides=(2,2), padding='same', activation='relu')(x)
 
  x1 = Conv2D(filters=128, kernel_size=(3,3), strides=(1,1), padding='same', activation='relu')(x)
  x1 = Conv2D(filters=128, kernel_size=(3,3), strides=(1,1), padding='same', activation='relu')(x1)
  x2 = Add()([x1, x])
 
  x3 = Conv2D(filters=128, kernel_size=(3,3), strides=(1,1), padding='same', activation='relu')(x2)
  x3 = Conv2D(filters=128, kernel_size=(3,3), strides=(1,1), padding='same', activation='relu')(x3)
  x4 = Add()([x3, x2])
 
  x5 = Conv2D(filters=128, kernel_size=(3,3), strides=(1,1), padding='same', activation='relu')(x4)
  x5 = Conv2D(filters=128, kernel_size=(3,3), strides=(1,1), padding='same', activation='relu')(x5)
  x6 = Add()([x5, x4])
 
  x = MaxPool2D(pool_size=(2,2))(x6)
 
  x = Lambda(deconv)(x)
  x = Conv2D(filters=64, kernel_size=(3, 3), strides=(1,1), padding='same',activation='relu')(x)
  x = Lambda(deconv)(x)
  x = Conv2D(filters=32, kernel_size=(3, 3), strides=(1,1), padding='same',activation='relu')(x)
  x = Lambda(deconv)(x)
  x = Conv2D(filters=3, kernel_size=(3, 3), strides=(1, 1), padding='same',activation='tanh')(x)
 
  x = Lambda(lambda x: x+1)(x)
  y = Lambda(lambda x: x*127.5)(x)
  
  model = Model(inputs=imgs_noise, outputs=y)
  model.summary()
  
  return model
 
my_generator = Generator()
my_generator.compile(loss='binary_crossentropy', optimizer=Adam(0.7, decay=1e-3), metrics=['accuracy'])

补充知识:含有Lambda自定义层keras模型,保存遇到的问题及解决方案

一,许多应用,keras含有的层已经不能满足要求,需要透过Lambda自定义层来实现一些layer,这个情况下,只能保存模型的权重,无法使用model.save来保存模型。保存时会报

TypeError: can't pickle _thread.RLock objects

Keras 使用 Lambda层详解

二,解决方案,为了便于后续的部署,可以转成tensorflow的PB进行部署。

from keras.models import load_model
import tensorflow as tf
import os, sys
from keras import backend as K
from tensorflow.python.framework import graph_util, graph_io

def h5_to_pb(h5_weight_path, output_dir, out_prefix="output_", log_tensorboard=True):
  if not os.path.exists(output_dir):
    os.mkdir(output_dir)
  h5_model = build_model()
  h5_model.load_weights(h5_weight_path)
  out_nodes = []
  for i in range(len(h5_model.outputs)):
    out_nodes.append(out_prefix + str(i + 1))
    tf.identity(h5_model.output[i], out_prefix + str(i + 1))
  model_name = os.path.splitext(os.path.split(h5_weight_path)[-1])[0] + '.pb'
  sess = K.get_session()
  init_graph = sess.graph.as_graph_def()
  main_graph = graph_util.convert_variables_to_constants(sess, init_graph, out_nodes)
  graph_io.write_graph(main_graph, output_dir, name=model_name, as_text=False)
  if log_tensorboard:
    from tensorflow.python.tools import import_pb_to_tensorboard
    import_pb_to_tensorboard.import_to_tensorboard(os.path.join(output_dir, model_name), output_dir)

def build_model():
  inputs = Input(shape=(784,), name='input_img')
  x = Dense(64, activation='relu')(inputs)
  x = Dense(64, activation='relu')(x)
  y = Dense(10, activation='softmax')(x)
  h5_model = Model(inputs=inputs, outputs=y)
  return h5_model

if __name__ == '__main__':
  if len(sys.argv) == 3:
    # usage: python3 h5_to_pb.py h5_weight_path output_dir
    h5_to_pb(h5_weight_path=sys.argv[1], output_dir=sys.argv[2])

以上这篇Keras 使用 Lambda层详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python 类与元类的深度挖掘 II【经验】
May 06 Python
Python编码爬坑指南(必看)
Jun 10 Python
Linux下python与C++使用dlib实现人脸检测
Jun 29 Python
PyQt5 实现给窗口设置背景图片的方法
Jun 13 Python
基于python生成英文版词云图代码实例
May 16 Python
使用Python构造hive insert语句说明
Jun 06 Python
使用Python中tkinter库简单gui界面制作及打包成exe的操作方法(二)
Oct 12 Python
PyCharm配置KBEngine快速处理代码提示冲突、配置命令问题
Apr 03 Python
pytorch 如何使用batch训练lstm网络
May 28 Python
Python答题卡识别并给出分数的实现代码
Jun 22 Python
Python Matplotlib绘制条形图的全过程
Oct 24 Python
Python matplotlib绘制条形统计图 处理多个实验多组观测值
Apr 21 Python
keras打印loss对权重的导数方式
Jun 10 #Python
Python xlrd模块导入过程及常用操作
Jun 10 #Python
keras-siamese用自己的数据集实现详解
Jun 10 #Python
python实现mean-shift聚类算法
Jun 10 #Python
Keras之自定义损失(loss)函数用法说明
Jun 10 #Python
Python xlwt模块使用代码实例
Jun 10 #Python
python中def是做什么的
Jun 10 #Python
You might like
PHP SPL标准库之接口(Interface)详解
2015/05/11 PHP
php文件扩展名判断及获取文件扩展名的N种方法
2015/09/12 PHP
PHP中PDO连接数据库中各种DNS设置方法小结
2016/05/13 PHP
php使用glob函数遍历文件和目录详解
2016/09/23 PHP
php+Memcached实现简单留言板功能示例
2017/02/15 PHP
JSON PHP中,Json字符串反序列化成对象/数组的方法
2018/05/31 PHP
使用laravel和ECharts实现折线图效果的例子
2019/10/09 PHP
Select标签下拉列表二级联动级联实例代码
2014/02/07 Javascript
jquery使用hide方法隐藏指定id的元素
2015/03/30 Javascript
简单的分页代码js实现
2016/05/17 Javascript
jQuery 3.0十大新特性
2016/07/06 Javascript
原生态js,鼠标按下后,经过了那些单元格的简单实例
2016/08/11 Javascript
JavaScript组成、引入、输出、运算符基础知识讲解
2016/12/08 Javascript
AngularJS中的Promise详细介绍及实例代码
2016/12/13 Javascript
bootstrap按钮插件(Button)使用方法解析
2017/01/13 Javascript
vue表单绑定实现多选框和下拉列表的实例
2017/08/12 Javascript
vue给对象动态添加属性和值的实例
2019/09/09 Javascript
vue+echarts实现中国地图流动效果(步骤详解)
2021/01/27 Vue.js
[01:42]TI4西雅图DOTA2前线报道 第一顿早饭哦
2014/07/08 DOTA
[54:26]完美世界DOTA2联赛PWL S3 Forest vs Rebirth 第一场 12.10
2020/12/12 DOTA
Pandas 重塑(stack)和轴向旋转(pivot)的实现
2019/07/22 Python
python打造爬虫代理池过程解析
2019/08/15 Python
Django框架ORM数据库操作实例详解
2019/11/07 Python
Django 简单实现分页与搜索功能的示例代码
2019/11/07 Python
python对Excel按条件进行内容补充(推荐)
2019/11/24 Python
python的列表List求均值和中位数实例
2020/03/03 Python
Python变量及数据类型用法原理汇总
2020/08/06 Python
Python性能测试工具Locust安装及使用
2020/12/01 Python
python 基于Apscheduler实现定时任务
2020/12/15 Python
官方授权图形T恤和服装:Fifth Sun
2019/06/12 全球购物
西班牙Polo衫品牌:Polo Club
2020/08/09 全球购物
旷课检讨书2000字
2014/01/14 职场文书
解决Mysql的left join无效及使用的注意事项说明
2021/07/01 MySQL
centos8安装MongoDB的详细过程
2021/10/24 MongoDB
用PYTHON去计算88键钢琴的琴键频率和音高
2022/04/10 Python
Java异常体系非正常停止和分类
2022/06/14 Java/Android