tensorflow 1.0用CNN进行图像分类


Posted in Python onApril 15, 2018

tensorflow升级到1.0之后,增加了一些高级模块: 如tf.layers, tf.metrics, 和tf.losses,使得代码稍微有些简化。

任务:花卉分类

版本:tensorflow 1.0

数据:flower-photos

花总共有五类,分别放在5个文件夹下。

闲话不多说,直接上代码,希望大家能看懂:)

复制代码

# -*- coding: utf-8 -*-

from skimage import io,transform
import glob
import os
import tensorflow as tf
import numpy as np
import time

path='e:/flower/'

#将所有的图片resize成100*100
w=100
h=100
c=3


#读取图片
def read_img(path):
 cate=[path+x for x in os.listdir(path) if os.path.isdir(path+x)]
 imgs=[]
 labels=[]
 for idx,folder in enumerate(cate):
  for im in glob.glob(folder+'/*.jpg'):
   print('reading the images:%s'%(im))
   img=io.imread(im)
   img=transform.resize(img,(w,h))
   imgs.append(img)
   labels.append(idx)
 return np.asarray(imgs,np.float32),np.asarray(labels,np.int32)
data,label=read_img(path)


#打乱顺序
num_example=data.shape[0]
arr=np.arange(num_example)
np.random.shuffle(arr)
data=data[arr]
label=label[arr]


#将所有数据分为训练集和验证集
ratio=0.8
s=np.int(num_example*ratio)
x_train=data[:s]
y_train=label[:s]
x_val=data[s:]
y_val=label[s:]

#-----------------构建网络----------------------
#占位符
x=tf.placeholder(tf.float32,shape=[None,w,h,c],name='x')
y_=tf.placeholder(tf.int32,shape=[None,],name='y_')

#第一个卷积层(100——>50)
conv1=tf.layers.conv2d(
  inputs=x,
  filters=32,
  kernel_size=[5, 5],
  padding="same",
  activation=tf.nn.relu,
  kernel_initializer=tf.truncated_normal_initializer(stddev=0.01))
pool1=tf.layers.max_pooling2d(inputs=conv1, pool_size=[2, 2], strides=2)

#第二个卷积层(50->25)
conv2=tf.layers.conv2d(
  inputs=pool1,
  filters=64,
  kernel_size=[5, 5],
  padding="same",
  activation=tf.nn.relu,
  kernel_initializer=tf.truncated_normal_initializer(stddev=0.01))
pool2=tf.layers.max_pooling2d(inputs=conv2, pool_size=[2, 2], strides=2)

#第三个卷积层(25->12)
conv3=tf.layers.conv2d(
  inputs=pool2,
  filters=128,
  kernel_size=[3, 3],
  padding="same",
  activation=tf.nn.relu,
  kernel_initializer=tf.truncated_normal_initializer(stddev=0.01))
pool3=tf.layers.max_pooling2d(inputs=conv3, pool_size=[2, 2], strides=2)

#第四个卷积层(12->6)
conv4=tf.layers.conv2d(
  inputs=pool3,
  filters=128,
  kernel_size=[3, 3],
  padding="same",
  activation=tf.nn.relu,
  kernel_initializer=tf.truncated_normal_initializer(stddev=0.01))
pool4=tf.layers.max_pooling2d(inputs=conv4, pool_size=[2, 2], strides=2)

re1 = tf.reshape(pool4, [-1, 6 * 6 * 128])

#全连接层
dense1 = tf.layers.dense(inputs=re1, 
      units=1024, 
      activation=tf.nn.relu,
      kernel_initializer=tf.truncated_normal_initializer(stddev=0.01),
      kernel_regularizer=tf.contrib.layers.l2_regularizer(0.003))
dense2= tf.layers.dense(inputs=dense1, 
      units=512, 
      activation=tf.nn.relu,
      kernel_initializer=tf.truncated_normal_initializer(stddev=0.01),
      kernel_regularizer=tf.contrib.layers.l2_regularizer(0.003))
logits= tf.layers.dense(inputs=dense2, 
      units=5, 
      activation=None,
      kernel_initializer=tf.truncated_normal_initializer(stddev=0.01),
      kernel_regularizer=tf.contrib.layers.l2_regularizer(0.003))
#---------------------------网络结束---------------------------

loss=tf.losses.sparse_softmax_cross_entropy(labels=y_,logits=logits)
train_op=tf.train.AdamOptimizer(learning_rate=0.001).minimize(loss)
correct_prediction = tf.equal(tf.cast(tf.argmax(logits,1),tf.int32), y_) 
acc= tf.reduce_mean(tf.cast(correct_prediction, tf.float32))


#定义一个函数,按批次取数据
def minibatches(inputs=None, targets=None, batch_size=None, shuffle=False):
 assert len(inputs) == len(targets)
 if shuffle:
  indices = np.arange(len(inputs))
  np.random.shuffle(indices)
 for start_idx in range(0, len(inputs) - batch_size + 1, batch_size):
  if shuffle:
   excerpt = indices[start_idx:start_idx + batch_size]
  else:
   excerpt = slice(start_idx, start_idx + batch_size)
  yield inputs[excerpt], targets[excerpt]


#训练和测试数据,可将n_epoch设置更大一些

n_epoch=10
batch_size=64
sess=tf.InteractiveSession() 
sess.run(tf.global_variables_initializer())
for epoch in range(n_epoch):
 start_time = time.time()
 
 #training
 train_loss, train_acc, n_batch = 0, 0, 0
 for x_train_a, y_train_a in minibatches(x_train, y_train, batch_size, shuffle=True):
  _,err,ac=sess.run([train_op,loss,acc], feed_dict={x: x_train_a, y_: y_train_a})
  train_loss += err; train_acc += ac; n_batch += 1
 print(" train loss: %f" % (train_loss/ n_batch))
 print(" train acc: %f" % (train_acc/ n_batch))
 
 #validation
 val_loss, val_acc, n_batch = 0, 0, 0
 for x_val_a, y_val_a in minibatches(x_val, y_val, batch_size, shuffle=False):
  err, ac = sess.run([loss,acc], feed_dict={x: x_val_a, y_: y_val_a})
  val_loss += err; val_acc += ac; n_batch += 1
 print(" validation loss: %f" % (val_loss/ n_batch))
 print(" validation acc: %f" % (val_acc/ n_batch))

sess.close()

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python使用minidom读写xml的方法
Jun 03 Python
Python做简单的字符串匹配详解
Mar 21 Python
Python 中urls.py:URL dispatcher(路由配置文件)详解
Mar 24 Python
numpy添加新的维度:newaxis的方法
Aug 02 Python
详解python运行三种方式
May 13 Python
Python3之手动创建迭代器的实例代码
May 22 Python
python selenium登录豆瓣网过程解析
Aug 10 Python
python调用matplotlib模块绘制柱状图
Oct 18 Python
Python django搭建layui提交表单,表格,图标的实例
Nov 18 Python
Python unittest工作原理和使用过程解析
Feb 24 Python
python 读txt文件,按‘,’分割每行数据操作
Jul 05 Python
8g内存用python读取10文件_面试题-python 如何读取一个大于 10G 的txt文件?
May 28 Python
tensorflow学习笔记之mnist的卷积神经网络实例
Apr 15 #Python
tensorflow学习笔记之简单的神经网络训练和测试
Apr 15 #Python
Pytorch入门之mnist分类实例
Apr 14 #Python
pytorch构建网络模型的4种方法
Apr 13 #Python
Python输入二维数组方法
Apr 13 #Python
Python基于递归实现电话号码映射功能示例
Apr 13 #Python
Python的多维空数组赋值方法
Apr 13 #Python
You might like
深入浅析PHP7.0新特征(五大新特征)
2015/10/29 PHP
twig模板常用语句实例小结
2016/02/04 PHP
IIS 7.5 asp Session超时时间设置方法
2017/04/17 PHP
PHP生成指定范围内的N个不重复的随机数
2019/03/18 PHP
使用Modello编写JavaScript类
2006/12/22 Javascript
用JavaScrpt实现文件夹简单轻松加密的实现方法图文
2008/09/08 Javascript
获得所有表单值的JQuery实现代码[IE暂不支持]
2012/05/24 Javascript
js事件冒泡实例分享(已测试)
2013/04/23 Javascript
JavaScript将页面表格导出为Excel的具体实现
2013/12/27 Javascript
jquery实现在页面加载完毕后获取图片高度或宽度
2014/06/16 Javascript
javascript操作表格排序实例分析
2015/05/06 Javascript
异步JS框架的作用以及实现方法
2015/10/29 Javascript
BootStrap入门教程(二)之固定的内置样式
2016/09/19 Javascript
微信小程序 教程之WXML
2016/10/18 Javascript
微信小程序 仿美团分类菜单 swiper分类菜单
2017/04/12 Javascript
使用Bootstrap打造特色进度条效果
2017/05/02 Javascript
vue的基本用法与常见指令
2017/08/15 Javascript
基于JavaScript实现前端数据多条件筛选功能
2020/08/19 Javascript
微信小程序6位或多位验证码密码输入框功能的实现代码
2018/05/29 Javascript
JS代码屏蔽F12,右键,粘贴,复制,剪切,选中,操作实例
2019/09/17 Javascript
layui-table获得当前行的上/下一行数据的例子
2019/09/24 Javascript
Python 异常处理实例详解
2014/03/12 Python
Python设计模式之单例模式实例
2014/04/26 Python
使用Python脚本在Linux下实现部分Bash Shell的教程
2015/04/17 Python
深入讲解Python编程中的字符串
2015/10/14 Python
python matplotlib绘图,修改坐标轴刻度为文字的实例
2018/05/25 Python
纯css3实现的竖形无限级导航
2014/12/10 HTML / CSS
瑞贝卡·明可弗包包官网:Rebecca Minkoff
2016/07/21 全球购物
XML文档定义有几种形式?它们之间有何本质区别?解析XML文档有哪几种方式?
2016/01/12 面试题
成人继续教育实施方案
2014/03/01 职场文书
领导干部廉政承诺书
2014/03/27 职场文书
单位一把手群众路线四风问题整改措施
2014/09/25 职场文书
圣诞节开幕词
2015/01/29 职场文书
2016年优秀共产党员先进事迹材料
2016/02/29 职场文书
咖啡厅里的创业计划书
2019/08/21 职场文书
Java日常练习题,每天进步一点点(38)
2021/07/26 Java/Android