tensorflow 1.0用CNN进行图像分类


Posted in Python onApril 15, 2018

tensorflow升级到1.0之后,增加了一些高级模块: 如tf.layers, tf.metrics, 和tf.losses,使得代码稍微有些简化。

任务:花卉分类

版本:tensorflow 1.0

数据:flower-photos

花总共有五类,分别放在5个文件夹下。

闲话不多说,直接上代码,希望大家能看懂:)

复制代码

# -*- coding: utf-8 -*-

from skimage import io,transform
import glob
import os
import tensorflow as tf
import numpy as np
import time

path='e:/flower/'

#将所有的图片resize成100*100
w=100
h=100
c=3


#读取图片
def read_img(path):
 cate=[path+x for x in os.listdir(path) if os.path.isdir(path+x)]
 imgs=[]
 labels=[]
 for idx,folder in enumerate(cate):
  for im in glob.glob(folder+'/*.jpg'):
   print('reading the images:%s'%(im))
   img=io.imread(im)
   img=transform.resize(img,(w,h))
   imgs.append(img)
   labels.append(idx)
 return np.asarray(imgs,np.float32),np.asarray(labels,np.int32)
data,label=read_img(path)


#打乱顺序
num_example=data.shape[0]
arr=np.arange(num_example)
np.random.shuffle(arr)
data=data[arr]
label=label[arr]


#将所有数据分为训练集和验证集
ratio=0.8
s=np.int(num_example*ratio)
x_train=data[:s]
y_train=label[:s]
x_val=data[s:]
y_val=label[s:]

#-----------------构建网络----------------------
#占位符
x=tf.placeholder(tf.float32,shape=[None,w,h,c],name='x')
y_=tf.placeholder(tf.int32,shape=[None,],name='y_')

#第一个卷积层(100——>50)
conv1=tf.layers.conv2d(
  inputs=x,
  filters=32,
  kernel_size=[5, 5],
  padding="same",
  activation=tf.nn.relu,
  kernel_initializer=tf.truncated_normal_initializer(stddev=0.01))
pool1=tf.layers.max_pooling2d(inputs=conv1, pool_size=[2, 2], strides=2)

#第二个卷积层(50->25)
conv2=tf.layers.conv2d(
  inputs=pool1,
  filters=64,
  kernel_size=[5, 5],
  padding="same",
  activation=tf.nn.relu,
  kernel_initializer=tf.truncated_normal_initializer(stddev=0.01))
pool2=tf.layers.max_pooling2d(inputs=conv2, pool_size=[2, 2], strides=2)

#第三个卷积层(25->12)
conv3=tf.layers.conv2d(
  inputs=pool2,
  filters=128,
  kernel_size=[3, 3],
  padding="same",
  activation=tf.nn.relu,
  kernel_initializer=tf.truncated_normal_initializer(stddev=0.01))
pool3=tf.layers.max_pooling2d(inputs=conv3, pool_size=[2, 2], strides=2)

#第四个卷积层(12->6)
conv4=tf.layers.conv2d(
  inputs=pool3,
  filters=128,
  kernel_size=[3, 3],
  padding="same",
  activation=tf.nn.relu,
  kernel_initializer=tf.truncated_normal_initializer(stddev=0.01))
pool4=tf.layers.max_pooling2d(inputs=conv4, pool_size=[2, 2], strides=2)

re1 = tf.reshape(pool4, [-1, 6 * 6 * 128])

#全连接层
dense1 = tf.layers.dense(inputs=re1, 
      units=1024, 
      activation=tf.nn.relu,
      kernel_initializer=tf.truncated_normal_initializer(stddev=0.01),
      kernel_regularizer=tf.contrib.layers.l2_regularizer(0.003))
dense2= tf.layers.dense(inputs=dense1, 
      units=512, 
      activation=tf.nn.relu,
      kernel_initializer=tf.truncated_normal_initializer(stddev=0.01),
      kernel_regularizer=tf.contrib.layers.l2_regularizer(0.003))
logits= tf.layers.dense(inputs=dense2, 
      units=5, 
      activation=None,
      kernel_initializer=tf.truncated_normal_initializer(stddev=0.01),
      kernel_regularizer=tf.contrib.layers.l2_regularizer(0.003))
#---------------------------网络结束---------------------------

loss=tf.losses.sparse_softmax_cross_entropy(labels=y_,logits=logits)
train_op=tf.train.AdamOptimizer(learning_rate=0.001).minimize(loss)
correct_prediction = tf.equal(tf.cast(tf.argmax(logits,1),tf.int32), y_) 
acc= tf.reduce_mean(tf.cast(correct_prediction, tf.float32))


#定义一个函数,按批次取数据
def minibatches(inputs=None, targets=None, batch_size=None, shuffle=False):
 assert len(inputs) == len(targets)
 if shuffle:
  indices = np.arange(len(inputs))
  np.random.shuffle(indices)
 for start_idx in range(0, len(inputs) - batch_size + 1, batch_size):
  if shuffle:
   excerpt = indices[start_idx:start_idx + batch_size]
  else:
   excerpt = slice(start_idx, start_idx + batch_size)
  yield inputs[excerpt], targets[excerpt]


#训练和测试数据,可将n_epoch设置更大一些

n_epoch=10
batch_size=64
sess=tf.InteractiveSession() 
sess.run(tf.global_variables_initializer())
for epoch in range(n_epoch):
 start_time = time.time()
 
 #training
 train_loss, train_acc, n_batch = 0, 0, 0
 for x_train_a, y_train_a in minibatches(x_train, y_train, batch_size, shuffle=True):
  _,err,ac=sess.run([train_op,loss,acc], feed_dict={x: x_train_a, y_: y_train_a})
  train_loss += err; train_acc += ac; n_batch += 1
 print(" train loss: %f" % (train_loss/ n_batch))
 print(" train acc: %f" % (train_acc/ n_batch))
 
 #validation
 val_loss, val_acc, n_batch = 0, 0, 0
 for x_val_a, y_val_a in minibatches(x_val, y_val, batch_size, shuffle=False):
  err, ac = sess.run([loss,acc], feed_dict={x: x_val_a, y_: y_val_a})
  val_loss += err; val_acc += ac; n_batch += 1
 print(" validation loss: %f" % (val_loss/ n_batch))
 print(" validation acc: %f" % (val_acc/ n_batch))

sess.close()

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python文件读写操作与linux shell变量命令交互执行的方法
Jan 14 Python
详解Python中最难理解的点-装饰器
Apr 03 Python
Python基于回溯法子集树模板解决找零问题示例
Sep 11 Python
python中 logging的使用详解
Oct 25 Python
Python异常处理操作实例详解
Aug 28 Python
Python 从相对路径下import的方法
Dec 04 Python
Django 框架模型操作入门教程
Nov 05 Python
python调用API接口实现登陆短信验证
May 10 Python
Python爬虫scrapy框架Cookie池(微博Cookie池)的使用
Jan 13 Python
深入理解Python变量的数据类型和存储
Feb 01 Python
python解决OpenCV在读取显示图片的时候闪退的问题
Feb 23 Python
FP-growth算法发现频繁项集——构建FP树
Jun 24 Python
tensorflow学习笔记之mnist的卷积神经网络实例
Apr 15 #Python
tensorflow学习笔记之简单的神经网络训练和测试
Apr 15 #Python
Pytorch入门之mnist分类实例
Apr 14 #Python
pytorch构建网络模型的4种方法
Apr 13 #Python
Python输入二维数组方法
Apr 13 #Python
Python基于递归实现电话号码映射功能示例
Apr 13 #Python
Python的多维空数组赋值方法
Apr 13 #Python
You might like
详解PHP中的null合并运算符
2015/12/30 PHP
php如何控制用户对图片的访问 PHP禁止图片盗链
2016/03/25 PHP
php7性能提升的原因详解
2019/10/13 PHP
让iframe自适应高度(支持XHTML,支持FF)
2007/07/24 Javascript
Dom操作之兼容技巧分享
2011/09/20 Javascript
js数字转换为float,取N位小数
2014/02/08 Javascript
在Firefox下js select标签点击无法弹出
2014/03/06 Javascript
javascript记录文本框内文字个数检测文字个数变化
2014/10/14 Javascript
js实现浏览器窗口大小被改变时触发事件的方法
2015/02/02 Javascript
javascript实现Email邮件显示与删除功能
2015/11/21 Javascript
JS给swf传参数的实现方法
2016/09/13 Javascript
JS实现十字坐标跟随鼠标效果
2017/12/25 Javascript
Angular单元测试之事件触发的实现
2020/01/20 Javascript
微信小程序 wx:for 与 wx:for-items 与 wx:key的正确用法
2020/05/19 Javascript
从零开始在vue-cli4配置自适应vw布局的实现
2020/06/08 Javascript
对python中数组的del,remove,pop区别详解
2018/11/07 Python
Python中栈、队列与优先级队列的实现方法
2019/06/30 Python
python使用递归的方式建立二叉树
2019/07/03 Python
python 如何将office文件转换为PDF
2020/09/22 Python
Python pip 常用命令汇总
2020/10/19 Python
python实现文件分片上传的接口自动化
2020/11/19 Python
Python爬取酷狗MP3音频的步骤
2021/02/26 Python
新闻专业个人自我评价
2013/09/21 职场文书
大学生村官工作感言
2014/01/10 职场文书
美发活动策划书
2014/01/14 职场文书
党员学习十八大感想
2014/01/17 职场文书
荷叶圆圆教学反思
2014/02/01 职场文书
安全生产活动月方案
2014/03/09 职场文书
2014年大学生四年规划书范文
2014/04/03 职场文书
合作协议书
2014/04/23 职场文书
入职担保书怎么写
2014/05/12 职场文书
群众路线教育党员自我剖析材料
2014/10/06 职场文书
大学生个人学习总结
2015/02/15 职场文书
诚信考试承诺书范文
2015/04/29 职场文书
Python中生成随机数据安全性、多功能性、用途和速度方面进行比较
2022/04/14 Python
MySQL数据库实验实现简单数据库应用系统设计
2022/06/21 MySQL