tensorflow 1.0用CNN进行图像分类


Posted in Python onApril 15, 2018

tensorflow升级到1.0之后,增加了一些高级模块: 如tf.layers, tf.metrics, 和tf.losses,使得代码稍微有些简化。

任务:花卉分类

版本:tensorflow 1.0

数据:flower-photos

花总共有五类,分别放在5个文件夹下。

闲话不多说,直接上代码,希望大家能看懂:)

复制代码

# -*- coding: utf-8 -*-

from skimage import io,transform
import glob
import os
import tensorflow as tf
import numpy as np
import time

path='e:/flower/'

#将所有的图片resize成100*100
w=100
h=100
c=3


#读取图片
def read_img(path):
 cate=[path+x for x in os.listdir(path) if os.path.isdir(path+x)]
 imgs=[]
 labels=[]
 for idx,folder in enumerate(cate):
  for im in glob.glob(folder+'/*.jpg'):
   print('reading the images:%s'%(im))
   img=io.imread(im)
   img=transform.resize(img,(w,h))
   imgs.append(img)
   labels.append(idx)
 return np.asarray(imgs,np.float32),np.asarray(labels,np.int32)
data,label=read_img(path)


#打乱顺序
num_example=data.shape[0]
arr=np.arange(num_example)
np.random.shuffle(arr)
data=data[arr]
label=label[arr]


#将所有数据分为训练集和验证集
ratio=0.8
s=np.int(num_example*ratio)
x_train=data[:s]
y_train=label[:s]
x_val=data[s:]
y_val=label[s:]

#-----------------构建网络----------------------
#占位符
x=tf.placeholder(tf.float32,shape=[None,w,h,c],name='x')
y_=tf.placeholder(tf.int32,shape=[None,],name='y_')

#第一个卷积层(100——>50)
conv1=tf.layers.conv2d(
  inputs=x,
  filters=32,
  kernel_size=[5, 5],
  padding="same",
  activation=tf.nn.relu,
  kernel_initializer=tf.truncated_normal_initializer(stddev=0.01))
pool1=tf.layers.max_pooling2d(inputs=conv1, pool_size=[2, 2], strides=2)

#第二个卷积层(50->25)
conv2=tf.layers.conv2d(
  inputs=pool1,
  filters=64,
  kernel_size=[5, 5],
  padding="same",
  activation=tf.nn.relu,
  kernel_initializer=tf.truncated_normal_initializer(stddev=0.01))
pool2=tf.layers.max_pooling2d(inputs=conv2, pool_size=[2, 2], strides=2)

#第三个卷积层(25->12)
conv3=tf.layers.conv2d(
  inputs=pool2,
  filters=128,
  kernel_size=[3, 3],
  padding="same",
  activation=tf.nn.relu,
  kernel_initializer=tf.truncated_normal_initializer(stddev=0.01))
pool3=tf.layers.max_pooling2d(inputs=conv3, pool_size=[2, 2], strides=2)

#第四个卷积层(12->6)
conv4=tf.layers.conv2d(
  inputs=pool3,
  filters=128,
  kernel_size=[3, 3],
  padding="same",
  activation=tf.nn.relu,
  kernel_initializer=tf.truncated_normal_initializer(stddev=0.01))
pool4=tf.layers.max_pooling2d(inputs=conv4, pool_size=[2, 2], strides=2)

re1 = tf.reshape(pool4, [-1, 6 * 6 * 128])

#全连接层
dense1 = tf.layers.dense(inputs=re1, 
      units=1024, 
      activation=tf.nn.relu,
      kernel_initializer=tf.truncated_normal_initializer(stddev=0.01),
      kernel_regularizer=tf.contrib.layers.l2_regularizer(0.003))
dense2= tf.layers.dense(inputs=dense1, 
      units=512, 
      activation=tf.nn.relu,
      kernel_initializer=tf.truncated_normal_initializer(stddev=0.01),
      kernel_regularizer=tf.contrib.layers.l2_regularizer(0.003))
logits= tf.layers.dense(inputs=dense2, 
      units=5, 
      activation=None,
      kernel_initializer=tf.truncated_normal_initializer(stddev=0.01),
      kernel_regularizer=tf.contrib.layers.l2_regularizer(0.003))
#---------------------------网络结束---------------------------

loss=tf.losses.sparse_softmax_cross_entropy(labels=y_,logits=logits)
train_op=tf.train.AdamOptimizer(learning_rate=0.001).minimize(loss)
correct_prediction = tf.equal(tf.cast(tf.argmax(logits,1),tf.int32), y_) 
acc= tf.reduce_mean(tf.cast(correct_prediction, tf.float32))


#定义一个函数,按批次取数据
def minibatches(inputs=None, targets=None, batch_size=None, shuffle=False):
 assert len(inputs) == len(targets)
 if shuffle:
  indices = np.arange(len(inputs))
  np.random.shuffle(indices)
 for start_idx in range(0, len(inputs) - batch_size + 1, batch_size):
  if shuffle:
   excerpt = indices[start_idx:start_idx + batch_size]
  else:
   excerpt = slice(start_idx, start_idx + batch_size)
  yield inputs[excerpt], targets[excerpt]


#训练和测试数据,可将n_epoch设置更大一些

n_epoch=10
batch_size=64
sess=tf.InteractiveSession() 
sess.run(tf.global_variables_initializer())
for epoch in range(n_epoch):
 start_time = time.time()
 
 #training
 train_loss, train_acc, n_batch = 0, 0, 0
 for x_train_a, y_train_a in minibatches(x_train, y_train, batch_size, shuffle=True):
  _,err,ac=sess.run([train_op,loss,acc], feed_dict={x: x_train_a, y_: y_train_a})
  train_loss += err; train_acc += ac; n_batch += 1
 print(" train loss: %f" % (train_loss/ n_batch))
 print(" train acc: %f" % (train_acc/ n_batch))
 
 #validation
 val_loss, val_acc, n_batch = 0, 0, 0
 for x_val_a, y_val_a in minibatches(x_val, y_val, batch_size, shuffle=False):
  err, ac = sess.run([loss,acc], feed_dict={x: x_val_a, y_: y_val_a})
  val_loss += err; val_acc += ac; n_batch += 1
 print(" validation loss: %f" % (val_loss/ n_batch))
 print(" validation acc: %f" % (val_acc/ n_batch))

sess.close()

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python进阶教程之函数对象(函数也是对象)
Aug 30 Python
Python实现更改图片尺寸大小的方法(基于Pillow包)
Sep 19 Python
深入分析python数据挖掘 Json结构分析
Apr 21 Python
python生成lmdb格式的文件实例
Nov 08 Python
PyCharm 配置远程python解释器和在本地修改服务器代码
Jul 23 Python
对django layer弹窗组件的使用详解
Aug 31 Python
python 线性回归分析模型检验标准--拟合优度详解
Feb 24 Python
python针对Oracle常见查询操作实例分析
Apr 30 Python
python输出数学符号实例
May 11 Python
OpenCV Python实现图像指定区域裁剪
Mar 12 Python
解决Python 写文件报错TypeError的问题
Oct 23 Python
python+opencv实现车道线检测
Feb 19 Python
tensorflow学习笔记之mnist的卷积神经网络实例
Apr 15 #Python
tensorflow学习笔记之简单的神经网络训练和测试
Apr 15 #Python
Pytorch入门之mnist分类实例
Apr 14 #Python
pytorch构建网络模型的4种方法
Apr 13 #Python
Python输入二维数组方法
Apr 13 #Python
Python基于递归实现电话号码映射功能示例
Apr 13 #Python
Python的多维空数组赋值方法
Apr 13 #Python
You might like
PHP生成唯一订单号的方法汇总
2015/04/16 PHP
PHP实现获取文件后缀名的几种常用方法
2015/08/08 PHP
PHP实现通过二维数组键值获取一维键名操作示例
2019/10/11 PHP
laravel接管Dingo-api和默认的错误处理方式
2019/10/25 PHP
jQuery html() in Firefox (uses .innerHTML) ignores DOM changes
2010/03/05 Javascript
jQuery EasyUI中对表格进行编辑的实现代码
2010/06/10 Javascript
JS 控制小数位数的实现代码
2011/08/02 Javascript
Jquery EasyUI中弹出确认对话框以及加载效果示例代码
2014/02/13 Javascript
浅谈jQuery 选择器和dom操作
2016/06/07 Javascript
纯JS实现可拖拽表单的简单实例
2016/09/02 Javascript
bootstrap 下拉多选框进行多选传值问题代码分析
2017/02/14 Javascript
JS实现禁止用户使用Ctrl+鼠标滚轮缩放网页的方法
2017/04/28 Javascript
详解封装基础的angular4的request请求方法
2018/06/05 Javascript
JavaScript原型对象、构造函数和实例对象功能与用法详解
2018/08/04 Javascript
js异步上传多张图片插件的使用方法
2018/10/22 Javascript
vue2.0移动端滑动事件vue-touch的实例代码
2018/11/27 Javascript
vue elementUI使用tabs与导航栏联动
2019/06/21 Javascript
微信小程序基于Taro的分享图片功能实践详解
2019/07/12 Javascript
JavaScript运动原理基础知识详解
2020/04/02 Javascript
带你使用webpack快速构建web项目的方法
2020/11/12 Javascript
python转换摩斯密码示例
2014/02/16 Python
Python判断文件和文件夹是否存在的方法
2015/05/21 Python
浅析Python中的多条件排序实现
2016/06/07 Python
Python中时间datetime的处理与转换用法总结
2019/02/18 Python
python tkinter窗口最大化的实现
2019/07/15 Python
pandas DataFrame行或列的删除方法的实现示例
2019/08/02 Python
python中数字是否为可变类型
2020/07/08 Python
Python数据模型与Python对象模型的相关总结
2021/01/26 Python
HTML5 Notification(桌面提醒)功能使用实例
2014/03/17 HTML / CSS
海洋天堂观后感
2015/06/05 职场文书
数学备课组工作总结
2015/08/12 职场文书
运动会广播稿20字
2015/08/19 职场文书
公安干警正风肃纪心得体会
2016/01/15 职场文书
优秀新员工事迹材料
2019/05/13 职场文书
Golang之sync.Pool使用详解
2021/05/06 Golang
JavaScript模拟实现网易云轮播效果
2022/04/04 Javascript