tensorflow 1.0用CNN进行图像分类


Posted in Python onApril 15, 2018

tensorflow升级到1.0之后,增加了一些高级模块: 如tf.layers, tf.metrics, 和tf.losses,使得代码稍微有些简化。

任务:花卉分类

版本:tensorflow 1.0

数据:flower-photos

花总共有五类,分别放在5个文件夹下。

闲话不多说,直接上代码,希望大家能看懂:)

复制代码

# -*- coding: utf-8 -*-

from skimage import io,transform
import glob
import os
import tensorflow as tf
import numpy as np
import time

path='e:/flower/'

#将所有的图片resize成100*100
w=100
h=100
c=3


#读取图片
def read_img(path):
 cate=[path+x for x in os.listdir(path) if os.path.isdir(path+x)]
 imgs=[]
 labels=[]
 for idx,folder in enumerate(cate):
  for im in glob.glob(folder+'/*.jpg'):
   print('reading the images:%s'%(im))
   img=io.imread(im)
   img=transform.resize(img,(w,h))
   imgs.append(img)
   labels.append(idx)
 return np.asarray(imgs,np.float32),np.asarray(labels,np.int32)
data,label=read_img(path)


#打乱顺序
num_example=data.shape[0]
arr=np.arange(num_example)
np.random.shuffle(arr)
data=data[arr]
label=label[arr]


#将所有数据分为训练集和验证集
ratio=0.8
s=np.int(num_example*ratio)
x_train=data[:s]
y_train=label[:s]
x_val=data[s:]
y_val=label[s:]

#-----------------构建网络----------------------
#占位符
x=tf.placeholder(tf.float32,shape=[None,w,h,c],name='x')
y_=tf.placeholder(tf.int32,shape=[None,],name='y_')

#第一个卷积层(100——>50)
conv1=tf.layers.conv2d(
  inputs=x,
  filters=32,
  kernel_size=[5, 5],
  padding="same",
  activation=tf.nn.relu,
  kernel_initializer=tf.truncated_normal_initializer(stddev=0.01))
pool1=tf.layers.max_pooling2d(inputs=conv1, pool_size=[2, 2], strides=2)

#第二个卷积层(50->25)
conv2=tf.layers.conv2d(
  inputs=pool1,
  filters=64,
  kernel_size=[5, 5],
  padding="same",
  activation=tf.nn.relu,
  kernel_initializer=tf.truncated_normal_initializer(stddev=0.01))
pool2=tf.layers.max_pooling2d(inputs=conv2, pool_size=[2, 2], strides=2)

#第三个卷积层(25->12)
conv3=tf.layers.conv2d(
  inputs=pool2,
  filters=128,
  kernel_size=[3, 3],
  padding="same",
  activation=tf.nn.relu,
  kernel_initializer=tf.truncated_normal_initializer(stddev=0.01))
pool3=tf.layers.max_pooling2d(inputs=conv3, pool_size=[2, 2], strides=2)

#第四个卷积层(12->6)
conv4=tf.layers.conv2d(
  inputs=pool3,
  filters=128,
  kernel_size=[3, 3],
  padding="same",
  activation=tf.nn.relu,
  kernel_initializer=tf.truncated_normal_initializer(stddev=0.01))
pool4=tf.layers.max_pooling2d(inputs=conv4, pool_size=[2, 2], strides=2)

re1 = tf.reshape(pool4, [-1, 6 * 6 * 128])

#全连接层
dense1 = tf.layers.dense(inputs=re1, 
      units=1024, 
      activation=tf.nn.relu,
      kernel_initializer=tf.truncated_normal_initializer(stddev=0.01),
      kernel_regularizer=tf.contrib.layers.l2_regularizer(0.003))
dense2= tf.layers.dense(inputs=dense1, 
      units=512, 
      activation=tf.nn.relu,
      kernel_initializer=tf.truncated_normal_initializer(stddev=0.01),
      kernel_regularizer=tf.contrib.layers.l2_regularizer(0.003))
logits= tf.layers.dense(inputs=dense2, 
      units=5, 
      activation=None,
      kernel_initializer=tf.truncated_normal_initializer(stddev=0.01),
      kernel_regularizer=tf.contrib.layers.l2_regularizer(0.003))
#---------------------------网络结束---------------------------

loss=tf.losses.sparse_softmax_cross_entropy(labels=y_,logits=logits)
train_op=tf.train.AdamOptimizer(learning_rate=0.001).minimize(loss)
correct_prediction = tf.equal(tf.cast(tf.argmax(logits,1),tf.int32), y_) 
acc= tf.reduce_mean(tf.cast(correct_prediction, tf.float32))


#定义一个函数,按批次取数据
def minibatches(inputs=None, targets=None, batch_size=None, shuffle=False):
 assert len(inputs) == len(targets)
 if shuffle:
  indices = np.arange(len(inputs))
  np.random.shuffle(indices)
 for start_idx in range(0, len(inputs) - batch_size + 1, batch_size):
  if shuffle:
   excerpt = indices[start_idx:start_idx + batch_size]
  else:
   excerpt = slice(start_idx, start_idx + batch_size)
  yield inputs[excerpt], targets[excerpt]


#训练和测试数据,可将n_epoch设置更大一些

n_epoch=10
batch_size=64
sess=tf.InteractiveSession() 
sess.run(tf.global_variables_initializer())
for epoch in range(n_epoch):
 start_time = time.time()
 
 #training
 train_loss, train_acc, n_batch = 0, 0, 0
 for x_train_a, y_train_a in minibatches(x_train, y_train, batch_size, shuffle=True):
  _,err,ac=sess.run([train_op,loss,acc], feed_dict={x: x_train_a, y_: y_train_a})
  train_loss += err; train_acc += ac; n_batch += 1
 print(" train loss: %f" % (train_loss/ n_batch))
 print(" train acc: %f" % (train_acc/ n_batch))
 
 #validation
 val_loss, val_acc, n_batch = 0, 0, 0
 for x_val_a, y_val_a in minibatches(x_val, y_val, batch_size, shuffle=False):
  err, ac = sess.run([loss,acc], feed_dict={x: x_val_a, y_: y_val_a})
  val_loss += err; val_acc += ac; n_batch += 1
 print(" validation loss: %f" % (val_loss/ n_batch))
 print(" validation acc: %f" % (val_acc/ n_batch))

sess.close()

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python调用Moxa PCOMM Lite通过串口Ymodem协议实现发送文件
Aug 15 Python
Python写的一个定时重跑获取数据库数据
Dec 28 Python
Python基于回溯法子集树模板解决全排列问题示例
Sep 07 Python
python利用paramiko连接远程服务器执行命令的方法
Oct 16 Python
Python数据结构与算法之常见的分配排序法示例【桶排序与基数排序】
Dec 15 Python
python如何生成网页验证码
Jul 28 Python
python NumPy ndarray二维数组 按照行列求平均实例
Nov 26 Python
解决pytorch报错:AssertionError: Invalid device id的问题
Jan 10 Python
python列表的逆序遍历实现
Apr 20 Python
Python pymysql模块安装并操作过程解析
Oct 13 Python
python mock测试的示例
Oct 19 Python
浅谈Python列表嵌套字典转化的问题
Apr 07 Python
tensorflow学习笔记之mnist的卷积神经网络实例
Apr 15 #Python
tensorflow学习笔记之简单的神经网络训练和测试
Apr 15 #Python
Pytorch入门之mnist分类实例
Apr 14 #Python
pytorch构建网络模型的4种方法
Apr 13 #Python
Python输入二维数组方法
Apr 13 #Python
Python基于递归实现电话号码映射功能示例
Apr 13 #Python
Python的多维空数组赋值方法
Apr 13 #Python
You might like
用PHP动态创建Flash动画
2006/10/09 PHP
用函数读出数据表内容放入二维数组
2006/10/09 PHP
PHP中redis的用法深入解析
2014/02/20 PHP
PHP 常用的header头部定义汇总
2015/06/19 PHP
Laravel如何实现适合Api的异常处理响应格式
2020/06/14 PHP
基于JQuery的cookie插件
2010/04/07 Javascript
div层的移动及性能优化
2010/11/16 Javascript
轻松创建nodejs服务器(9):实现非阻塞操作
2014/12/18 NodeJs
jquery代码实现简单的随机图片瀑布流效果
2015/04/20 Javascript
zepto中使用swipe.js制作轮播图附swipeUp,swipeDown不起效果问题
2015/08/27 Javascript
AngularJS在IE下取数据总是缓存问题的解决方法
2016/08/05 Javascript
Bootstrap基本插件学习笔记之模态对话框(16)
2016/12/08 Javascript
JS操作xml对象转换为Json对象示例
2017/03/25 Javascript
简单谈谈React中的路由系统
2017/07/25 Javascript
微信小程序-滚动消息通知的实例代码
2017/08/03 Javascript
JavaScript中关于class的调用方法
2017/11/28 Javascript
关于vue 结合原生js 解决echarts resize问题
2020/07/26 Javascript
解决ant-design-vue中menu菜单无法默认展开的问题
2020/10/31 Javascript
[09:23]国际邀请赛采访专栏:iG战队VK,Tongfu战队Cu
2013/08/05 DOTA
[01:06:43]完美世界DOTA2联赛PWL S3 PXG vs GXR 第二场 12.19
2020/12/24 DOTA
Python实现遍历数据库并获取key的值
2015/05/17 Python
python微信跳一跳系列之色块轮廓定位棋盘
2018/02/26 Python
django反向解析URL和URL命名空间的方法
2018/06/05 Python
PyCharm中Matplotlib绘图不能显示UI效果的问题解决
2020/03/12 Python
服务器端jupyter notebook映射到本地浏览器的操作
2020/04/14 Python
Python Pillow(PIL)库的用法详解
2020/09/19 Python
Python模拟登录和登录跳转的参考示例
2020/10/30 Python
日本最大美瞳直送网:Morecontact(中文)
2019/04/03 全球购物
Made in Design英国:设计家具、照明、家庭装饰和花园家具
2019/09/24 全球购物
机电职业生涯规划书范文
2014/03/08 职场文书
连带责任保证书
2014/04/29 职场文书
银行贷款收入证明
2014/10/17 职场文书
党员个人整改措施
2014/10/24 职场文书
2015年公路路政个人工作总结
2015/07/24 职场文书
tensorflow学习笔记之tfrecord文件的生成与读取
2021/03/31 Python
一次项目中Thinkphp绕过禁用函数的实战记录
2021/11/17 PHP