详解tensorflow训练自己的数据集实现CNN图像分类


Posted in Python onFebruary 07, 2018

利用卷积神经网络训练图像数据分为以下几个步骤

1.读取图片文件
2.产生用于训练的批次
3.定义训练的模型(包括初始化参数,卷积、池化层等参数、网络)
4.训练

1 读取图片文件

def get_files(filename):
 class_train = []
 label_train = []
 for train_class in os.listdir(filename):
  for pic in os.listdir(filename+train_class):
   class_train.append(filename+train_class+'/'+pic)
   label_train.append(train_class)
 temp = np.array([class_train,label_train])
 temp = temp.transpose()
 #shuffle the samples
 np.random.shuffle(temp)
 #after transpose, images is in dimension 0 and label in dimension 1
 image_list = list(temp[:,0])
 label_list = list(temp[:,1])
 label_list = [int(i) for i in label_list]
 #print(label_list)
 return image_list,label_list

这里文件名作为标签,即类别(其数据类型要确定,后面要转为tensor类型数据)。

然后将image和label转为list格式数据,因为后边用到的的一些tensorflow函数接收的是list格式数据。

2 产生用于训练的批次

def get_batches(image,label,resize_w,resize_h,batch_size,capacity):
 #convert the list of images and labels to tensor
 image = tf.cast(image,tf.string)
 label = tf.cast(label,tf.int64)
 queue = tf.train.slice_input_producer([image,label])
 label = queue[1]
 image_c = tf.read_file(queue[0])
 image = tf.image.decode_jpeg(image_c,channels = 3)
 #resize
 image = tf.image.resize_image_with_crop_or_pad(image,resize_w,resize_h)
 #(x - mean) / adjusted_stddev
 image = tf.image.per_image_standardization(image)
 
 image_batch,label_batch = tf.train.batch([image,label],
            batch_size = batch_size,
            num_threads = 64,
            capacity = capacity)
 images_batch = tf.cast(image_batch,tf.float32)
 labels_batch = tf.reshape(label_batch,[batch_size])
 return images_batch,labels_batch

首先使用tf.cast转化为tensorflow数据格式,使用tf.train.slice_input_producer实现一个输入的队列。

label不需要处理,image存储的是路径,需要读取为图片,接下来的几步就是读取路径转为图片,用于训练。

CNN对图像大小是敏感的,第10行图片resize处理为大小一致,12行将其标准化,即减去所有图片的均值,方便训练。

接下来使用tf.train.batch函数产生训练的批次。

最后将产生的批次做数据类型的转换和shape的处理即可产生用于训练的批次。

3 定义训练的模型

(1)训练参数的定义及初始化

def init_weights(shape):
 return tf.Variable(tf.random_normal(shape,stddev = 0.01))
#init weights
weights = {
 "w1":init_weights([3,3,3,16]),
 "w2":init_weights([3,3,16,128]),
 "w3":init_weights([3,3,128,256]),
 "w4":init_weights([4096,4096]),
 "wo":init_weights([4096,2])
 }

#init biases
biases = {
 "b1":init_weights([16]),
 "b2":init_weights([128]),
 "b3":init_weights([256]),
 "b4":init_weights([4096]),
 "bo":init_weights([2])
 }

CNN的每层是y=wx+b的决策模型,卷积层产生特征向量,根据这些特征向量带入x进行计算,因此,需要定义卷积层的初始化参数,包括权重和偏置。其中第8行的参数形状后边再解释。

(2)定义不同层的操作

def conv2d(x,w,b):
 x = tf.nn.conv2d(x,w,strides = [1,1,1,1],padding = "SAME")
 x = tf.nn.bias_add(x,b)
 return tf.nn.relu(x)

def pooling(x):
 return tf.nn.max_pool(x,ksize = [1,2,2,1],strides = [1,2,2,1],padding = "SAME")

def norm(x,lsize = 4):
 return tf.nn.lrn(x,depth_radius = lsize,bias = 1,alpha = 0.001/9.0,beta = 0.75)

这里只定义了三种层,即卷积层、池化层和正则化层

(3)定义训练模型

def mmodel(images):
 l1 = conv2d(images,weights["w1"],biases["b1"])
 l2 = pooling(l1)
 l2 = norm(l2)
 l3 = conv2d(l2,weights["w2"],biases["b2"])
 l4 = pooling(l3)
 l4 = norm(l4)
 l5 = conv2d(l4,weights["w3"],biases["b3"])
 #same as the batch size
 l6 = pooling(l5)
 l6 = tf.reshape(l6,[-1,weights["w4"].get_shape().as_list()[0]])
 l7 = tf.nn.relu(tf.matmul(l6,weights["w4"])+biases["b4"])
 soft_max = tf.add(tf.matmul(l7,weights["wo"]),biases["bo"])
 return soft_max

模型比较简单,使用三层卷积,第11行使用全连接,需要对特征向量进行reshape,其中l6的形状为[-1,w4的第1维的参数],因此,将其按照“w4”reshape的时候,要使得-1位置的大小为batch_size,这样,最终再乘以“wo”时,最终的输出大小为[batch_size,class_num]

(4)定义评估量

def loss(logits,label_batches):
  cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits,labels=label_batches)
  cost = tf.reduce_mean(cross_entropy)
  return cost

首先定义损失函数,这是用于训练最小化损失的必需量
 def get_accuracy(logits,labels):
  acc = tf.nn.in_top_k(logits,labels,1)
  acc = tf.cast(acc,tf.float32)
  acc = tf.reduce_mean(acc)
  return acc

评价分类准确率的量,训练时,需要loss值减小,准确率增加,这样的训练才是收敛的。

(5)定义训练方式

def training(loss,lr):
  train_op = tf.train.RMSPropOptimizer(lr,0.9).minimize(loss)
  return train_op

有很多种训练方式,可以自行去官网查看,但是不同的训练方式可能对应前面的参数定义不一样,需要另行处理,否则可能报错。

 4 训练

def run_training():
 data_dir = 'C:/Users/wk/Desktop/bky/dataSet/'
 image,label = inputData.get_files(data_dir)
 image_batches,label_batches = inputData.get_batches(image,label,32,32,16,20)
 p = model.mmodel(image_batches)
 cost = model.loss(p,label_batches)
 train_op = model.training(cost,0.001)
 acc = model.get_accuracy(p,label_batches)
 
 sess = tf.Session()
 init = tf.global_variables_initializer()
 sess.run(init)
 
 coord = tf.train.Coordinator()
 threads = tf.train.start_queue_runners(sess = sess,coord = coord)
 
 try:
  for step in np.arange(1000):
   print(step)
   if coord.should_stop():
    break
   _,train_acc,train_loss = sess.run([train_op,acc,cost])
   print("loss:{} accuracy:{}".format(train_loss,train_acc))
 except tf.errors.OutOfRangeError:
  print("Done!!!")
 finally:
  coord.request_stop()
 coord.join(threads)
 sess.close()

神经网络训练的时候,我们需要将模型保存下来,方便后面继续训练或者用训练好的模型进行测试。因此,我们需要创建一个saver保存模型。

def run_training():
 data_dir = 'C:/Users/wk/Desktop/bky/dataSet/'
 log_dir = 'C:/Users/wk/Desktop/bky/log/'
 image,label = inputData.get_files(data_dir)
 image_batches,label_batches = inputData.get_batches(image,label,32,32,16,20)
 print(image_batches.shape)
 p = model.mmodel(image_batches,16)
 cost = model.loss(p,label_batches)
 train_op = model.training(cost,0.001)
 acc = model.get_accuracy(p,label_batches)
 
 sess = tf.Session()
 init = tf.global_variables_initializer()
 sess.run(init)
 saver = tf.train.Saver()
 coord = tf.train.Coordinator()
 threads = tf.train.start_queue_runners(sess = sess,coord = coord)
 
 try:
  for step in np.arange(1000):
   print(step)
   if coord.should_stop():
    break
   _,train_acc,train_loss = sess.run([train_op,acc,cost])
   print("loss:{} accuracy:{}".format(train_loss,train_acc))
   if step % 100 == 0:
    check = os.path.join(log_dir,"model.ckpt")
    saver.save(sess,check,global_step = step)
 except tf.errors.OutOfRangeError:
  print("Done!!!")
 finally:
  coord.request_stop()
 coord.join(threads)
 sess.close()

训练好的模型信息会记录在checkpoint文件中,大致如下: 

model_checkpoint_path: "C:/Users/wk/Desktop/bky/log/model.ckpt-100"
all_model_checkpoint_paths: "C:/Users/wk/Desktop/bky/log/model.ckpt-0"
all_model_checkpoint_paths: "C:/Users/wk/Desktop/bky/log/model.ckpt-100"

其余还会生成一些文件,分别记录了模型参数等信息,后边测试的时候程序会读取checkpoint文件去加载这些真正的数据文件

详解tensorflow训练自己的数据集实现CNN图像分类

构建好神经网络进行训练完成后,如果用之前的代码直接进行测试,会报shape不符合的错误,大致是卷积层的输入与图像的shape不一致,这是因为上篇的代码,将weights和biases定义在了模型的外面,调用模型的时候,出现valueError的错误。

详解tensorflow训练自己的数据集实现CNN图像分类

因此,我们需要将参数定义在模型里面,加载训练好的模型参数时,训练好的参数才能够真正初始化模型。重写模型函数如下

def mmodel(images,batch_size):
 with tf.variable_scope('conv1') as scope:
  weights = tf.get_variable('weights', 
         shape = [3,3,3, 16],
         dtype = tf.float32, 
         initializer=tf.truncated_normal_initializer(stddev=0.1,dtype=tf.float32))
  biases = tf.get_variable('biases', 
         shape=[16],
         dtype=tf.float32,
         initializer=tf.constant_initializer(0.1))
  conv = tf.nn.conv2d(images, weights, strides=[1,1,1,1], padding='SAME')
  pre_activation = tf.nn.bias_add(conv, biases)
  conv1 = tf.nn.relu(pre_activation, name= scope.name)
 with tf.variable_scope('pooling1_lrn') as scope:
  pool1 = tf.nn.max_pool(conv1, ksize=[1,2,2,1],strides=[1,2,2,1],
        padding='SAME', name='pooling1')
  norm1 = tf.nn.lrn(pool1, depth_radius=4, bias=1.0, alpha=0.001/9.0,
       beta=0.75,name='norm1')
 with tf.variable_scope('conv2') as scope:
  weights = tf.get_variable('weights',
         shape=[3,3,16,128],
         dtype=tf.float32,
         initializer=tf.truncated_normal_initializer(stddev=0.1,dtype=tf.float32))
  biases = tf.get_variable('biases',
         shape=[128], 
         dtype=tf.float32,
         initializer=tf.constant_initializer(0.1))
  conv = tf.nn.conv2d(norm1, weights, strides=[1,1,1,1],padding='SAME')
  pre_activation = tf.nn.bias_add(conv, biases)
  conv2 = tf.nn.relu(pre_activation, name='conv2') 
 with tf.variable_scope('pooling2_lrn') as scope:
  norm2 = tf.nn.lrn(conv2, depth_radius=4, bias=1.0, alpha=0.001/9.0,
       beta=0.75,name='norm2')
  pool2 = tf.nn.max_pool(norm2, ksize=[1,2,2,1], strides=[1,1,1,1],
        padding='SAME',name='pooling2')
 with tf.variable_scope('local3') as scope:
  reshape = tf.reshape(pool2, shape=[batch_size, -1])
  dim = reshape.get_shape()[1].value
  weights = tf.get_variable('weights',
         shape=[dim,4096],
         dtype=tf.float32,
         initializer=tf.truncated_normal_initializer(stddev=0.005,dtype=tf.float32))
  biases = tf.get_variable('biases',
         shape=[4096],
         dtype=tf.float32, 
         initializer=tf.constant_initializer(0.1))
  local3 = tf.nn.relu(tf.matmul(reshape, weights) + biases, name=scope.name) 
 with tf.variable_scope('softmax_linear') as scope:
  weights = tf.get_variable('softmax_linear',
         shape=[4096, 2],
         dtype=tf.float32,
         initializer=tf.truncated_normal_initializer(stddev=0.005,dtype=tf.float32))
  biases = tf.get_variable('biases', 
         shape=[2],
         dtype=tf.float32, 
         initializer=tf.constant_initializer(0.1))
  softmax_linear = tf.add(tf.matmul(local3, weights), biases, name='softmax_linear')
 return softmax_linear

测试训练好的模型

首先获取一张测试图像

def get_one_image(img_dir):
  image = Image.open(img_dir)
  plt.imshow(image)
  image = image.resize([32, 32])
  image_arr = np.array(image)
  return image_arr

加载模型,计算测试结果

def test(test_file):
 log_dir = 'C:/Users/wk/Desktop/bky/log/'
 image_arr = get_one_image(test_file)
 
 with tf.Graph().as_default():
  image = tf.cast(image_arr, tf.float32)
  image = tf.image.per_image_standardization(image)
  image = tf.reshape(image, [1,32, 32, 3])
  print(image.shape)
  p = model.mmodel(image,1)
  logits = tf.nn.softmax(p)
  x = tf.placeholder(tf.float32,shape = [32,32,3])
  saver = tf.train.Saver()
  with tf.Session() as sess:
   ckpt = tf.train.get_checkpoint_state(log_dir)
   if ckpt and ckpt.model_checkpoint_path:
    global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]
    saver.restore(sess, ckpt.model_checkpoint_path)
    print('Loading success)
   else:
    print('No checkpoint')
   prediction = sess.run(logits, feed_dict={x: image_arr})
   max_index = np.argmax(prediction)
   print(max_index)

前面主要是将测试图片标准化为网络的输入图像,15-19是加载模型文件,然后将图像输入到模型里即可

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python编写生成验证码的脚本的教程
May 04 Python
用python实现简单EXCEL数据统计的实例
Jan 24 Python
python操作oracle的完整教程分享
Jan 30 Python
python opencv 读取本地视频文件 修改ffmpeg的方法
Jan 26 Python
Python按钮的响应事件详解
Mar 04 Python
python turtle库画一个方格和圆实例
Jun 27 Python
Django缓存系统实现过程解析
Aug 02 Python
Python enumerate函数遍历数据对象组合过程解析
Dec 11 Python
python pycharm最新版本激活码(永久有效)附python安装教程
Sep 18 Python
tensorflow使用range_input_producer多线程读取数据实例
Jan 20 Python
Python openpyxl 插入折线图实例
Apr 17 Python
python中Array和DataFrame相互转换的实例讲解
Feb 03 Python
全面分析Python的优点和缺点
Feb 07 #Python
Tensorflow环境搭建的方法步骤
Feb 07 #Python
Python pandas常用函数详解
Feb 07 #Python
详解python字节码
Feb 07 #Python
Tensorflow之构建自己的图片数据集TFrecords的方法
Feb 07 #Python
python深度优先搜索和广度优先搜索
Feb 07 #Python
Python Flask基础教程示例代码
Feb 07 #Python
You might like
PHP仿盗链代码
2012/06/03 PHP
php连接oracle数据库的核心步骤
2016/05/26 PHP
PHP自定义函数实现格式化秒的方法
2016/09/14 PHP
PHP利用超级全局变量$_POST来接收表单数据的实例
2016/11/05 PHP
php str_getcsv把字符串解析为数组的实现方法
2017/04/05 PHP
PHP实现批量修改文件名的方法示例
2019/09/18 PHP
JavaScript CSS修改学习第一章 查找位置
2010/02/19 Javascript
jQuery each()小议
2010/03/18 Javascript
jquery获取下拉列表的值为null的解决方法
2011/03/18 Javascript
12306验证码破解思路分享
2015/03/25 Javascript
jQuery实现仿新浪微博浮动的消息提示框(可智能定位)
2015/10/10 Javascript
原生JavaScript实现滚动条效果
2020/03/24 Javascript
js实现模糊匹配功能
2017/02/15 Javascript
浅谈原型对象的常用开发模式
2017/07/22 Javascript
使用JS模拟锚点跳转的实例
2018/02/01 Javascript
angularJs-$http实现百度搜索时的动态下拉框示例
2018/02/27 Javascript
vue.js实现的绑定class操作示例
2018/07/06 Javascript
BootStrap table实现表格行拖拽效果
2018/12/01 Javascript
微信小程序实现购物页面左右联动
2019/02/15 Javascript
H5实现手机拍照和选择上传功能
2019/12/18 Javascript
用云开发Cloudbase实现小程序多图片内容安全监测的代码详解
2020/06/07 Javascript
11个Javascript小技巧帮你提升代码质量(小结)
2020/12/28 Javascript
Python的Flask框架中的Jinja2模板引擎学习教程
2016/06/30 Python
关于反爬虫的一些简单总结
2017/12/13 Python
Python中反射和描述器总结
2018/09/23 Python
python入门之基础语法学习笔记
2020/02/08 Python
Python统计学一数据的概括性度量详解
2020/03/03 Python
python实现图片,视频人脸识别(dlib版)
2020/11/18 Python
25个CSS3动画按钮和菜单教程分享
2012/10/03 HTML / CSS
突破canvas语法限制 让他支持链式语法
2012/12/24 HTML / CSS
世界上最大的售后摩托车零配件超市:J&P Cycles
2017/12/08 全球购物
PHP经典面试题
2016/09/03 面试题
写好自荐信需做到的5要点
2014/03/07 职场文书
竞选学生会主席演讲稿
2014/04/24 职场文书
2014年护理工作总结范文
2014/11/14 职场文书
Python3.8官网文档之类的基础语法阅读
2021/09/04 Python