Tensorflow分类器项目自定义数据读入的实现


Posted in Python onFebruary 05, 2019

在照着Tensorflow官网的demo敲了一遍分类器项目的代码后,运行倒是成功了,结果也不错。但是最终还是要训练自己的数据,所以尝试准备加载自定义的数据,然而demo中只是出现了fashion_mnist.load_data()并没有详细的读取过程,随后我又找了些资料,把读取的过程记录在这里。

首先提一下需要用到的模块:

import os
import keras
import matplotlib.pyplot as plt
from PIL import Image
from keras.preprocessing.image import ImageDataGenerator
from sklearn.model_selection import train_test_split

图片分类器项目,首先确定你要处理的图片分辨率将是多少,这里的例子为30像素:

IMG_SIZE_X = 30
IMG_SIZE_Y = 30

其次确定你图片的方式目录:

image_path = r'D:\Projects\ImageClassifier\data\set'
path = ".\data"
# 你也可以使用相对路径的方式
# image_path =os.path.join(path, "set")

目录下的结构如下:

Tensorflow分类器项目自定义数据读入的实现

相应的label.txt如下:

动漫
风景
美女
物语
樱花

接下来是接在labels.txt,如下:

label_name = "labels.txt"
label_path = os.path.join(path, label_name)
class_names = np.loadtxt(label_path, type(""))

这里简便起见,直接利用了numpy的loadtxt函数直接加载。

之后便是正式处理图片数据了,注释就写在里面了:

re_load = False
re_build = False
# re_load = True
re_build = True

data_name = "data.npz"
data_path = os.path.join(path, data_name)
model_name = "model.h5"
model_path = os.path.join(path, model_name)

count = 0

# 这里判断是否存在序列化之后的数据,re_load是一个开关,是否强制重新处理,测试用,可以去除。
if not os.path.exists(data_path) or re_load:
  labels = []
  images = []
  print('Handle images')
  # 由于label.txt是和图片防止目录的分类目录一一对应的,即每个子目录的目录名就是labels.txt里的一个label,所以这里可以通过读取class_names的每一项去拼接path后读取
  for index, name in enumerate(class_names):
    # 这里是拼接后的子目录path
    classpath = os.path.join(image_path, name)
    # 先判断一下是否是目录
    if not os.path.isdir(classpath):
      continue
    # limit是测试时候用的这里可以去除
    limit = 0
    for image_name in os.listdir(classpath):
      if limit >= max_size:
        break
      # 这里是拼接后的待处理的图片path
      imagepath = os.path.join(classpath, image_name)
      count = count + 1
      limit = limit + 1
      # 利用Image打开图片
      img = Image.open(imagepath)
      # 缩放到你最初确定要处理的图片分辨率大小
      img = img.resize((IMG_SIZE_X, IMG_SIZE_Y))
      # 转为灰度图片,这里彩色通道会干扰结果,并且会加大计算量
      img = img.convert("L")
      # 转为numpy数组
      img = np.array(img)
      # 由(30,30)转为(1,30,30)(即`channels_first`),当然你也可以转换为(30,30,1)(即`channels_last`)但为了之后预览处理后的图片方便这里采用了(1,30,30)的格式存放
      img = np.reshape(img, (1, IMG_SIZE_X, IMG_SIZE_Y))
      # 这里利用循环生成labels数据,其中存放的实际是class_names中对应元素的索引
      labels.append([index])
      # 添加到images中,最后统一处理
      images.append(img)
      # 循环中一些状态的输出,可以去除
      print("{} class: {} {} limit: {} {}"
         .format(count, index + 1, class_names[index], limit, imagepath))
  # 最后一次性将images和labels都转换成numpy数组
  npy_data = np.array(images)
  npy_labels = np.array(labels)
  # 处理数据只需要一次,所以我们选择在这里利用numpy自带的方法将处理之后的数据序列化存储
  np.savez(data_path, x=npy_data, y=npy_labels)
  print("Save images by npz")
else:
  # 如果存在序列化号的数据,便直接读取,提高速度
  npy_data = np.load(data_path)["x"]
  npy_labels = np.load(data_path)["y"]
  print("Load images by npz")
image_data = npy_data
labels_data = npy_labels

到了这里原始数据的加工预处理便已经完成,只需要最后一步,就和demo中fashion_mnist.load_data()返回的结果一样了。代码如下:

# 最后一步就是将原始数据分成训练数据和测试数据
train_images, test_images, train_labels, test_labels = \
  train_test_split(image_data, labels_data, test_size=0.2, random_state=6)

这里将相关信息打印的方法也附上:

print("_________________________________________________________________")
print("%-28s %-s" % ("Name", "Shape"))
print("=================================================================")
print("%-28s %-s" % ("Image Data", image_data.shape))
print("%-28s %-s" % ("Labels Data", labels_data.shape))
print("=================================================================")

print('Split train and test data,p=%')
print("_________________________________________________________________")
print("%-28s %-s" % ("Name", "Shape"))
print("=================================================================")
print("%-28s %-s" % ("Train Images", train_images.shape))
print("%-28s %-s" % ("Test Images", test_images.shape))
print("%-28s %-s" % ("Train Labels", train_labels.shape))
print("%-28s %-s" % ("Test Labels", test_labels.shape))
print("=================================================================")

之后别忘了归一化哟:

print("Normalize images")
train_images = train_images / 255.0
test_images = test_images / 255.0

最后附上读取自定义数据的完整代码:

import os

import keras
import matplotlib.pyplot as plt
from PIL import Image
from keras.layers import *
from keras.models import *
from keras.optimizers import Adam
from keras.preprocessing.image import ImageDataGenerator
from sklearn.model_selection import train_test_split

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
# 支持中文
plt.rcParams['font.sans-serif'] = ['SimHei'] # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False # 用来正常显示负号
re_load = False
re_build = False
# re_load = True
re_build = True
epochs = 50
batch_size = 5
count = 0
max_size = 2000000000

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python删除文件示例分享
Jan 28 Python
python 字典(dict)按键和值排序
Jun 28 Python
python cx_Oracle模块的安装和使用详细介绍
Feb 13 Python
Python读csv文件去掉一列后再写入新的文件实例
Dec 28 Python
django 2.0更新的10条注意事项总结
Jan 05 Python
Python OOP类中的几种函数或方法总结
Feb 22 Python
Python常用数据类型之间的转换总结
Sep 06 Python
浅析使用Python搭建http服务器
Oct 27 Python
Django中提示消息messages的设置方式
Nov 15 Python
python为Django项目上的每个应用程序创建不同的自定义404页面(最佳答案)
Mar 09 Python
python如何运行js语句
Sep 09 Python
pandas 数据类型转换的实现
Dec 29 Python
在Python 字典中一键对应多个值的实例
Feb 03 #Python
Django csrf 两种方法设置form的实例
Feb 03 #Python
解决django前后端分离csrf验证的问题
Feb 03 #Python
Python利用heapq实现一个优先级队列的方法
Feb 03 #Python
对Python3中dict.keys()转换成list类型的方法详解
Feb 03 #Python
对python中字典keys,values,items的使用详解
Feb 03 #Python
python生成带有表格的图片实例
Feb 03 #Python
You might like
php实现的一个很好用HTML解析器类可用于采集数据
2013/09/23 PHP
PHP实现设计模式中的抽象工厂模式详解
2014/10/11 PHP
PHP利用hash冲突漏洞进行DDoS攻击的方法分析
2015/03/26 PHP
基于PHP生成简单的验证码
2016/06/01 PHP
PHP入门教程之PHP操作MySQL的方法分析
2016/09/11 PHP
PHP会话操作之cookie用法分析
2016/09/28 PHP
PHP执行shell脚本运行程序不产生core文件的方法
2016/12/28 PHP
Gambit vs CL BO3 第一场 2.13
2021/03/10 DOTA
jQuery 判断元素上是否绑定了事件
2009/10/28 Javascript
javascript中onmouse事件在div中失效问题的解决方法
2012/01/09 Javascript
jQuery中的select操作详解
2016/11/29 Javascript
移动端脚本框架Hammer.js
2016/12/15 Javascript
JS实现图片点击后出现模态框效果
2017/05/03 Javascript
Vue关于数据绑定出错解决办法
2017/05/15 Javascript
解决ie img标签内存泄漏的问题
2017/10/13 Javascript
Vue.js与 ASP.NET Core 服务端渲染功能整合
2017/11/16 Javascript
jQuery实现鼠标滑过商品小图片上显示对应大图片功能【测试可用】
2018/04/27 jQuery
JavaScript创建对象的常用方式总结
2018/08/10 Javascript
AngularJS使用$http配置对象方式与服务端交互方法
2018/08/13 Javascript
Node.js 如何利用异步提升任务处理速度
2019/01/07 Javascript
原生JavaScript实现的无缝滚动功能详解
2020/01/17 Javascript
Python标准库sched模块使用指南
2017/07/06 Python
Python使用pickle模块实现序列化功能示例
2018/07/13 Python
Python多线程正确用法实例解析
2020/05/30 Python
通过canvas转换颜色为RGBA格式及性能问题的解决
2019/11/22 HTML / CSS
匡威比利时官网:Converse Belgium
2017/04/13 全球购物
英国家庭珠宝商:T. H. Baker
2018/02/08 全球购物
资生堂英国官网:Shiseido英国
2020/12/30 全球购物
店长岗位的工作内容
2013/11/12 职场文书
大学生国家助学金感谢信
2015/01/23 职场文书
英语教师求职信范文
2015/03/20 职场文书
《倍数和因数》教学反思
2016/02/23 职场文书
CSS3实现的3D隧道效果
2021/04/27 HTML / CSS
Jpa Specification如何实现and和or同时使用查询
2021/11/23 Java/Android
Java获取字符串编码格式实现思路
2022/09/23 Java/Android
keepalived + nginx 实现高可用方案
2022/12/24 Servers