tensorflow 1.0用CNN进行图像分类


Posted in Python onApril 15, 2018

tensorflow升级到1.0之后,增加了一些高级模块: 如tf.layers, tf.metrics, 和tf.losses,使得代码稍微有些简化。

任务:花卉分类

版本:tensorflow 1.0

数据:flower-photos

花总共有五类,分别放在5个文件夹下。

闲话不多说,直接上代码,希望大家能看懂:)

复制代码

# -*- coding: utf-8 -*-

from skimage import io,transform
import glob
import os
import tensorflow as tf
import numpy as np
import time

path='e:/flower/'

#将所有的图片resize成100*100
w=100
h=100
c=3


#读取图片
def read_img(path):
 cate=[path+x for x in os.listdir(path) if os.path.isdir(path+x)]
 imgs=[]
 labels=[]
 for idx,folder in enumerate(cate):
  for im in glob.glob(folder+'/*.jpg'):
   print('reading the images:%s'%(im))
   img=io.imread(im)
   img=transform.resize(img,(w,h))
   imgs.append(img)
   labels.append(idx)
 return np.asarray(imgs,np.float32),np.asarray(labels,np.int32)
data,label=read_img(path)


#打乱顺序
num_example=data.shape[0]
arr=np.arange(num_example)
np.random.shuffle(arr)
data=data[arr]
label=label[arr]


#将所有数据分为训练集和验证集
ratio=0.8
s=np.int(num_example*ratio)
x_train=data[:s]
y_train=label[:s]
x_val=data[s:]
y_val=label[s:]

#-----------------构建网络----------------------
#占位符
x=tf.placeholder(tf.float32,shape=[None,w,h,c],name='x')
y_=tf.placeholder(tf.int32,shape=[None,],name='y_')

#第一个卷积层(100——>50)
conv1=tf.layers.conv2d(
  inputs=x,
  filters=32,
  kernel_size=[5, 5],
  padding="same",
  activation=tf.nn.relu,
  kernel_initializer=tf.truncated_normal_initializer(stddev=0.01))
pool1=tf.layers.max_pooling2d(inputs=conv1, pool_size=[2, 2], strides=2)

#第二个卷积层(50->25)
conv2=tf.layers.conv2d(
  inputs=pool1,
  filters=64,
  kernel_size=[5, 5],
  padding="same",
  activation=tf.nn.relu,
  kernel_initializer=tf.truncated_normal_initializer(stddev=0.01))
pool2=tf.layers.max_pooling2d(inputs=conv2, pool_size=[2, 2], strides=2)

#第三个卷积层(25->12)
conv3=tf.layers.conv2d(
  inputs=pool2,
  filters=128,
  kernel_size=[3, 3],
  padding="same",
  activation=tf.nn.relu,
  kernel_initializer=tf.truncated_normal_initializer(stddev=0.01))
pool3=tf.layers.max_pooling2d(inputs=conv3, pool_size=[2, 2], strides=2)

#第四个卷积层(12->6)
conv4=tf.layers.conv2d(
  inputs=pool3,
  filters=128,
  kernel_size=[3, 3],
  padding="same",
  activation=tf.nn.relu,
  kernel_initializer=tf.truncated_normal_initializer(stddev=0.01))
pool4=tf.layers.max_pooling2d(inputs=conv4, pool_size=[2, 2], strides=2)

re1 = tf.reshape(pool4, [-1, 6 * 6 * 128])

#全连接层
dense1 = tf.layers.dense(inputs=re1, 
      units=1024, 
      activation=tf.nn.relu,
      kernel_initializer=tf.truncated_normal_initializer(stddev=0.01),
      kernel_regularizer=tf.contrib.layers.l2_regularizer(0.003))
dense2= tf.layers.dense(inputs=dense1, 
      units=512, 
      activation=tf.nn.relu,
      kernel_initializer=tf.truncated_normal_initializer(stddev=0.01),
      kernel_regularizer=tf.contrib.layers.l2_regularizer(0.003))
logits= tf.layers.dense(inputs=dense2, 
      units=5, 
      activation=None,
      kernel_initializer=tf.truncated_normal_initializer(stddev=0.01),
      kernel_regularizer=tf.contrib.layers.l2_regularizer(0.003))
#---------------------------网络结束---------------------------

loss=tf.losses.sparse_softmax_cross_entropy(labels=y_,logits=logits)
train_op=tf.train.AdamOptimizer(learning_rate=0.001).minimize(loss)
correct_prediction = tf.equal(tf.cast(tf.argmax(logits,1),tf.int32), y_) 
acc= tf.reduce_mean(tf.cast(correct_prediction, tf.float32))


#定义一个函数,按批次取数据
def minibatches(inputs=None, targets=None, batch_size=None, shuffle=False):
 assert len(inputs) == len(targets)
 if shuffle:
  indices = np.arange(len(inputs))
  np.random.shuffle(indices)
 for start_idx in range(0, len(inputs) - batch_size + 1, batch_size):
  if shuffle:
   excerpt = indices[start_idx:start_idx + batch_size]
  else:
   excerpt = slice(start_idx, start_idx + batch_size)
  yield inputs[excerpt], targets[excerpt]


#训练和测试数据,可将n_epoch设置更大一些

n_epoch=10
batch_size=64
sess=tf.InteractiveSession() 
sess.run(tf.global_variables_initializer())
for epoch in range(n_epoch):
 start_time = time.time()
 
 #training
 train_loss, train_acc, n_batch = 0, 0, 0
 for x_train_a, y_train_a in minibatches(x_train, y_train, batch_size, shuffle=True):
  _,err,ac=sess.run([train_op,loss,acc], feed_dict={x: x_train_a, y_: y_train_a})
  train_loss += err; train_acc += ac; n_batch += 1
 print(" train loss: %f" % (train_loss/ n_batch))
 print(" train acc: %f" % (train_acc/ n_batch))
 
 #validation
 val_loss, val_acc, n_batch = 0, 0, 0
 for x_val_a, y_val_a in minibatches(x_val, y_val, batch_size, shuffle=False):
  err, ac = sess.run([loss,acc], feed_dict={x: x_val_a, y_: y_val_a})
  val_loss += err; val_acc += ac; n_batch += 1
 print(" validation loss: %f" % (val_loss/ n_batch))
 print(" validation acc: %f" % (val_acc/ n_batch))

sess.close()

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
使用Python脚本将Bing的每日图片作为桌面的教程
May 04 Python
Python找出9个连续的空闲端口
Feb 01 Python
使用简单工厂模式来进行Python的设计模式编程
Mar 01 Python
python使用logging模块发送邮件代码示例
Jan 18 Python
python修改list中所有元素类型的三种方法
Apr 09 Python
解决python删除文件的权限错误问题
Apr 24 Python
python连接mongodb密码认证实例
Oct 16 Python
检测python爬虫时是否代理ip伪装成功的方法
Jul 12 Python
pycharm中显示CSS提示的知识点总结
Jul 29 Python
python opencv实现证件照换底功能
Aug 19 Python
Python发送邮件的实例代码讲解
Oct 16 Python
使用 Python 读取电子表格中的数据实例详解
Apr 17 Python
tensorflow学习笔记之mnist的卷积神经网络实例
Apr 15 #Python
tensorflow学习笔记之简单的神经网络训练和测试
Apr 15 #Python
Pytorch入门之mnist分类实例
Apr 14 #Python
pytorch构建网络模型的4种方法
Apr 13 #Python
Python输入二维数组方法
Apr 13 #Python
Python基于递归实现电话号码映射功能示例
Apr 13 #Python
Python的多维空数组赋值方法
Apr 13 #Python
You might like
php遍历目录viewDir函数
2009/12/15 PHP
编译php 5.2.14+fpm+memcached(具体操作详解)
2013/06/18 PHP
JS获取屏幕,浏览器窗口大小,网页高度宽度(实现代码)
2013/12/17 Javascript
js判断ie版本号的简单实现代码
2014/03/05 Javascript
js判断浏览器类型为ie6时不执行
2014/06/15 Javascript
jQuery处理json数据返回数组和输出的方法
2015/03/11 Javascript
原生js页面滚动延迟加载图片
2015/12/20 Javascript
浅析JS异步加载进度条
2016/05/05 Javascript
使用jQuery.Qrcode插件在客户端动态生成二维码并添加自定义Logo
2016/09/01 Javascript
最原始的jQuery注册验证方式
2016/10/11 Javascript
jQuery实现一个简单的轮播图
2017/02/19 Javascript
移动web开发之touch事件实例详解
2018/01/17 Javascript
JavaScript 正则命名分组【推荐】
2018/06/07 Javascript
详解微信小程序开发之formId使用(模板消息)
2019/08/27 Javascript
微信头像地址失效踩坑记附带解决方案
2019/09/23 Javascript
vue+render+jsx实现可编辑动态多级表头table的实例代码
2020/04/01 Javascript
基于JavaScript实现简单抽奖功能代码实例
2020/10/20 Javascript
初步解析Python中的yield函数的用法
2015/04/03 Python
在Django中管理Users和Permissions以及Groups的方法
2015/07/23 Python
详解Python的Flask框架中生成SECRET_KEY密钥的方法
2016/06/07 Python
python提取图像的名字*.jpg到txt文本的方法
2018/05/10 Python
python+openCV利用摄像头实现人员活动检测
2019/06/22 Python
python MultipartEncoder传输zip文件实例
2020/04/07 Python
Python数据正态性检验实现过程
2020/04/18 Python
使用CSS3的box-sizing属性解决div宽高被内边距撑开的问题
2016/06/28 HTML / CSS
美国复古街头服饰精品店:Need Supply Co.
2017/02/22 全球购物
GWebs公司笔试题
2012/05/04 面试题
先进事迹演讲稿
2014/09/01 职场文书
2014老师三严三实对照检查材料思想汇报
2014/09/18 职场文书
三提三创主题教育活动查摆整改措施
2014/10/25 职场文书
大学生求职自荐信
2015/03/24 职场文书
企业管理不到位检讨书
2019/06/27 职场文书
vue-cli4.5.x快速搭建项目
2021/05/30 Vue.js
SpringBoot集成Redis的思路详解
2021/10/16 Redis
关于mysql中时间日期类型和字符串类型的选择
2021/11/27 MySQL
SpringBoot 集成短信和邮件 以阿里云短信服务为例
2022/04/22 Java/Android