keras自动编码器实现系列之卷积自动编码器操作


Posted in Python onJuly 03, 2020

图片的自动编码很容易就想到用卷积神经网络做为编码-解码器。在实际的操作中,

也经常使用卷积自动编码器去解决图像编码问题,而且非常有效。

下面通过**keras**完成简单的卷积自动编码。 编码器有堆叠的卷积层和池化层(max pooling用于空间降采样)组成。 对应的解码器由卷积层和上采样层组成。

@requires_authorization
# -*- coding:utf-8 -*-

from keras.layers import Input, Dense, Conv2D, MaxPooling2D, UpSampling2D
from keras.models import Model
from keras import backend as K
import os

## 网络结构 ##
input_img = Input(shape=(28,28,1)) # Tensorflow后端, 注意要用channel_last
# 编码器部分
x = Conv2D(16, (3,3), activation='relu', padding='same')(input_img)
x = MaxPooling2D((2,2), padding='same')(x)
x = Conv2D(8,(3,3), activation='relu', padding='same')(x)
x = MaxPooling2D((2,2), padding='same')(x)
x = Conv2D(8, (3,3), activation='relu', padding='same')(x)
encoded = MaxPooling2D((2,2), padding='same')(x)

# 解码器部分
x = Conv2D(8, (3,3), activation='relu', padding='same')(encoded)
x = UpSampling2D((2, 2))(x)
x = Conv2D(8, (3,3), activation='relu', padding='same')(x) 
x = UpSampling2D((2, 2))(x)
x = Conv2D(16, (3, 3), activation='relu', padding='same')(x)
x = UpSampling2D((2, 2))(x)
decoded = Conv2D(1, (3, 3), activation='sigmoid', padding='same')(x)

autoencoder = Model(input_img, decoded)
autoencoder.compile(optimizer='adam', loss='binary_crossentropy')

# 得到编码层的输出
encoder_model = Model(inputs=autoencoder.input, outputs=autoencoder.get_layer('encoder_out').output)

## 导入数据, 使用常用的手写识别数据集
def load_mnist(dataset_name):
'''
load the data
'''
  data_dir = os.path.join("./data", dataset_name)
  f = np.load(os.path.join(data_dir, 'mnist.npz'))
  train_data = f['train'].T
  trX = train_data.reshape((-1, 28, 28, 1)).astype(np.float32)
  trY = f['train_labels'][-1].astype(np.float32)
  test_data = f['test'].T
  teX = test_data.reshape((-1, 28, 28, 1)).astype(np.float32)
  teY = f['test_labels'][-1].astype(np.float32)

  # one-hot 
  # y_vec = np.zeros((len(y), 10), dtype=np.float32)
  # for i, label in enumerate(y):
  #   y_vec[i, y[i]] = 1
  # keras.utils里带的有one-hot的函数, 就直接用那个了
  return trX / 255., trY, teX/255., teY

# 开始导入数据
x_train, _ , x_test, _= load_mnist('mnist')

# 可视化训练结果, 我们打开终端, 使用tensorboard
# tensorboard --logdir=/tmp/autoencoder # 注意这里是打开一个终端, 在终端里运行

# 训练模型, 并且在callbacks中使用tensorBoard实例, 写入训练日志 http://0.0.0.0:6006
from keras.callbacks import TensorBoard
autoencoder.fit(x_train, x_train,
        epochs=50,
        batch_size=128,
        shuffle=True,
        validation_data=(x_test, x_test),
        callbacks=[TensorBoard(log_dir='/tmp/autoencoder')])

# 重建图片
import matplotlib.pyplot as plt 
decoded_imgs = autoencoder.predict(x_test)
encoded_imgs = encoder_model.predict(x_test)
n = 10
plt.figure(figsize=(20, 4))
for i in range(n):
  k = i + 1
  # 画原始图片
  ax = plt.subplot(2, n, k)
  plt.imshow(x_test[k].reshape(28, 28))
  plt.gray()
  ax.get_xaxis().set_visible(False)
  # 画重建图片
  ax = plt.subplot(2, n, k + n)
  plt.imshow(decoded_imgs[i].reshape(28, 28))
  plt.gray()
  ax.get_xaxis().set_visible(False)
  ax.get_yaxis().set_visible(False)
plt.show()

# 编码得到的特征
n = 10
plt.figure(figsize=(20, 8))
for i in range(n):
  k = i + 1
  ax = plt.subplot(1, n, k)
  plt.imshow(encoded[k].reshape(4, 4 * 8).T)
  plt.gray()
  ax.get_xaxis().set_visible(False)
  ax.get_yaxis().set_visible(False)
plt.show()

补充知识:keras搬砖系列-单层卷积自编码器

考试成绩出来了,竟然有一门出奇的差,只是有点意外。

觉得应该不错的,竟然考差了,它估计写了个随机数吧。

头文件

from keras.layers import Input,Dense
from keras.models import Model 
from keras.datasets import mnist
import numpy as np 
import matplotlib.pyplot as plt

导入数据

(X_train,_),(X_test,_) = mnist.load_data()
 
X_train = X_train.astype('float32')/255.
X_test = X_test.astype('float32')/255.
X_train = X_train.reshape((len(X_train),-1))
X_test = X_test.reshape((len(X_test),-1))

这里的X_train和X_test的维度分别为(60000L,784L),(10000L,784L)

这里进行了归一化,将所有的数值除上255.

设定编码的维数与输入数据的维数

encoding_dim = 32

input_img = Input(shape=(784,))

构建模型

encoded = Dense(encoding_dim,activation='relu')(input_img)
decoded = Dense(784,activation='relu')(encoded)
 
autoencoder = Model(inputs = input_img,outputs=decoded)
encoder = Model(inputs=input_img,outputs=encoded)
 
encoded_input = Input(shape=(encoding_dim,))
decoder_layer = autoencoder.layers[-1]
deconder = Model(inputs=encoded_input,outputs = decoder_layer(encoded_input))

模型编译

autoencoder.compile(optimizer='adadelta',loss='binary_crossentropy')

模型训练

autoencoder.fit(X_train,X_train,epochs=50,batch_size=256,shuffle=True,validation_data=(X_test,X_test))

预测

encoded_imgs = encoder.predict(X_test)

decoded_imgs = deconder.predict(encoded_imgs)

数据可视化

n = 10
for i in range(n):
 ax = plt.subplot(2,n,i+1)
 plt.imshow(X_test[i].reshape(28,28))
 plt.gray()
 ax.get_xaxis().set_visible(False)
 ax.get_yaxis().set_visible(False)
 
 ax = plt.subplot(2,n,i+1+n)
 plt.imshow(decoded_imgs[i].reshape(28,28))
 plt.gray()
 ax.get_xaxis().set_visible(False)
 ax.get_yaxis().set_visible(False)
plt.show()

完成代码

from keras.layers import Input,Dense
from keras.models import Model 
from keras.datasets import mnist
import numpy as np 
import matplotlib.pyplot as plt 
 
(X_train,_),(X_test,_) = mnist.load_data()
 
X_train = X_train.astype('float32')/255.
X_test = X_test.astype('float32')/255.
X_train = X_train.reshape((len(X_train),-1))
X_test = X_test.reshape((len(X_test),-1))
 
encoding_dim = 32
input_img = Input(shape=(784,))
 
encoded = Dense(encoding_dim,activation='relu')(input_img)
decoded = Dense(784,activation='relu')(encoded)
 
autoencoder = Model(inputs = input_img,outputs=decoded)
encoder = Model(inputs=input_img,outputs=encoded)
 
encoded_input = Input(shape=(encoding_dim,))
decoder_layer = autoencoder.layers[-1]
deconder = Model(inputs=encoded_input,outputs = decoder_layer(encoded_input))
 
autoencoder.compile(optimizer='adadelta',loss='binary_crossentropy')
autoencoder.fit(X_train,X_train,epochs=50,batch_size=256,shuffle=True,validation_data=(X_test,X_test))
 
encoded_imgs = encoder.predict(X_test)
decoded_imgs = deconder.predict(encoded_imgs)
 
##via
n = 10
for i in range(n):
 ax = plt.subplot(2,n,i+1)
 plt.imshow(X_test[i].reshape(28,28))
 plt.gray()
 ax.get_xaxis().set_visible(False)
 ax.get_yaxis().set_visible(False)
 
 ax = plt.subplot(2,n,i+1+n)
 plt.imshow(decoded_imgs[i].reshape(28,28))
 plt.gray()
 ax.get_xaxis().set_visible(False)
 ax.get_yaxis().set_visible(False)
plt.show()

以上这篇keras自动编码器实现系列之卷积自动编码器操作就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python ORM框架SQLAlchemy学习笔记之数据添加和事务回滚介绍
Jun 10 Python
分析在Python中何种情况下需要使用断言
Apr 01 Python
详解python里的命名规范
Jul 16 Python
python分布式计算dispy的使用详解
Dec 22 Python
pytorch 图像中的数据预处理和批标准化实例
Jan 15 Python
python 中的[:-1]和[::-1]的具体使用
Feb 13 Python
使用python实现飞机大战游戏
Mar 23 Python
jupyter 导入csv文件方式
Apr 21 Python
django在开发中取消外键约束的实现
May 20 Python
Pycharm生成可执行文件.exe的实现方法
Jun 02 Python
python3爬虫GIL修改多线程实例讲解
Nov 24 Python
Python移位密码、仿射变换解密实例代码
Jun 27 Python
Python with语句用法原理详解
Jul 03 #Python
Keras搭建自编码器操作
Jul 03 #Python
python 识别登录验证码图片功能的实现代码(完整代码)
Jul 03 #Python
python图片验证码识别最新模块muggle_ocr的示例代码
Jul 03 #Python
keras topN显示,自编写代码案例
Jul 03 #Python
python如何使用代码运行助手
Jul 03 #Python
Python 3.10 的首个 PEP 诞生,内置类型 zip() 迎来新特性(推荐)
Jul 03 #Python
You might like
PHP中include()与require()的区别说明
2010/03/10 PHP
php实现批量压缩图片文件大小的脚本
2014/07/04 PHP
thinkPHP多语言切换设置方法详解
2016/11/11 PHP
PHP+JS实现的商品秒杀倒计时用法示例
2016/11/15 PHP
php+Memcached实现简单留言板功能示例
2017/02/15 PHP
Laravel框架查询构造器简单示例
2019/05/08 PHP
jQuery Tools Dateinput使用介绍
2012/07/14 Javascript
Javascript控制页面链接在新窗口打开具体方法
2013/08/16 Javascript
JSON+HTML实现国家省市联动选择效果
2014/05/18 Javascript
PHP+jQuery+Ajax实现多图片上传效果
2015/03/14 Javascript
jquery实现漫天雪花飞舞的圣诞祝福雪花效果代码分享
2015/08/20 Javascript
JavaScript File分段上传
2016/03/10 Javascript
JS hashMap实例详解
2016/05/26 Javascript
JS动态给对象添加事件的简单方法
2016/07/19 Javascript
微信小程序购物商城系统开发系列-工具篇的介绍
2016/11/21 Javascript
JS经典正则表达式笔试题汇总
2016/12/15 Javascript
微信小程序去哪里找 小程序到底如何使用(附小程序名单)
2017/01/09 Javascript
jquery Ajax实现Select动态添加数据
2017/06/08 jQuery
详解Vue单元测试Karma+Mocha学习笔记
2018/01/31 Javascript
微信小程序首页的分类功能和搜索功能的实现思路及代码详解
2018/09/11 Javascript
实现elementUI表单的全局验证的方法步骤
2019/04/29 Javascript
简单通过settimeout看javascript的运行机制
2019/05/10 Javascript
VUE 自定义组件模板的方法详解
2019/08/30 Javascript
[00:52]DOTA2齐天大圣预告片
2016/08/13 DOTA
Python 调用PIL库失败的解决方法
2019/01/08 Python
基于python的selenium两种文件上传操作实现详解
2019/09/19 Python
基于python+selenium自动健康打卡的实现代码
2021/01/13 Python
TALLY WEiJL法国网上商店:服装、时装及配饰
2019/08/31 全球购物
NULL是什么,它是怎么定义的
2015/05/09 面试题
项目经理岗位职责
2013/11/11 职场文书
个人委托书怎么写
2014/09/17 职场文书
2014年科室工作总结
2014/11/20 职场文书
诚信承诺书
2015/01/19 职场文书
幼儿园老师新年寄语
2015/08/17 职场文书
万能密码的SQL注入漏洞其PHP环境搭建及防御手段
2021/09/04 SQL Server
PyTorch中的torch.cat简单介绍
2022/03/17 Python