使用TensorFlow-Slim进行图像分类的实现


Posted in Python onDecember 31, 2019

参考 https://github.com/tensorflow/models/tree/master/slim

使用TensorFlow-Slim进行图像分类

准备

安装TensorFlow

参考 https://www.tensorflow.org/install/

如在Ubuntu下安装TensorFlow with GPU support, python 2.7版本

wget https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-1.2.0-cp27-none-linux_x86_64.whl
pip install tensorflow_gpu-1.2.0-cp27-none-linux_x86_64.whl

下载TF-slim图像模型库

cd $WORKSPACE
git clone https://github.com/tensorflow/models/

准备数据

有不少公开数据集,这里以官网提供的Flowers为例。

官网提供了下载和转换数据的代码,为了理解代码并能使用自己的数据,这里参考官方提供的代码进行修改。

cd $WORKSPACE/data
wget http://download.tensorflow.org/example_images/flower_photos.tgz
tar zxf flower_photos.tgz

数据集文件夹结构如下:

flower_photos
├── daisy
│  ├── 100080576_f52e8ee070_n.jpg
│  └── ...
├── dandelion
├── LICENSE.txt
├── roses
├── sunflowers
└── tulips

由于实际情况中我们自己的数据集并不一定把图片按类别放在不同的文件夹里,故我们生成list.txt来表示图片路径与标签的关系。

Python代码:

import os

class_names_to_ids = {'daisy': 0, 'dandelion': 1, 'roses': 2, 'sunflowers': 3, 'tulips': 4}
data_dir = 'flower_photos/'
output_path = 'list.txt'

fd = open(output_path, 'w')
for class_name in class_names_to_ids.keys():
  images_list = os.listdir(data_dir + class_name)
  for image_name in images_list:
    fd.write('{}/{} {}\n'.format(class_name, image_name, class_names_to_ids[class_name]))

fd.close()

为了方便后期查看label标签,也可以定义labels.txt:

daisy
dandelion
roses
sunflowers
tulips

随机生成训练集与验证集:

Python代码:

import random

_NUM_VALIDATION = 350
_RANDOM_SEED = 0
list_path = 'list.txt'
train_list_path = 'list_train.txt'
val_list_path = 'list_val.txt'

fd = open(list_path)
lines = fd.readlines()
fd.close()
random.seed(_RANDOM_SEED)
random.shuffle(lines)

fd = open(train_list_path, 'w')
for line in lines[_NUM_VALIDATION:]:
  fd.write(line)

fd.close()
fd = open(val_list_path, 'w')
for line in lines[:_NUM_VALIDATION]:
  fd.write(line)

fd.close()

生成TFRecord数据:

Python代码:

import sys
sys.path.insert(0, '../models/slim/')
from datasets import dataset_utils
import math
import os
import tensorflow as tf

def convert_dataset(list_path, data_dir, output_dir, _NUM_SHARDS=5):
  fd = open(list_path)
  lines = [line.split() for line in fd]
  fd.close()
  num_per_shard = int(math.ceil(len(lines) / float(_NUM_SHARDS)))
  with tf.Graph().as_default():
    decode_jpeg_data = tf.placeholder(dtype=tf.string)
    decode_jpeg = tf.image.decode_jpeg(decode_jpeg_data, channels=3)
    with tf.Session('') as sess:
      for shard_id in range(_NUM_SHARDS):
        output_path = os.path.join(output_dir,
          'data_{:05}-of-{:05}.tfrecord'.format(shard_id, _NUM_SHARDS))
        tfrecord_writer = tf.python_io.TFRecordWriter(output_path)
        start_ndx = shard_id * num_per_shard
        end_ndx = min((shard_id + 1) * num_per_shard, len(lines))
        for i in range(start_ndx, end_ndx):
          sys.stdout.write('\r>> Converting image {}/{} shard {}'.format(
            i + 1, len(lines), shard_id))
          sys.stdout.flush()
          image_data = tf.gfile.FastGFile(os.path.join(data_dir, lines[i][0]), 'rb').read()
          image = sess.run(decode_jpeg, feed_dict={decode_jpeg_data: image_data})
          height, width = image.shape[0], image.shape[1]
          example = dataset_utils.image_to_tfexample(
            image_data, b'jpg', height, width, int(lines[i][1]))
          tfrecord_writer.write(example.SerializeToString())
        tfrecord_writer.close()
  sys.stdout.write('\n')
  sys.stdout.flush()

os.system('mkdir -p train')
convert_dataset('list_train.txt', 'flower_photos', 'train/')
os.system('mkdir -p val')
convert_dataset('list_val.txt', 'flower_photos', 'val/')

得到的文件夹结构如下:

data
├── flower_photos
├── labels.txt
├── list_train.txt
├── list.txt
├── list_val.txt
├── train
│  ├── data_00000-of-00005.tfrecord
│  ├── ...
│  └── data_00004-of-00005.tfrecord
└── val
  ├── data_00000-of-00005.tfrecord
  ├── ...
  └── data_00004-of-00005.tfrecord

(可选)下载模型

官方提供了不少预训练模型,这里以Inception-ResNet-v2以例。

cd $WORKSPACE/checkpoints
wget http://download.tensorflow.org/models/inception_resnet_v2_2016_08_30.tar.gz
tar zxf inception_resnet_v2_2016_08_30.tar.gz

训练

读入数据

官方提供了读入Flowers数据集的代码models/slim/datasets/flowers.py,同样这里也是参考并修改成能读入上面定义的通用数据集。

把下面代码写入models/slim/datasets/dataset_classification.py。

import os
import tensorflow as tf
slim = tf.contrib.slim

def get_dataset(dataset_dir, num_samples, num_classes, labels_to_names_path=None, file_pattern='*.tfrecord'):
  file_pattern = os.path.join(dataset_dir, file_pattern)
  keys_to_features = {
    'image/encoded': tf.FixedLenFeature((), tf.string, default_value=''),
    'image/format': tf.FixedLenFeature((), tf.string, default_value='png'),
    'image/class/label': tf.FixedLenFeature(
      [], tf.int64, default_value=tf.zeros([], dtype=tf.int64)),
  }
  items_to_handlers = {
    'image': slim.tfexample_decoder.Image(),
    'label': slim.tfexample_decoder.Tensor('image/class/label'),
  }
  decoder = slim.tfexample_decoder.TFExampleDecoder(keys_to_features, items_to_handlers)
  items_to_descriptions = {
    'image': 'A color image of varying size.',
    'label': 'A single integer between 0 and ' + str(num_classes - 1),
  }
  labels_to_names = None
  if labels_to_names_path is not None:
    fd = open(labels_to_names_path)
    labels_to_names = {i : line.strip() for i, line in enumerate(fd)}
    fd.close()
  return slim.dataset.Dataset(
      data_sources=file_pattern,
      reader=tf.TFRecordReader,
      decoder=decoder,
      num_samples=num_samples,
      items_to_descriptions=items_to_descriptions,
      num_classes=num_classes,
      labels_to_names=labels_to_names)

构建模型

官方提供了许多模型在models/slim/nets/。

如需要自定义模型,则参考官方提供的模型并放在对应的文件夹即可。

开始训练

官方提供了训练脚本,如果使用官方的数据读入和处理,可使用以下方式开始训练。

cd $WORKSPACE/models/slim
CUDA_VISIBLE_DEVICES="0" python train_image_classifier.py \
  --train_dir=train_logs \
  --dataset_name=flowers \
  --dataset_split_name=train \
  --dataset_dir=../../data/flowers \
  --model_name=inception_resnet_v2 \
  --checkpoint_path=../../checkpoints/inception_resnet_v2_2016_08_30.ckpt \
  --checkpoint_exclude_scopes=InceptionResnetV2/Logits,InceptionResnetV2/AuxLogits \
  --trainable_scopes=InceptionResnetV2/Logits,InceptionResnetV2/AuxLogits \
  --max_number_of_steps=1000 \
  --batch_size=32 \
  --learning_rate=0.01 \
  --learning_rate_decay_type=fixed \
  --save_interval_secs=60 \
  --save_summaries_secs=60 \
  --log_every_n_steps=10 \
  --optimizer=rmsprop \
  --weight_decay=0.00004

不fine-tune把--checkpoint_path, --checkpoint_exclude_scopes和--trainable_scopes删掉。

fine-tune所有层把--checkpoint_exclude_scopes和--trainable_scopes删掉。

如果只使用CPU则加上--clone_on_cpu=True。

其它参数可删掉用默认值或自行修改。

使用自己的数据则需要修改models/slim/train_image_classifier.py:

from datasets import dataset_factory

修改为

from datasets import dataset_classification

dataset = dataset_factory.get_dataset(
  FLAGS.dataset_name, FLAGS.dataset_split_name, FLAGS.dataset_dir)

修改为

dataset = dataset_classification.get_dataset(
  FLAGS.dataset_dir, FLAGS.num_samples, FLAGS.num_classes, FLAGS.labels_to_names_path)

tf.app.flags.DEFINE_string(
  'dataset_dir', None, 'The directory where the dataset files are stored.')

后加入

tf.app.flags.DEFINE_integer(
  'num_samples', 3320, 'Number of samples.')

tf.app.flags.DEFINE_integer(
  'num_classes', 5, 'Number of classes.')

tf.app.flags.DEFINE_string(
  'labels_to_names_path', None, 'Label names file path.')

训练时执行以下命令即可:

cd $WORKSPACE/models/slim
python train_image_classifier.py \
  --train_dir=train_logs \
  --dataset_dir=../../data/train \
  --num_samples=3320 \
  --num_classes=5 \
  --labels_to_names_path=../../data/labels.txt \
  --model_name=inception_resnet_v2 \
  --checkpoint_path=../../checkpoints/inception_resnet_v2_2016_08_30.ckpt \
  --checkpoint_exclude_scopes=InceptionResnetV2/Logits,InceptionResnetV2/AuxLogits \
  --trainable_scopes=InceptionResnetV2/Logits,InceptionResnetV2/AuxLogits

可视化log

可一边训练一边可视化训练的log,可看到Loss趋势。

tensorboard --logdir train_logs/

验证

官方提供了验证脚本。

python eval_image_classifier.py \
  --checkpoint_path=train_logs \
  --eval_dir=eval_logs \
  --dataset_name=flowers \
  --dataset_split_name=validation \
  --dataset_dir=../../data/flowers \
  --model_name=inception_resnet_v2

同样,如果是使用自己的数据集,则需要修改models/slim/eval_image_classifier.py:

from datasets import dataset_factory

修改为

from datasets import dataset_classification

dataset = dataset_factory.get_dataset(
  FLAGS.dataset_name, FLAGS.dataset_split_name, FLAGS.dataset_dir)

修改为

dataset = dataset_classification.get_dataset(
  FLAGS.dataset_dir, FLAGS.num_samples, FLAGS.num_classes, FLAGS.labels_to_names_path)

tf.app.flags.DEFINE_string(
  'dataset_dir', None, 'The directory where the dataset files are stored.')

后加入

tf.app.flags.DEFINE_integer(
  'num_samples', 350, 'Number of samples.')

tf.app.flags.DEFINE_integer(
  'num_classes', 5, 'Number of classes.')

tf.app.flags.DEFINE_string(
  'labels_to_names_path', None, 'Label names file path.')

验证时执行以下命令即可:

python eval_image_classifier.py \
  --checkpoint_path=train_logs \
  --eval_dir=eval_logs \
  --dataset_dir=../../data/val \
  --num_samples=350 \
  --num_classes=5 \
  --model_name=inception_resnet_v2

可以一边训练一边验证,,注意使用其它的GPU或合理分配显存。

同样也可以可视化log,如果已经在可视化训练的log则建议使用其它端口,如:

tensorboard --logdir eval_logs/ --port 6007

测试

参考models/slim/eval_image_classifier.py,可编写读取图片用模型进行推导的脚本models/slim/test_image_classifier.py

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
import math
import tensorflow as tf

from nets import nets_factory
from preprocessing import preprocessing_factory

slim = tf.contrib.slim

tf.app.flags.DEFINE_string(
  'master', '', 'The address of the TensorFlow master to use.')

tf.app.flags.DEFINE_string(
  'checkpoint_path', '/tmp/tfmodel/',
  'The directory where the model was written to or an absolute path to a '
  'checkpoint file.')

tf.app.flags.DEFINE_string(
  'test_path', '', 'Test image path.')

tf.app.flags.DEFINE_integer(
  'num_classes', 5, 'Number of classes.')

tf.app.flags.DEFINE_integer(
  'labels_offset', 0,
  'An offset for the labels in the dataset. This flag is primarily used to '
  'evaluate the VGG and ResNet architectures which do not use a background '
  'class for the ImageNet dataset.')

tf.app.flags.DEFINE_string(
  'model_name', 'inception_v3', 'The name of the architecture to evaluate.')

tf.app.flags.DEFINE_string(
  'preprocessing_name', None, 'The name of the preprocessing to use. If left '
  'as `None`, then the model_name flag is used.')

tf.app.flags.DEFINE_integer(
  'test_image_size', None, 'Eval image size')

FLAGS = tf.app.flags.FLAGS


def main(_):
  if not FLAGS.test_list:
    raise ValueError('You must supply the test list with --test_list')

  tf.logging.set_verbosity(tf.logging.INFO)
  with tf.Graph().as_default():
    tf_global_step = slim.get_or_create_global_step()

    ####################
    # Select the model #
    ####################
    network_fn = nets_factory.get_network_fn(
      FLAGS.model_name,
      num_classes=(FLAGS.num_classes - FLAGS.labels_offset),
      is_training=False)

    #####################################
    # Select the preprocessing function #
    #####################################
    preprocessing_name = FLAGS.preprocessing_name or FLAGS.model_name
    image_preprocessing_fn = preprocessing_factory.get_preprocessing(
      preprocessing_name,
      is_training=False)

    test_image_size = FLAGS.test_image_size or network_fn.default_image_size

    if tf.gfile.IsDirectory(FLAGS.checkpoint_path):
      checkpoint_path = tf.train.latest_checkpoint(FLAGS.checkpoint_path)
    else:
      checkpoint_path = FLAGS.checkpoint_path

    tf.Graph().as_default()
    with tf.Session() as sess:
      image = open(FLAGS.test_path, 'rb').read()
      image = tf.image.decode_jpeg(image, channels=3)
      processed_image = image_preprocessing_fn(image, test_image_size, test_image_size)
      processed_images = tf.expand_dims(processed_image, 0)
      logits, _ = network_fn(processed_images)
      predictions = tf.argmax(logits, 1)
      saver = tf.train.Saver()
      saver.restore(sess, checkpoint_path)
      np_image, network_input, predictions = sess.run([image, processed_image, predictions])
      print('{} {}'.format(FLAGS.test_path, predictions[0]))

if __name__ == '__main__':
  tf.app.run()

测试时执行以下命令即可:

python test_image_classifier.py \
  --checkpoint_path=train_logs/ \
  --test_path=../../data/flower_photos/tulips/6948239566_0ac0a124ee_n.jpg \
  --num_classes=5 \
  --model_name=inception_resnet_v2

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python入门学习之字符串与比较运算符
Oct 12 Python
Windows安装Python、pip、easy_install的方法
Mar 05 Python
Python3实战之爬虫抓取网易云音乐的热门评论
Oct 09 Python
对numpy的array和python中自带的list之间相互转化详解
Apr 13 Python
解决Django的request.POST获取不到内容的问题
May 28 Python
elasticsearch python 查询的两种方法
Aug 04 Python
有关Tensorflow梯度下降常用的优化方法分享
Feb 04 Python
Jupyter Notebook的连接密码 token查询方式
Apr 21 Python
基于python生成英文版词云图代码实例
May 16 Python
python爬虫 requests-html的使用
Nov 30 Python
Python趣味挑战之用pygame实现简单的金币旋转效果
May 31 Python
python 爬取天气网卫星图片
Jun 07 Python
Pytorch之view及view_as使用详解
Dec 31 #Python
window环境pip切换国内源(pip安装异常缓慢的问题)
Dec 31 #Python
如何基于Python创建目录文件夹
Dec 31 #Python
Pytorch之contiguous的用法
Dec 31 #Python
python实现将json多行数据传入到mysql中使用
Dec 31 #Python
Pytorch之Variable的用法
Dec 31 #Python
Pytorch 多块GPU的使用详解
Dec 31 #Python
You might like
php长字符串定义方法
2012/07/12 PHP
制作个性化的WordPress登陆界面的实例教程
2016/05/21 PHP
PHP静态成员变量和非静态成员变量详解
2017/02/14 PHP
PHP实现的Redis多库选择功能单例类
2017/07/27 PHP
JQuery 应用 JQuery.groupTable.js
2010/12/15 Javascript
数组方法解决JS字符串连接性能问题有争议
2011/01/12 Javascript
jQuery基本过滤选择器使用介绍
2013/04/18 Javascript
使用JavaScript和C#中获得referer
2014/11/14 Javascript
有效提高JavaScript执行效率的几点知识
2015/01/31 Javascript
JS实现的左侧竖向滑动菜单效果代码
2015/10/19 Javascript
如何使用PHP+jQuery+MySQL实现异步加载ECharts地图数据(附源码下载)
2016/02/23 Javascript
jQuery的Each比JS原生for循环性能慢很多的原因
2016/07/05 Javascript
功能强大的Bootstrap组件(结合js)
2016/08/03 Javascript
react实现pure render时bind(this)隐患需注意!
2017/03/09 Javascript
vant 中van-list的用法说明
2020/11/11 Javascript
python计算最大优先级队列实例
2013/12/18 Python
python之matplotlib学习绘制动态更新图实例代码
2018/01/23 Python
Windows下安装Scrapy
2018/10/17 Python
修改python plot折线图的坐标轴刻度方法
2018/12/13 Python
python palywright库基本使用
2021/01/21 Python
python如何修改文件时间属性
2021/02/05 Python
python推导式的使用方法实例
2021/02/28 Python
通过css3动画和opacity透明度实现呼吸灯效果
2019/08/09 HTML / CSS
HTML5 Canvas中绘制矩形实例
2015/01/01 HTML / CSS
HTML5中的websocket实现直播功能
2018/05/21 HTML / CSS
AmazeUI 导航条的实现示例
2020/08/14 HTML / CSS
美国孕妇装购物网站:Motherhood Maternity
2019/09/22 全球购物
大四自我鉴定范文
2013/10/06 职场文书
大学生如何写自荐信
2014/01/08 职场文书
旅游安全协议书
2014/04/21 职场文书
纪检干部对照检查材料
2014/08/22 职场文书
2015年十一国庆节演讲稿
2015/03/20 职场文书
社区服务理念口号
2015/12/25 职场文书
雄兵连:第三季先行图公开,天使恶魔联合,银河之力的新力量
2021/06/11 国漫
Python借助with语句实现代码段只执行有限次
2022/03/23 Python
Python使用Opencv打开笔记本电脑摄像头报错解问题及解决
2022/06/21 Python