tensorflow 1.0用CNN进行图像分类


Posted in Python onApril 15, 2018

tensorflow升级到1.0之后,增加了一些高级模块: 如tf.layers, tf.metrics, 和tf.losses,使得代码稍微有些简化。

任务:花卉分类

版本:tensorflow 1.0

数据:flower-photos

花总共有五类,分别放在5个文件夹下。

闲话不多说,直接上代码,希望大家能看懂:)

复制代码

# -*- coding: utf-8 -*-

from skimage import io,transform
import glob
import os
import tensorflow as tf
import numpy as np
import time

path='e:/flower/'

#将所有的图片resize成100*100
w=100
h=100
c=3


#读取图片
def read_img(path):
 cate=[path+x for x in os.listdir(path) if os.path.isdir(path+x)]
 imgs=[]
 labels=[]
 for idx,folder in enumerate(cate):
  for im in glob.glob(folder+'/*.jpg'):
   print('reading the images:%s'%(im))
   img=io.imread(im)
   img=transform.resize(img,(w,h))
   imgs.append(img)
   labels.append(idx)
 return np.asarray(imgs,np.float32),np.asarray(labels,np.int32)
data,label=read_img(path)


#打乱顺序
num_example=data.shape[0]
arr=np.arange(num_example)
np.random.shuffle(arr)
data=data[arr]
label=label[arr]


#将所有数据分为训练集和验证集
ratio=0.8
s=np.int(num_example*ratio)
x_train=data[:s]
y_train=label[:s]
x_val=data[s:]
y_val=label[s:]

#-----------------构建网络----------------------
#占位符
x=tf.placeholder(tf.float32,shape=[None,w,h,c],name='x')
y_=tf.placeholder(tf.int32,shape=[None,],name='y_')

#第一个卷积层(100——>50)
conv1=tf.layers.conv2d(
  inputs=x,
  filters=32,
  kernel_size=[5, 5],
  padding="same",
  activation=tf.nn.relu,
  kernel_initializer=tf.truncated_normal_initializer(stddev=0.01))
pool1=tf.layers.max_pooling2d(inputs=conv1, pool_size=[2, 2], strides=2)

#第二个卷积层(50->25)
conv2=tf.layers.conv2d(
  inputs=pool1,
  filters=64,
  kernel_size=[5, 5],
  padding="same",
  activation=tf.nn.relu,
  kernel_initializer=tf.truncated_normal_initializer(stddev=0.01))
pool2=tf.layers.max_pooling2d(inputs=conv2, pool_size=[2, 2], strides=2)

#第三个卷积层(25->12)
conv3=tf.layers.conv2d(
  inputs=pool2,
  filters=128,
  kernel_size=[3, 3],
  padding="same",
  activation=tf.nn.relu,
  kernel_initializer=tf.truncated_normal_initializer(stddev=0.01))
pool3=tf.layers.max_pooling2d(inputs=conv3, pool_size=[2, 2], strides=2)

#第四个卷积层(12->6)
conv4=tf.layers.conv2d(
  inputs=pool3,
  filters=128,
  kernel_size=[3, 3],
  padding="same",
  activation=tf.nn.relu,
  kernel_initializer=tf.truncated_normal_initializer(stddev=0.01))
pool4=tf.layers.max_pooling2d(inputs=conv4, pool_size=[2, 2], strides=2)

re1 = tf.reshape(pool4, [-1, 6 * 6 * 128])

#全连接层
dense1 = tf.layers.dense(inputs=re1, 
      units=1024, 
      activation=tf.nn.relu,
      kernel_initializer=tf.truncated_normal_initializer(stddev=0.01),
      kernel_regularizer=tf.contrib.layers.l2_regularizer(0.003))
dense2= tf.layers.dense(inputs=dense1, 
      units=512, 
      activation=tf.nn.relu,
      kernel_initializer=tf.truncated_normal_initializer(stddev=0.01),
      kernel_regularizer=tf.contrib.layers.l2_regularizer(0.003))
logits= tf.layers.dense(inputs=dense2, 
      units=5, 
      activation=None,
      kernel_initializer=tf.truncated_normal_initializer(stddev=0.01),
      kernel_regularizer=tf.contrib.layers.l2_regularizer(0.003))
#---------------------------网络结束---------------------------

loss=tf.losses.sparse_softmax_cross_entropy(labels=y_,logits=logits)
train_op=tf.train.AdamOptimizer(learning_rate=0.001).minimize(loss)
correct_prediction = tf.equal(tf.cast(tf.argmax(logits,1),tf.int32), y_) 
acc= tf.reduce_mean(tf.cast(correct_prediction, tf.float32))


#定义一个函数,按批次取数据
def minibatches(inputs=None, targets=None, batch_size=None, shuffle=False):
 assert len(inputs) == len(targets)
 if shuffle:
  indices = np.arange(len(inputs))
  np.random.shuffle(indices)
 for start_idx in range(0, len(inputs) - batch_size + 1, batch_size):
  if shuffle:
   excerpt = indices[start_idx:start_idx + batch_size]
  else:
   excerpt = slice(start_idx, start_idx + batch_size)
  yield inputs[excerpt], targets[excerpt]


#训练和测试数据,可将n_epoch设置更大一些

n_epoch=10
batch_size=64
sess=tf.InteractiveSession() 
sess.run(tf.global_variables_initializer())
for epoch in range(n_epoch):
 start_time = time.time()
 
 #training
 train_loss, train_acc, n_batch = 0, 0, 0
 for x_train_a, y_train_a in minibatches(x_train, y_train, batch_size, shuffle=True):
  _,err,ac=sess.run([train_op,loss,acc], feed_dict={x: x_train_a, y_: y_train_a})
  train_loss += err; train_acc += ac; n_batch += 1
 print(" train loss: %f" % (train_loss/ n_batch))
 print(" train acc: %f" % (train_acc/ n_batch))
 
 #validation
 val_loss, val_acc, n_batch = 0, 0, 0
 for x_val_a, y_val_a in minibatches(x_val, y_val, batch_size, shuffle=False):
  err, ac = sess.run([loss,acc], feed_dict={x: x_val_a, y_: y_val_a})
  val_loss += err; val_acc += ac; n_batch += 1
 print(" validation loss: %f" % (val_loss/ n_batch))
 print(" validation acc: %f" % (val_acc/ n_batch))

sess.close()

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python中文件遍历的两种方法
Jun 16 Python
Python中实现结构相似的函数调用方法
Mar 10 Python
python执行外部程序的常用方法小结
Mar 21 Python
Python 用Redis简单实现分布式爬虫的方法
Nov 23 Python
python中将正则过滤的内容输出写入到文件中的实例
Oct 21 Python
pycharm在调试python时执行其他语句的方法
Nov 29 Python
python requests.post带head和body的实例
Jan 02 Python
python+selenium 定位到元素,无法点击的解决方法
Jan 30 Python
python安装gdal的两种方法
Oct 29 Python
Python 依赖库太多了该如何管理
Nov 08 Python
Python+Redis实现布隆过滤器
Dec 08 Python
python3.x中安装web.py步骤方法
Jun 23 Python
tensorflow学习笔记之mnist的卷积神经网络实例
Apr 15 #Python
tensorflow学习笔记之简单的神经网络训练和测试
Apr 15 #Python
Pytorch入门之mnist分类实例
Apr 14 #Python
pytorch构建网络模型的4种方法
Apr 13 #Python
Python输入二维数组方法
Apr 13 #Python
Python基于递归实现电话号码映射功能示例
Apr 13 #Python
Python的多维空数组赋值方法
Apr 13 #Python
You might like
2014最热门的24个php类库汇总
2014/12/18 PHP
thinkPHP2.1自定义标签库的导入方法详解
2016/07/20 PHP
Laravel6.0.4中将添加计划任务事件的方法步骤
2019/10/15 PHP
JavaScript延迟加载
2021/03/09 Javascript
Prototype 学习 工具函数学习($方法)
2009/07/12 Javascript
php上传图片并给图片打上透明水印的代码
2010/06/07 Javascript
jQuery队列控制方法详解queue()/dequeue()/clearQueue()
2010/12/02 Javascript
jquery实现保存已选用户
2014/07/21 Javascript
轻松创建nodejs服务器(2):nodejs服务器的构成分析
2014/12/18 NodeJs
js去除浏览器默认底图的方法
2015/06/08 Javascript
JavaScript中数组的合并以及排序实现示例
2015/10/24 Javascript
AngularJS ng-blur 指令详解及简单实例
2016/07/30 Javascript
功能强大的Bootstrap效果展示(二)
2016/08/03 Javascript
jquery 点击元素后,滚动条滚动至该元素位置的方法
2016/08/05 Javascript
Node.js数据库操作之查询MySQL数据库(二)
2017/03/04 Javascript
带你快速理解javascript中的事件模型
2017/08/14 Javascript
vue通过cookie获取用户登录信息的思路详解
2018/10/30 Javascript
微信小程序调用wx.getImageInfo遇到的坑解决
2020/05/31 Javascript
vue-quill-editor的使用及个性化定制操作
2020/08/04 Javascript
[48:45]Ti4 循环赛第二日 NEWBEE vs EG
2014/07/11 DOTA
编写同时兼容Python2.x与Python3.x版本的代码的几个示例
2015/03/30 Python
Python装饰器原理与用法分析
2018/04/30 Python
用python3教你任意Html主内容提取功能
2018/11/05 Python
详解python实现小波变换的一个简单例子
2019/07/18 Python
Python 字符串处理特殊空格\xc2\xa0\t\n Non-breaking space
2020/02/23 Python
css3实现针线缝合效果(图解步骤)
2013/02/04 HTML / CSS
英国50岁以上人群的交友网站:Ourtime
2018/03/28 全球购物
电影T恤、80年代T恤和80年代服装:TV Store Online
2020/01/05 全球购物
EJB的几种类型
2012/08/15 面试题
研究生自荐信
2013/10/09 职场文书
平面设计专业求职信
2014/08/09 职场文书
中学生旷课检讨书2篇
2014/10/09 职场文书
2014年社区妇联工作总结
2014/12/02 职场文书
2015年上半年党建工作总结
2015/03/30 职场文书
施工安全保证书
2015/05/09 职场文书
2015年保险公司个人工作总结
2015/05/22 职场文书