Tensorflow分类器项目自定义数据读入的实现


Posted in Python onFebruary 05, 2019

在照着Tensorflow官网的demo敲了一遍分类器项目的代码后,运行倒是成功了,结果也不错。但是最终还是要训练自己的数据,所以尝试准备加载自定义的数据,然而demo中只是出现了fashion_mnist.load_data()并没有详细的读取过程,随后我又找了些资料,把读取的过程记录在这里。

首先提一下需要用到的模块:

import os
import keras
import matplotlib.pyplot as plt
from PIL import Image
from keras.preprocessing.image import ImageDataGenerator
from sklearn.model_selection import train_test_split

图片分类器项目,首先确定你要处理的图片分辨率将是多少,这里的例子为30像素:

IMG_SIZE_X = 30
IMG_SIZE_Y = 30

其次确定你图片的方式目录:

image_path = r'D:\Projects\ImageClassifier\data\set'
path = ".\data"
# 你也可以使用相对路径的方式
# image_path =os.path.join(path, "set")

目录下的结构如下:

Tensorflow分类器项目自定义数据读入的实现

相应的label.txt如下:

动漫
风景
美女
物语
樱花

接下来是接在labels.txt,如下:

label_name = "labels.txt"
label_path = os.path.join(path, label_name)
class_names = np.loadtxt(label_path, type(""))

这里简便起见,直接利用了numpy的loadtxt函数直接加载。

之后便是正式处理图片数据了,注释就写在里面了:

re_load = False
re_build = False
# re_load = True
re_build = True

data_name = "data.npz"
data_path = os.path.join(path, data_name)
model_name = "model.h5"
model_path = os.path.join(path, model_name)

count = 0

# 这里判断是否存在序列化之后的数据,re_load是一个开关,是否强制重新处理,测试用,可以去除。
if not os.path.exists(data_path) or re_load:
  labels = []
  images = []
  print('Handle images')
  # 由于label.txt是和图片防止目录的分类目录一一对应的,即每个子目录的目录名就是labels.txt里的一个label,所以这里可以通过读取class_names的每一项去拼接path后读取
  for index, name in enumerate(class_names):
    # 这里是拼接后的子目录path
    classpath = os.path.join(image_path, name)
    # 先判断一下是否是目录
    if not os.path.isdir(classpath):
      continue
    # limit是测试时候用的这里可以去除
    limit = 0
    for image_name in os.listdir(classpath):
      if limit >= max_size:
        break
      # 这里是拼接后的待处理的图片path
      imagepath = os.path.join(classpath, image_name)
      count = count + 1
      limit = limit + 1
      # 利用Image打开图片
      img = Image.open(imagepath)
      # 缩放到你最初确定要处理的图片分辨率大小
      img = img.resize((IMG_SIZE_X, IMG_SIZE_Y))
      # 转为灰度图片,这里彩色通道会干扰结果,并且会加大计算量
      img = img.convert("L")
      # 转为numpy数组
      img = np.array(img)
      # 由(30,30)转为(1,30,30)(即`channels_first`),当然你也可以转换为(30,30,1)(即`channels_last`)但为了之后预览处理后的图片方便这里采用了(1,30,30)的格式存放
      img = np.reshape(img, (1, IMG_SIZE_X, IMG_SIZE_Y))
      # 这里利用循环生成labels数据,其中存放的实际是class_names中对应元素的索引
      labels.append([index])
      # 添加到images中,最后统一处理
      images.append(img)
      # 循环中一些状态的输出,可以去除
      print("{} class: {} {} limit: {} {}"
         .format(count, index + 1, class_names[index], limit, imagepath))
  # 最后一次性将images和labels都转换成numpy数组
  npy_data = np.array(images)
  npy_labels = np.array(labels)
  # 处理数据只需要一次,所以我们选择在这里利用numpy自带的方法将处理之后的数据序列化存储
  np.savez(data_path, x=npy_data, y=npy_labels)
  print("Save images by npz")
else:
  # 如果存在序列化号的数据,便直接读取,提高速度
  npy_data = np.load(data_path)["x"]
  npy_labels = np.load(data_path)["y"]
  print("Load images by npz")
image_data = npy_data
labels_data = npy_labels

到了这里原始数据的加工预处理便已经完成,只需要最后一步,就和demo中fashion_mnist.load_data()返回的结果一样了。代码如下:

# 最后一步就是将原始数据分成训练数据和测试数据
train_images, test_images, train_labels, test_labels = \
  train_test_split(image_data, labels_data, test_size=0.2, random_state=6)

这里将相关信息打印的方法也附上:

print("_________________________________________________________________")
print("%-28s %-s" % ("Name", "Shape"))
print("=================================================================")
print("%-28s %-s" % ("Image Data", image_data.shape))
print("%-28s %-s" % ("Labels Data", labels_data.shape))
print("=================================================================")

print('Split train and test data,p=%')
print("_________________________________________________________________")
print("%-28s %-s" % ("Name", "Shape"))
print("=================================================================")
print("%-28s %-s" % ("Train Images", train_images.shape))
print("%-28s %-s" % ("Test Images", test_images.shape))
print("%-28s %-s" % ("Train Labels", train_labels.shape))
print("%-28s %-s" % ("Test Labels", test_labels.shape))
print("=================================================================")

之后别忘了归一化哟:

print("Normalize images")
train_images = train_images / 255.0
test_images = test_images / 255.0

最后附上读取自定义数据的完整代码:

import os

import keras
import matplotlib.pyplot as plt
from PIL import Image
from keras.layers import *
from keras.models import *
from keras.optimizers import Adam
from keras.preprocessing.image import ImageDataGenerator
from sklearn.model_selection import train_test_split

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
# 支持中文
plt.rcParams['font.sans-serif'] = ['SimHei'] # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False # 用来正常显示负号
re_load = False
re_build = False
# re_load = True
re_build = True
epochs = 50
batch_size = 5
count = 0
max_size = 2000000000

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python list语法学习(带例子)
Nov 01 Python
手动实现把python项目发布为exe可执行程序过程分享
Oct 23 Python
python实现感知器算法详解
Dec 19 Python
python实现RabbitMQ的消息队列的示例代码
Nov 08 Python
PyQt5 在label显示的图片中绘制矩形的方法
Jun 17 Python
python pytest进阶之xunit fixture详解
Jun 27 Python
python3 selenium自动化 下拉框定位的例子
Aug 23 Python
Python使用__new__()方法为对象分配内存及返回对象的引用示例
Sep 20 Python
Django实现auth模块下的登录注册与注销功能
Oct 10 Python
python返回数组的索引实例
Nov 28 Python
Python实现井字棋小游戏
Mar 09 Python
Python爬虫之Selenium多窗口切换的实现
Dec 04 Python
在Python 字典中一键对应多个值的实例
Feb 03 #Python
Django csrf 两种方法设置form的实例
Feb 03 #Python
解决django前后端分离csrf验证的问题
Feb 03 #Python
Python利用heapq实现一个优先级队列的方法
Feb 03 #Python
对Python3中dict.keys()转换成list类型的方法详解
Feb 03 #Python
对python中字典keys,values,items的使用详解
Feb 03 #Python
python生成带有表格的图片实例
Feb 03 #Python
You might like
解析PayPal支付接口的PHP开发方式
2010/11/28 PHP
JpGraph php柱状图使用介绍
2011/08/23 PHP
学习php设计模式 php实现工厂模式(factory)
2015/12/07 PHP
服务器迁移php版本不同可能诱发的问题
2015/12/22 PHP
弹出广告特效代码(一个IP只弹出一次)
2007/05/11 Javascript
TreeView 用法(有代码)(asp.net)
2011/07/15 Javascript
基于jquery的放大镜效果
2012/05/30 Javascript
Js数组的操作push,pop,shift,unshift等方法详细介绍
2012/12/28 Javascript
jquery如何通过name名称获取当前name的value值
2013/12/20 Javascript
javascript实现点击商品列表checkbox实时统计金额的方法
2015/05/15 Javascript
简单谈谈Vue 模板各类数据绑定
2016/09/25 Javascript
Bootstrap CSS组件之按钮下拉菜单
2016/12/17 Javascript
jquery实现简单实用的轮播器
2017/05/23 jQuery
webpack配置打包后图片路径出错的解决
2018/04/26 Javascript
vue脚手架搭建项目的兼容性配置详解
2018/07/17 Javascript
JS+HTML5 Canvas实现简单的写字板功能示例
2018/08/30 Javascript
vue路由切换之淡入淡出的简单实现
2019/10/31 Javascript
微信小程序实现横向滚动导航栏效果
2019/12/12 Javascript
elementUI同一页面展示多个Dialog的实现
2020/11/19 Javascript
[53:36]Liquid vs VP Supermajor决赛 BO 第三场 6.10
2018/07/05 DOTA
pyv8学习python和javascript变量进行交互
2013/12/04 Python
python3.3教程之模拟百度登陆代码分享
2014/01/16 Python
Python字典,函数,全局变量代码解析
2017/12/18 Python
python3.5+tesseract+adb实现西瓜视频或头脑王者辅助答题
2018/01/17 Python
Django框架模板注入操作示例【变量传递到模板】
2018/12/19 Python
零基础使用Python读写处理Excel表格的方法
2019/05/02 Python
PIL图像处理模块paste方法简单使用详解
2019/07/17 Python
Python队列、进程间通信、线程案例
2019/10/25 Python
HTML5实现晶莹剔透的雨滴特效
2014/05/14 HTML / CSS
在C中是否有模拟继承等面向对象程序设计特性的好方法
2012/05/22 面试题
什么是数组名
2012/05/10 面试题
搞笑爱情保证书
2014/04/29 职场文书
护理专科学生自荐书
2014/07/05 职场文书
销售代理协议书
2014/09/30 职场文书
不同意离婚上诉状
2015/05/23 职场文书
公司员工宿舍管理制度
2015/08/03 职场文书