Tensorflow分类器项目自定义数据读入的实现


Posted in Python onFebruary 05, 2019

在照着Tensorflow官网的demo敲了一遍分类器项目的代码后,运行倒是成功了,结果也不错。但是最终还是要训练自己的数据,所以尝试准备加载自定义的数据,然而demo中只是出现了fashion_mnist.load_data()并没有详细的读取过程,随后我又找了些资料,把读取的过程记录在这里。

首先提一下需要用到的模块:

import os
import keras
import matplotlib.pyplot as plt
from PIL import Image
from keras.preprocessing.image import ImageDataGenerator
from sklearn.model_selection import train_test_split

图片分类器项目,首先确定你要处理的图片分辨率将是多少,这里的例子为30像素:

IMG_SIZE_X = 30
IMG_SIZE_Y = 30

其次确定你图片的方式目录:

image_path = r'D:\Projects\ImageClassifier\data\set'
path = ".\data"
# 你也可以使用相对路径的方式
# image_path =os.path.join(path, "set")

目录下的结构如下:

Tensorflow分类器项目自定义数据读入的实现

相应的label.txt如下:

动漫
风景
美女
物语
樱花

接下来是接在labels.txt,如下:

label_name = "labels.txt"
label_path = os.path.join(path, label_name)
class_names = np.loadtxt(label_path, type(""))

这里简便起见,直接利用了numpy的loadtxt函数直接加载。

之后便是正式处理图片数据了,注释就写在里面了:

re_load = False
re_build = False
# re_load = True
re_build = True

data_name = "data.npz"
data_path = os.path.join(path, data_name)
model_name = "model.h5"
model_path = os.path.join(path, model_name)

count = 0

# 这里判断是否存在序列化之后的数据,re_load是一个开关,是否强制重新处理,测试用,可以去除。
if not os.path.exists(data_path) or re_load:
  labels = []
  images = []
  print('Handle images')
  # 由于label.txt是和图片防止目录的分类目录一一对应的,即每个子目录的目录名就是labels.txt里的一个label,所以这里可以通过读取class_names的每一项去拼接path后读取
  for index, name in enumerate(class_names):
    # 这里是拼接后的子目录path
    classpath = os.path.join(image_path, name)
    # 先判断一下是否是目录
    if not os.path.isdir(classpath):
      continue
    # limit是测试时候用的这里可以去除
    limit = 0
    for image_name in os.listdir(classpath):
      if limit >= max_size:
        break
      # 这里是拼接后的待处理的图片path
      imagepath = os.path.join(classpath, image_name)
      count = count + 1
      limit = limit + 1
      # 利用Image打开图片
      img = Image.open(imagepath)
      # 缩放到你最初确定要处理的图片分辨率大小
      img = img.resize((IMG_SIZE_X, IMG_SIZE_Y))
      # 转为灰度图片,这里彩色通道会干扰结果,并且会加大计算量
      img = img.convert("L")
      # 转为numpy数组
      img = np.array(img)
      # 由(30,30)转为(1,30,30)(即`channels_first`),当然你也可以转换为(30,30,1)(即`channels_last`)但为了之后预览处理后的图片方便这里采用了(1,30,30)的格式存放
      img = np.reshape(img, (1, IMG_SIZE_X, IMG_SIZE_Y))
      # 这里利用循环生成labels数据,其中存放的实际是class_names中对应元素的索引
      labels.append([index])
      # 添加到images中,最后统一处理
      images.append(img)
      # 循环中一些状态的输出,可以去除
      print("{} class: {} {} limit: {} {}"
         .format(count, index + 1, class_names[index], limit, imagepath))
  # 最后一次性将images和labels都转换成numpy数组
  npy_data = np.array(images)
  npy_labels = np.array(labels)
  # 处理数据只需要一次,所以我们选择在这里利用numpy自带的方法将处理之后的数据序列化存储
  np.savez(data_path, x=npy_data, y=npy_labels)
  print("Save images by npz")
else:
  # 如果存在序列化号的数据,便直接读取,提高速度
  npy_data = np.load(data_path)["x"]
  npy_labels = np.load(data_path)["y"]
  print("Load images by npz")
image_data = npy_data
labels_data = npy_labels

到了这里原始数据的加工预处理便已经完成,只需要最后一步,就和demo中fashion_mnist.load_data()返回的结果一样了。代码如下:

# 最后一步就是将原始数据分成训练数据和测试数据
train_images, test_images, train_labels, test_labels = \
  train_test_split(image_data, labels_data, test_size=0.2, random_state=6)

这里将相关信息打印的方法也附上:

print("_________________________________________________________________")
print("%-28s %-s" % ("Name", "Shape"))
print("=================================================================")
print("%-28s %-s" % ("Image Data", image_data.shape))
print("%-28s %-s" % ("Labels Data", labels_data.shape))
print("=================================================================")

print('Split train and test data,p=%')
print("_________________________________________________________________")
print("%-28s %-s" % ("Name", "Shape"))
print("=================================================================")
print("%-28s %-s" % ("Train Images", train_images.shape))
print("%-28s %-s" % ("Test Images", test_images.shape))
print("%-28s %-s" % ("Train Labels", train_labels.shape))
print("%-28s %-s" % ("Test Labels", test_labels.shape))
print("=================================================================")

之后别忘了归一化哟:

print("Normalize images")
train_images = train_images / 255.0
test_images = test_images / 255.0

最后附上读取自定义数据的完整代码:

import os

import keras
import matplotlib.pyplot as plt
from PIL import Image
from keras.layers import *
from keras.models import *
from keras.optimizers import Adam
from keras.preprocessing.image import ImageDataGenerator
from sklearn.model_selection import train_test_split

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
# 支持中文
plt.rcParams['font.sans-serif'] = ['SimHei'] # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False # 用来正常显示负号
re_load = False
re_build = False
# re_load = True
re_build = True
epochs = 50
batch_size = 5
count = 0
max_size = 2000000000

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
使用优化器来提升Python程序的执行效率的教程
Apr 02 Python
Python编程中的文件读写及相关的文件对象方法讲解
Jan 19 Python
Python 中 list 的各项操作技巧
Apr 13 Python
利用matplotlib+numpy绘制多种绘图的方法实例
May 03 Python
vue.js实现输入框输入值内容实时响应变化示例
Jul 07 Python
浅谈dataframe中更改列属性的方法
Jul 10 Python
Python 移动光标位置的方法
Jan 20 Python
Python2 Selenium元素定位的实现(8种)
Feb 25 Python
numpy.random模块用法总结
May 27 Python
python 数据生成excel导出(xlwt,wlsxwrite)代码实例
Aug 23 Python
Python中Flask-RESTful编写API接口(小白入门)
Dec 11 Python
Keras中的多分类损失函数用法categorical_crossentropy
Jun 11 Python
在Python 字典中一键对应多个值的实例
Feb 03 #Python
Django csrf 两种方法设置form的实例
Feb 03 #Python
解决django前后端分离csrf验证的问题
Feb 03 #Python
Python利用heapq实现一个优先级队列的方法
Feb 03 #Python
对Python3中dict.keys()转换成list类型的方法详解
Feb 03 #Python
对python中字典keys,values,items的使用详解
Feb 03 #Python
python生成带有表格的图片实例
Feb 03 #Python
You might like
php 什么是PEAR?
2009/03/19 PHP
PHP中获取变量的变量名的一段代码的bug分析
2011/07/07 PHP
WordPress中查询文章的循环Loop结构及用法分析
2015/12/17 PHP
php数据库的增删改查 php与javascript之间的交互
2017/08/31 PHP
Sample script that displays all of the users in a given SQL Server DB
2007/06/16 Javascript
基于jQuery的投票系统显示结果插件
2011/08/12 Javascript
IE下JS读取xml文件示例代码
2013/08/05 Javascript
JQuery Tips相关(1)----关于$.Ready()
2014/08/14 Javascript
avascript中的自执行匿名函数应用示例
2014/09/15 Javascript
javascript定义变量时加var与不加var的区别
2014/12/22 Javascript
自己编写的支持Ajax验证的JS表单验证插件
2015/05/15 Javascript
浅析JavaScript中的变量复制、参数传递和作用域链
2016/01/13 Javascript
JavaScript常用代码书写规范的超全面总结
2016/09/11 Javascript
javascript 单例模式详解及简单实例
2017/02/14 Javascript
Javascript实现一个简单的输入关键字添加标签效果实例
2017/06/01 Javascript
JS对象序列化成json数据和json数据转化为JS对象的代码
2017/08/23 Javascript
解决vue-cli创建项目的loader问题
2018/03/13 Javascript
基于vue-cli npm run build之后vendor.js文件过大的解决方法
2018/09/27 Javascript
mock.js模拟数据实现前后端分离
2019/07/24 Javascript
浅析webpack-bundle-analyzer在vue-cli3中的使用
2019/10/23 Javascript
[06:01]刀塔次级联赛top10第一期
2014/11/07 DOTA
[01:08:43]DOTA2-DPC中国联赛定级赛 Phoenix vs DLG BO3第一场 1月9日
2021/03/11 DOTA
Python HTMLParser模块解析html获取url实例
2015/04/08 Python
简单的Apache+FastCGI+Django配置指南
2015/07/22 Python
Python根据成绩分析系统浅析
2019/02/11 Python
Windows系统Python直接调用C++ DLL的方法
2019/08/01 Python
Python 将json序列化后的字符串转换成字典(推荐)
2020/01/06 Python
基于Python第三方插件实现西游记章节标注汉语拼音的方法
2020/05/22 Python
css3实现文字首尾衔接跑马灯的示例代码
2020/10/16 HTML / CSS
阿波罗盒子:Apollo Box
2017/08/14 全球购物
为什么要做架构设计
2015/07/08 面试题
宿舍违规用电检讨书
2014/02/16 职场文书
银行求职自荐信
2014/06/30 职场文书
计算机实训报告总结
2014/11/05 职场文书
MySQL索引知识的一些小妙招总结
2021/05/10 MySQL
mysql timestamp比较查询遇到的坑及解决
2021/11/27 MySQL