tensorflow 1.0用CNN进行图像分类


Posted in Python onApril 15, 2018

tensorflow升级到1.0之后,增加了一些高级模块: 如tf.layers, tf.metrics, 和tf.losses,使得代码稍微有些简化。

任务:花卉分类

版本:tensorflow 1.0

数据:flower-photos

花总共有五类,分别放在5个文件夹下。

闲话不多说,直接上代码,希望大家能看懂:)

复制代码

# -*- coding: utf-8 -*-

from skimage import io,transform
import glob
import os
import tensorflow as tf
import numpy as np
import time

path='e:/flower/'

#将所有的图片resize成100*100
w=100
h=100
c=3


#读取图片
def read_img(path):
 cate=[path+x for x in os.listdir(path) if os.path.isdir(path+x)]
 imgs=[]
 labels=[]
 for idx,folder in enumerate(cate):
  for im in glob.glob(folder+'/*.jpg'):
   print('reading the images:%s'%(im))
   img=io.imread(im)
   img=transform.resize(img,(w,h))
   imgs.append(img)
   labels.append(idx)
 return np.asarray(imgs,np.float32),np.asarray(labels,np.int32)
data,label=read_img(path)


#打乱顺序
num_example=data.shape[0]
arr=np.arange(num_example)
np.random.shuffle(arr)
data=data[arr]
label=label[arr]


#将所有数据分为训练集和验证集
ratio=0.8
s=np.int(num_example*ratio)
x_train=data[:s]
y_train=label[:s]
x_val=data[s:]
y_val=label[s:]

#-----------------构建网络----------------------
#占位符
x=tf.placeholder(tf.float32,shape=[None,w,h,c],name='x')
y_=tf.placeholder(tf.int32,shape=[None,],name='y_')

#第一个卷积层(100——>50)
conv1=tf.layers.conv2d(
  inputs=x,
  filters=32,
  kernel_size=[5, 5],
  padding="same",
  activation=tf.nn.relu,
  kernel_initializer=tf.truncated_normal_initializer(stddev=0.01))
pool1=tf.layers.max_pooling2d(inputs=conv1, pool_size=[2, 2], strides=2)

#第二个卷积层(50->25)
conv2=tf.layers.conv2d(
  inputs=pool1,
  filters=64,
  kernel_size=[5, 5],
  padding="same",
  activation=tf.nn.relu,
  kernel_initializer=tf.truncated_normal_initializer(stddev=0.01))
pool2=tf.layers.max_pooling2d(inputs=conv2, pool_size=[2, 2], strides=2)

#第三个卷积层(25->12)
conv3=tf.layers.conv2d(
  inputs=pool2,
  filters=128,
  kernel_size=[3, 3],
  padding="same",
  activation=tf.nn.relu,
  kernel_initializer=tf.truncated_normal_initializer(stddev=0.01))
pool3=tf.layers.max_pooling2d(inputs=conv3, pool_size=[2, 2], strides=2)

#第四个卷积层(12->6)
conv4=tf.layers.conv2d(
  inputs=pool3,
  filters=128,
  kernel_size=[3, 3],
  padding="same",
  activation=tf.nn.relu,
  kernel_initializer=tf.truncated_normal_initializer(stddev=0.01))
pool4=tf.layers.max_pooling2d(inputs=conv4, pool_size=[2, 2], strides=2)

re1 = tf.reshape(pool4, [-1, 6 * 6 * 128])

#全连接层
dense1 = tf.layers.dense(inputs=re1, 
      units=1024, 
      activation=tf.nn.relu,
      kernel_initializer=tf.truncated_normal_initializer(stddev=0.01),
      kernel_regularizer=tf.contrib.layers.l2_regularizer(0.003))
dense2= tf.layers.dense(inputs=dense1, 
      units=512, 
      activation=tf.nn.relu,
      kernel_initializer=tf.truncated_normal_initializer(stddev=0.01),
      kernel_regularizer=tf.contrib.layers.l2_regularizer(0.003))
logits= tf.layers.dense(inputs=dense2, 
      units=5, 
      activation=None,
      kernel_initializer=tf.truncated_normal_initializer(stddev=0.01),
      kernel_regularizer=tf.contrib.layers.l2_regularizer(0.003))
#---------------------------网络结束---------------------------

loss=tf.losses.sparse_softmax_cross_entropy(labels=y_,logits=logits)
train_op=tf.train.AdamOptimizer(learning_rate=0.001).minimize(loss)
correct_prediction = tf.equal(tf.cast(tf.argmax(logits,1),tf.int32), y_) 
acc= tf.reduce_mean(tf.cast(correct_prediction, tf.float32))


#定义一个函数,按批次取数据
def minibatches(inputs=None, targets=None, batch_size=None, shuffle=False):
 assert len(inputs) == len(targets)
 if shuffle:
  indices = np.arange(len(inputs))
  np.random.shuffle(indices)
 for start_idx in range(0, len(inputs) - batch_size + 1, batch_size):
  if shuffle:
   excerpt = indices[start_idx:start_idx + batch_size]
  else:
   excerpt = slice(start_idx, start_idx + batch_size)
  yield inputs[excerpt], targets[excerpt]


#训练和测试数据,可将n_epoch设置更大一些

n_epoch=10
batch_size=64
sess=tf.InteractiveSession() 
sess.run(tf.global_variables_initializer())
for epoch in range(n_epoch):
 start_time = time.time()
 
 #training
 train_loss, train_acc, n_batch = 0, 0, 0
 for x_train_a, y_train_a in minibatches(x_train, y_train, batch_size, shuffle=True):
  _,err,ac=sess.run([train_op,loss,acc], feed_dict={x: x_train_a, y_: y_train_a})
  train_loss += err; train_acc += ac; n_batch += 1
 print(" train loss: %f" % (train_loss/ n_batch))
 print(" train acc: %f" % (train_acc/ n_batch))
 
 #validation
 val_loss, val_acc, n_batch = 0, 0, 0
 for x_val_a, y_val_a in minibatches(x_val, y_val, batch_size, shuffle=False):
  err, ac = sess.run([loss,acc], feed_dict={x: x_val_a, y_: y_val_a})
  val_loss += err; val_acc += ac; n_batch += 1
 print(" validation loss: %f" % (val_loss/ n_batch))
 print(" validation acc: %f" % (val_acc/ n_batch))

sess.close()

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python多线程http下载实现示例
Dec 30 Python
Python中针对函数处理的特殊方法
Mar 06 Python
python中assert用法实例分析
Apr 30 Python
在Django的session中使用User对象的方法
Jul 23 Python
利用Python暴力破解zip文件口令的方法详解
Dec 21 Python
python监控键盘输入实例代码
Feb 09 Python
Python 查看list中是否含有某元素的方法
Jun 27 Python
使用python采集脚本之家电子书资源并自动下载到本地的实例脚本
Oct 23 Python
Python3.7 读取 mp3 音频文件生成波形图效果
Nov 05 Python
python批量处理txt文件的实例代码
Jan 13 Python
Python基于read(size)方法读取超大文件
Mar 12 Python
2021年最新用于图像处理的Python库总结
Jun 15 Python
tensorflow学习笔记之mnist的卷积神经网络实例
Apr 15 #Python
tensorflow学习笔记之简单的神经网络训练和测试
Apr 15 #Python
Pytorch入门之mnist分类实例
Apr 14 #Python
pytorch构建网络模型的4种方法
Apr 13 #Python
Python输入二维数组方法
Apr 13 #Python
Python基于递归实现电话号码映射功能示例
Apr 13 #Python
Python的多维空数组赋值方法
Apr 13 #Python
You might like
提高PHP性能的编码技巧以及性能优化详细解析
2013/08/24 PHP
10 个经典PHP函数
2013/10/17 PHP
Zend Framework框架路由机制代码分析
2016/03/22 PHP
PHP unlink与rmdir删除目录及目录下所有文件实例代码
2018/02/07 PHP
PHP7 echo和print语句实例用法
2019/02/15 PHP
PHP For循环字母A-Z当超过26个字母时输出AA,AB,AC
2020/02/16 PHP
js根据给定的日期计算当月有多少天实现思路及代码
2013/02/25 Javascript
js/jquery解析json和数组格式的方法详解
2014/01/09 Javascript
jQuery页面刷新(局部、全部)问题分析
2016/01/09 Javascript
JS hashMap实例详解
2016/05/26 Javascript
Vue.js进行查询操作的实例详解
2017/08/25 Javascript
jQuery实现的表格前端排序功能示例
2017/09/18 jQuery
基于vue的换肤功能的示例代码
2017/10/10 Javascript
jQuery实现页码跳转式动态数据分页
2017/12/31 jQuery
微信小程序页面间值传递的两种方法
2018/11/26 Javascript
vue生命周期的探索
2019/04/03 Javascript
详解vue-cli 脚手架 安装
2019/04/16 Javascript
Angular封装表单控件及思想总结
2019/12/11 Javascript
js实现石头剪刀布游戏
2020/10/11 Javascript
VUE前端从后台请求过来的数据进行转换数据结构操作
2020/11/11 Javascript
JavaScript用document.write()输出换行的示例代码
2020/11/26 Javascript
Python发送form-data请求及拼接form-data内容的方法
2016/03/05 Python
Python实现HTTP协议下的文件下载方法总结
2016/04/20 Python
python opencv之SIFT算法示例
2018/02/24 Python
Python基于滑动平均思想实现缺失数据填充的方法
2019/02/21 Python
详解Python Matplot中文显示完美解决方案
2019/03/07 Python
Pytorch之Variable的用法
2019/12/31 Python
Python如何执行系统命令
2020/09/23 Python
用纯css3和html制作泡沫对话框实现代码
2013/03/21 HTML / CSS
HTML5中的Scoped属性使用实例
2014/04/23 HTML / CSS
建筑设计所实习生自我鉴定
2013/09/25 职场文书
学生会宣传部部长竞选演讲稿
2014/04/25 职场文书
四群教育工作总结
2015/08/10 职场文书
2019个人工作态度自我评价
2019/04/24 职场文书
深入理解redis中multi与pipeline
2021/06/02 Redis
详解Python中下划线的5种含义
2021/07/15 Python