keras自动编码器实现系列之卷积自动编码器操作


Posted in Python onJuly 03, 2020

图片的自动编码很容易就想到用卷积神经网络做为编码-解码器。在实际的操作中,

也经常使用卷积自动编码器去解决图像编码问题,而且非常有效。

下面通过**keras**完成简单的卷积自动编码。 编码器有堆叠的卷积层和池化层(max pooling用于空间降采样)组成。 对应的解码器由卷积层和上采样层组成。

@requires_authorization
# -*- coding:utf-8 -*-

from keras.layers import Input, Dense, Conv2D, MaxPooling2D, UpSampling2D
from keras.models import Model
from keras import backend as K
import os

## 网络结构 ##
input_img = Input(shape=(28,28,1)) # Tensorflow后端, 注意要用channel_last
# 编码器部分
x = Conv2D(16, (3,3), activation='relu', padding='same')(input_img)
x = MaxPooling2D((2,2), padding='same')(x)
x = Conv2D(8,(3,3), activation='relu', padding='same')(x)
x = MaxPooling2D((2,2), padding='same')(x)
x = Conv2D(8, (3,3), activation='relu', padding='same')(x)
encoded = MaxPooling2D((2,2), padding='same')(x)

# 解码器部分
x = Conv2D(8, (3,3), activation='relu', padding='same')(encoded)
x = UpSampling2D((2, 2))(x)
x = Conv2D(8, (3,3), activation='relu', padding='same')(x) 
x = UpSampling2D((2, 2))(x)
x = Conv2D(16, (3, 3), activation='relu', padding='same')(x)
x = UpSampling2D((2, 2))(x)
decoded = Conv2D(1, (3, 3), activation='sigmoid', padding='same')(x)

autoencoder = Model(input_img, decoded)
autoencoder.compile(optimizer='adam', loss='binary_crossentropy')

# 得到编码层的输出
encoder_model = Model(inputs=autoencoder.input, outputs=autoencoder.get_layer('encoder_out').output)

## 导入数据, 使用常用的手写识别数据集
def load_mnist(dataset_name):
'''
load the data
'''
  data_dir = os.path.join("./data", dataset_name)
  f = np.load(os.path.join(data_dir, 'mnist.npz'))
  train_data = f['train'].T
  trX = train_data.reshape((-1, 28, 28, 1)).astype(np.float32)
  trY = f['train_labels'][-1].astype(np.float32)
  test_data = f['test'].T
  teX = test_data.reshape((-1, 28, 28, 1)).astype(np.float32)
  teY = f['test_labels'][-1].astype(np.float32)

  # one-hot 
  # y_vec = np.zeros((len(y), 10), dtype=np.float32)
  # for i, label in enumerate(y):
  #   y_vec[i, y[i]] = 1
  # keras.utils里带的有one-hot的函数, 就直接用那个了
  return trX / 255., trY, teX/255., teY

# 开始导入数据
x_train, _ , x_test, _= load_mnist('mnist')

# 可视化训练结果, 我们打开终端, 使用tensorboard
# tensorboard --logdir=/tmp/autoencoder # 注意这里是打开一个终端, 在终端里运行

# 训练模型, 并且在callbacks中使用tensorBoard实例, 写入训练日志 http://0.0.0.0:6006
from keras.callbacks import TensorBoard
autoencoder.fit(x_train, x_train,
        epochs=50,
        batch_size=128,
        shuffle=True,
        validation_data=(x_test, x_test),
        callbacks=[TensorBoard(log_dir='/tmp/autoencoder')])

# 重建图片
import matplotlib.pyplot as plt 
decoded_imgs = autoencoder.predict(x_test)
encoded_imgs = encoder_model.predict(x_test)
n = 10
plt.figure(figsize=(20, 4))
for i in range(n):
  k = i + 1
  # 画原始图片
  ax = plt.subplot(2, n, k)
  plt.imshow(x_test[k].reshape(28, 28))
  plt.gray()
  ax.get_xaxis().set_visible(False)
  # 画重建图片
  ax = plt.subplot(2, n, k + n)
  plt.imshow(decoded_imgs[i].reshape(28, 28))
  plt.gray()
  ax.get_xaxis().set_visible(False)
  ax.get_yaxis().set_visible(False)
plt.show()

# 编码得到的特征
n = 10
plt.figure(figsize=(20, 8))
for i in range(n):
  k = i + 1
  ax = plt.subplot(1, n, k)
  plt.imshow(encoded[k].reshape(4, 4 * 8).T)
  plt.gray()
  ax.get_xaxis().set_visible(False)
  ax.get_yaxis().set_visible(False)
plt.show()

补充知识:keras搬砖系列-单层卷积自编码器

考试成绩出来了,竟然有一门出奇的差,只是有点意外。

觉得应该不错的,竟然考差了,它估计写了个随机数吧。

头文件

from keras.layers import Input,Dense
from keras.models import Model 
from keras.datasets import mnist
import numpy as np 
import matplotlib.pyplot as plt

导入数据

(X_train,_),(X_test,_) = mnist.load_data()
 
X_train = X_train.astype('float32')/255.
X_test = X_test.astype('float32')/255.
X_train = X_train.reshape((len(X_train),-1))
X_test = X_test.reshape((len(X_test),-1))

这里的X_train和X_test的维度分别为(60000L,784L),(10000L,784L)

这里进行了归一化,将所有的数值除上255.

设定编码的维数与输入数据的维数

encoding_dim = 32

input_img = Input(shape=(784,))

构建模型

encoded = Dense(encoding_dim,activation='relu')(input_img)
decoded = Dense(784,activation='relu')(encoded)
 
autoencoder = Model(inputs = input_img,outputs=decoded)
encoder = Model(inputs=input_img,outputs=encoded)
 
encoded_input = Input(shape=(encoding_dim,))
decoder_layer = autoencoder.layers[-1]
deconder = Model(inputs=encoded_input,outputs = decoder_layer(encoded_input))

模型编译

autoencoder.compile(optimizer='adadelta',loss='binary_crossentropy')

模型训练

autoencoder.fit(X_train,X_train,epochs=50,batch_size=256,shuffle=True,validation_data=(X_test,X_test))

预测

encoded_imgs = encoder.predict(X_test)

decoded_imgs = deconder.predict(encoded_imgs)

数据可视化

n = 10
for i in range(n):
 ax = plt.subplot(2,n,i+1)
 plt.imshow(X_test[i].reshape(28,28))
 plt.gray()
 ax.get_xaxis().set_visible(False)
 ax.get_yaxis().set_visible(False)
 
 ax = plt.subplot(2,n,i+1+n)
 plt.imshow(decoded_imgs[i].reshape(28,28))
 plt.gray()
 ax.get_xaxis().set_visible(False)
 ax.get_yaxis().set_visible(False)
plt.show()

完成代码

from keras.layers import Input,Dense
from keras.models import Model 
from keras.datasets import mnist
import numpy as np 
import matplotlib.pyplot as plt 
 
(X_train,_),(X_test,_) = mnist.load_data()
 
X_train = X_train.astype('float32')/255.
X_test = X_test.astype('float32')/255.
X_train = X_train.reshape((len(X_train),-1))
X_test = X_test.reshape((len(X_test),-1))
 
encoding_dim = 32
input_img = Input(shape=(784,))
 
encoded = Dense(encoding_dim,activation='relu')(input_img)
decoded = Dense(784,activation='relu')(encoded)
 
autoencoder = Model(inputs = input_img,outputs=decoded)
encoder = Model(inputs=input_img,outputs=encoded)
 
encoded_input = Input(shape=(encoding_dim,))
decoder_layer = autoencoder.layers[-1]
deconder = Model(inputs=encoded_input,outputs = decoder_layer(encoded_input))
 
autoencoder.compile(optimizer='adadelta',loss='binary_crossentropy')
autoencoder.fit(X_train,X_train,epochs=50,batch_size=256,shuffle=True,validation_data=(X_test,X_test))
 
encoded_imgs = encoder.predict(X_test)
decoded_imgs = deconder.predict(encoded_imgs)
 
##via
n = 10
for i in range(n):
 ax = plt.subplot(2,n,i+1)
 plt.imshow(X_test[i].reshape(28,28))
 plt.gray()
 ax.get_xaxis().set_visible(False)
 ax.get_yaxis().set_visible(False)
 
 ax = plt.subplot(2,n,i+1+n)
 plt.imshow(decoded_imgs[i].reshape(28,28))
 plt.gray()
 ax.get_xaxis().set_visible(False)
 ax.get_yaxis().set_visible(False)
plt.show()

以上这篇keras自动编码器实现系列之卷积自动编码器操作就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python算法学习之基数排序实例
Dec 18 Python
Python的Flask框架应用程序实现使用QQ账号登录的方法
Jun 07 Python
tensorflow TFRecords文件的生成和读取的方法
Feb 06 Python
基于pip install django失败时的解决方法
Jun 12 Python
Python+Pandas 获取数据库并加入DataFrame的实例
Jul 25 Python
Windows下python3.7安装教程
Jul 31 Python
python 弹窗提示警告框MessageBox的实例
Jun 18 Python
python中通过selenium简单操作及元素定位知识点总结
Sep 10 Python
python列表插入append(), extend(), insert()用法详解
Sep 14 Python
Python调用Windows API函数编写录音机和音乐播放器功能
Jan 05 Python
python代码中怎么换行
Jun 17 Python
Python 读取千万级数据自动写入 MySQL 数据库
Jun 28 Python
Python with语句用法原理详解
Jul 03 #Python
Keras搭建自编码器操作
Jul 03 #Python
python 识别登录验证码图片功能的实现代码(完整代码)
Jul 03 #Python
python图片验证码识别最新模块muggle_ocr的示例代码
Jul 03 #Python
keras topN显示,自编写代码案例
Jul 03 #Python
python如何使用代码运行助手
Jul 03 #Python
Python 3.10 的首个 PEP 诞生,内置类型 zip() 迎来新特性(推荐)
Jul 03 #Python
You might like
利用php的ob缓存机制实现页面静态化方法
2017/07/09 PHP
PHP小白必须要知道的php基础知识(超实用)
2017/10/10 PHP
PHP 范围解析操作符(::)用法分析【访问静态成员和类常量】
2020/04/14 PHP
javascript 禁止复制网页
2009/06/11 Javascript
一个可绑定数据源的jQuery数据表格插件
2010/07/17 Javascript
jquery实现通用版鼠标经过淡入淡出效果
2014/06/15 Javascript
jQuery中$.each使用详解
2015/01/29 Javascript
jQuery实现ajax调用WCF服务的方法(附带demo下载)
2015/12/04 Javascript
浅析BootStrap Treeview的简单使用
2016/10/12 Javascript
微信小程序 选择器(时间,日期,地区)实例详解
2016/11/16 Javascript
Bootstrap笔记—折叠实例代码
2017/03/13 Javascript
详解JS中的柯里化(currying)
2017/08/17 Javascript
React学习笔记之高阶组件应用
2018/06/02 Javascript
纯JS实现的读取excel文件内容功能示例【支持所有浏览器】
2018/06/23 Javascript
vue input实现点击按钮文字增删功能示例
2019/01/29 Javascript
解决在layer.open中使用时间控件laydate失败的问题
2019/09/11 Javascript
Vue快速实现通用表单验证功能
2019/12/05 Javascript
小程序登录之支付宝授权的实现示例
2019/12/13 Javascript
[44:26]DOTA2上海特级锦标赛主赛事日 - 2 胜者组第一轮#4EG VS Fnatic第二局
2016/03/03 DOTA
[51:06]2018DOTA2亚洲邀请赛3月29日 小组赛A组 KG VS Liquid
2018/03/30 DOTA
初步剖析C语言编程中的结构体
2016/01/16 Python
Django rest framework基本介绍与代码示例
2018/01/26 Python
wxpython+pymysql实现用户登陆功能
2019/11/19 Python
pytorch 实现将自己的图片数据处理成可以训练的图片类型
2020/01/08 Python
ansible动态Inventory主机清单配置遇到的坑
2020/01/19 Python
基于Python把网站域名解析成ip地址
2020/05/25 Python
trivago美国:全球最大的酒店价格比较网站
2018/01/18 全球购物
化学实验员岗位职责
2013/12/28 职场文书
师德标兵事迹材料
2014/12/19 职场文书
活动费用申请报告
2015/05/15 职场文书
实习单位鉴定意见
2015/06/04 职场文书
2016春季运动会通讯稿
2015/07/18 职场文书
大学开学感言
2015/08/01 职场文书
2016党员干部政治学习心得体会
2016/01/23 职场文书
2019个人工作计划书的格式及范文!
2019/07/04 职场文书
python使用shell脚本创建kafka连接器
2022/04/29 Python