Tensorflow分类器项目自定义数据读入的实现


Posted in Python onFebruary 05, 2019

在照着Tensorflow官网的demo敲了一遍分类器项目的代码后,运行倒是成功了,结果也不错。但是最终还是要训练自己的数据,所以尝试准备加载自定义的数据,然而demo中只是出现了fashion_mnist.load_data()并没有详细的读取过程,随后我又找了些资料,把读取的过程记录在这里。

首先提一下需要用到的模块:

import os
import keras
import matplotlib.pyplot as plt
from PIL import Image
from keras.preprocessing.image import ImageDataGenerator
from sklearn.model_selection import train_test_split

图片分类器项目,首先确定你要处理的图片分辨率将是多少,这里的例子为30像素:

IMG_SIZE_X = 30
IMG_SIZE_Y = 30

其次确定你图片的方式目录:

image_path = r'D:\Projects\ImageClassifier\data\set'
path = ".\data"
# 你也可以使用相对路径的方式
# image_path =os.path.join(path, "set")

目录下的结构如下:

Tensorflow分类器项目自定义数据读入的实现

相应的label.txt如下:

动漫
风景
美女
物语
樱花

接下来是接在labels.txt,如下:

label_name = "labels.txt"
label_path = os.path.join(path, label_name)
class_names = np.loadtxt(label_path, type(""))

这里简便起见,直接利用了numpy的loadtxt函数直接加载。

之后便是正式处理图片数据了,注释就写在里面了:

re_load = False
re_build = False
# re_load = True
re_build = True

data_name = "data.npz"
data_path = os.path.join(path, data_name)
model_name = "model.h5"
model_path = os.path.join(path, model_name)

count = 0

# 这里判断是否存在序列化之后的数据,re_load是一个开关,是否强制重新处理,测试用,可以去除。
if not os.path.exists(data_path) or re_load:
  labels = []
  images = []
  print('Handle images')
  # 由于label.txt是和图片防止目录的分类目录一一对应的,即每个子目录的目录名就是labels.txt里的一个label,所以这里可以通过读取class_names的每一项去拼接path后读取
  for index, name in enumerate(class_names):
    # 这里是拼接后的子目录path
    classpath = os.path.join(image_path, name)
    # 先判断一下是否是目录
    if not os.path.isdir(classpath):
      continue
    # limit是测试时候用的这里可以去除
    limit = 0
    for image_name in os.listdir(classpath):
      if limit >= max_size:
        break
      # 这里是拼接后的待处理的图片path
      imagepath = os.path.join(classpath, image_name)
      count = count + 1
      limit = limit + 1
      # 利用Image打开图片
      img = Image.open(imagepath)
      # 缩放到你最初确定要处理的图片分辨率大小
      img = img.resize((IMG_SIZE_X, IMG_SIZE_Y))
      # 转为灰度图片,这里彩色通道会干扰结果,并且会加大计算量
      img = img.convert("L")
      # 转为numpy数组
      img = np.array(img)
      # 由(30,30)转为(1,30,30)(即`channels_first`),当然你也可以转换为(30,30,1)(即`channels_last`)但为了之后预览处理后的图片方便这里采用了(1,30,30)的格式存放
      img = np.reshape(img, (1, IMG_SIZE_X, IMG_SIZE_Y))
      # 这里利用循环生成labels数据,其中存放的实际是class_names中对应元素的索引
      labels.append([index])
      # 添加到images中,最后统一处理
      images.append(img)
      # 循环中一些状态的输出,可以去除
      print("{} class: {} {} limit: {} {}"
         .format(count, index + 1, class_names[index], limit, imagepath))
  # 最后一次性将images和labels都转换成numpy数组
  npy_data = np.array(images)
  npy_labels = np.array(labels)
  # 处理数据只需要一次,所以我们选择在这里利用numpy自带的方法将处理之后的数据序列化存储
  np.savez(data_path, x=npy_data, y=npy_labels)
  print("Save images by npz")
else:
  # 如果存在序列化号的数据,便直接读取,提高速度
  npy_data = np.load(data_path)["x"]
  npy_labels = np.load(data_path)["y"]
  print("Load images by npz")
image_data = npy_data
labels_data = npy_labels

到了这里原始数据的加工预处理便已经完成,只需要最后一步,就和demo中fashion_mnist.load_data()返回的结果一样了。代码如下:

# 最后一步就是将原始数据分成训练数据和测试数据
train_images, test_images, train_labels, test_labels = \
  train_test_split(image_data, labels_data, test_size=0.2, random_state=6)

这里将相关信息打印的方法也附上:

print("_________________________________________________________________")
print("%-28s %-s" % ("Name", "Shape"))
print("=================================================================")
print("%-28s %-s" % ("Image Data", image_data.shape))
print("%-28s %-s" % ("Labels Data", labels_data.shape))
print("=================================================================")

print('Split train and test data,p=%')
print("_________________________________________________________________")
print("%-28s %-s" % ("Name", "Shape"))
print("=================================================================")
print("%-28s %-s" % ("Train Images", train_images.shape))
print("%-28s %-s" % ("Test Images", test_images.shape))
print("%-28s %-s" % ("Train Labels", train_labels.shape))
print("%-28s %-s" % ("Test Labels", test_labels.shape))
print("=================================================================")

之后别忘了归一化哟:

print("Normalize images")
train_images = train_images / 255.0
test_images = test_images / 255.0

最后附上读取自定义数据的完整代码:

import os

import keras
import matplotlib.pyplot as plt
from PIL import Image
from keras.layers import *
from keras.models import *
from keras.optimizers import Adam
from keras.preprocessing.image import ImageDataGenerator
from sklearn.model_selection import train_test_split

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
# 支持中文
plt.rcParams['font.sans-serif'] = ['SimHei'] # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False # 用来正常显示负号
re_load = False
re_build = False
# re_load = True
re_build = True
epochs = 50
batch_size = 5
count = 0
max_size = 2000000000

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python读写Excel文件的实例
Nov 01 Python
17个Python小技巧分享
Jan 23 Python
使用Python的PEAK来适配协议的教程
Apr 14 Python
利用Python破解验证码实例详解
Dec 08 Python
一个基于flask的web应用诞生 组织结构调整(7)
Apr 11 Python
python实现QQ邮箱/163邮箱的邮件发送
Jan 22 Python
python 函数中的内置函数及用法详解
Jul 02 Python
pandas分区间,算频率的实例
Jul 04 Python
一篇文章搞定Python操作文件与目录
Aug 13 Python
Python: 传递列表副本方式
Dec 19 Python
在pytorch 中计算精度、回归率、F1 score等指标的实例
Jan 18 Python
Python 通过爬虫实现GitHub网页的模拟登录的示例代码
Aug 17 Python
在Python 字典中一键对应多个值的实例
Feb 03 #Python
Django csrf 两种方法设置form的实例
Feb 03 #Python
解决django前后端分离csrf验证的问题
Feb 03 #Python
Python利用heapq实现一个优先级队列的方法
Feb 03 #Python
对Python3中dict.keys()转换成list类型的方法详解
Feb 03 #Python
对python中字典keys,values,items的使用详解
Feb 03 #Python
python生成带有表格的图片实例
Feb 03 #Python
You might like
剧场版动画《PSYCHO-PASS 3 FIRST INSPECTOR》3月27日日本上映!
2020/03/06 日漫
[原创]PHP实现字节数Byte转换为KB、MB、GB、TB的方法
2017/08/31 PHP
Laravel中unique和exists验证规则的优化详解
2018/01/28 PHP
thinkphp框架实现路由重定义简化url访问地址的方法分析
2020/04/04 PHP
贴一个在Mozilla中常用的Javascript代码
2007/01/09 Javascript
jquery焦点图片切换(数字标注/手动/自动播放/横向滚动)
2013/01/24 Javascript
原生js的弹出层且其内的窗口居中
2014/05/14 Javascript
js模拟淘宝网的多级选择菜单实现方法
2015/08/18 Javascript
仅30行代码实现Javascript中的MVC
2016/02/15 Javascript
JavaScript跨域调用基于JSON的RESTful API
2016/07/09 Javascript
微信小程序 wx.uploadFile在安卓手机上面the same task is working问题解决
2016/12/14 Javascript
JavaScript中常见的八个陷阱总结
2017/06/28 Javascript
jQuery表单设置值的方法
2017/06/30 jQuery
ionic App问题总结系列之ionic点击系统返回键退出App
2017/08/19 Javascript
Koa代理Http请求的示例代码
2018/10/10 Javascript
一篇文章,教你学会Vue CLI 插件开发
2019/04/17 Javascript
Vue-input框checkbox强制刷新问题
2019/04/18 Javascript
JavaScript如何获取一个元素的样式信息
2019/07/29 Javascript
JS实现水平遍历和嵌套递归操作示例
2019/08/15 Javascript
Javascript实现鼠标移入方向感知
2020/06/24 Javascript
Vue(定时器)解决mounted不能获取到data中的数据问题
2020/07/30 Javascript
详解Python的Twisted框架中reactor事件管理器的用法
2016/05/25 Python
Python中enumerate函数代码解析
2017/10/31 Python
对numpy中shape的深入理解
2018/06/15 Python
78行Python代码实现现微信撤回消息功能
2018/07/26 Python
详解pytorch 0.4.0迁移指南
2019/06/16 Python
Pycharm 2020年最新激活码(亲测有效)
2020/09/18 Python
Keras自定义实现带masking的meanpooling层方式
2020/06/16 Python
python网络爬虫实现发送短信验证码的方法
2021/02/25 Python
Alba Moda德国网上商店:意大利时尚女装销售
2016/11/14 全球购物
高性能装备提升营地:Kammok
2019/02/27 全球购物
init进程的作用
2015/08/20 面试题
机械电子工程专业自荐书
2014/06/10 职场文书
安全知识竞赛主持词
2015/06/30 职场文书
2015年暑期社会实践报告
2015/07/13 职场文书
动漫APP软件排行榜前十名,半次元上榜,第一款由腾讯公司推出
2022/03/18 杂记