tensorflow 1.0用CNN进行图像分类


Posted in Python onApril 15, 2018

tensorflow升级到1.0之后,增加了一些高级模块: 如tf.layers, tf.metrics, 和tf.losses,使得代码稍微有些简化。

任务:花卉分类

版本:tensorflow 1.0

数据:flower-photos

花总共有五类,分别放在5个文件夹下。

闲话不多说,直接上代码,希望大家能看懂:)

复制代码

# -*- coding: utf-8 -*-

from skimage import io,transform
import glob
import os
import tensorflow as tf
import numpy as np
import time

path='e:/flower/'

#将所有的图片resize成100*100
w=100
h=100
c=3


#读取图片
def read_img(path):
 cate=[path+x for x in os.listdir(path) if os.path.isdir(path+x)]
 imgs=[]
 labels=[]
 for idx,folder in enumerate(cate):
  for im in glob.glob(folder+'/*.jpg'):
   print('reading the images:%s'%(im))
   img=io.imread(im)
   img=transform.resize(img,(w,h))
   imgs.append(img)
   labels.append(idx)
 return np.asarray(imgs,np.float32),np.asarray(labels,np.int32)
data,label=read_img(path)


#打乱顺序
num_example=data.shape[0]
arr=np.arange(num_example)
np.random.shuffle(arr)
data=data[arr]
label=label[arr]


#将所有数据分为训练集和验证集
ratio=0.8
s=np.int(num_example*ratio)
x_train=data[:s]
y_train=label[:s]
x_val=data[s:]
y_val=label[s:]

#-----------------构建网络----------------------
#占位符
x=tf.placeholder(tf.float32,shape=[None,w,h,c],name='x')
y_=tf.placeholder(tf.int32,shape=[None,],name='y_')

#第一个卷积层(100——>50)
conv1=tf.layers.conv2d(
  inputs=x,
  filters=32,
  kernel_size=[5, 5],
  padding="same",
  activation=tf.nn.relu,
  kernel_initializer=tf.truncated_normal_initializer(stddev=0.01))
pool1=tf.layers.max_pooling2d(inputs=conv1, pool_size=[2, 2], strides=2)

#第二个卷积层(50->25)
conv2=tf.layers.conv2d(
  inputs=pool1,
  filters=64,
  kernel_size=[5, 5],
  padding="same",
  activation=tf.nn.relu,
  kernel_initializer=tf.truncated_normal_initializer(stddev=0.01))
pool2=tf.layers.max_pooling2d(inputs=conv2, pool_size=[2, 2], strides=2)

#第三个卷积层(25->12)
conv3=tf.layers.conv2d(
  inputs=pool2,
  filters=128,
  kernel_size=[3, 3],
  padding="same",
  activation=tf.nn.relu,
  kernel_initializer=tf.truncated_normal_initializer(stddev=0.01))
pool3=tf.layers.max_pooling2d(inputs=conv3, pool_size=[2, 2], strides=2)

#第四个卷积层(12->6)
conv4=tf.layers.conv2d(
  inputs=pool3,
  filters=128,
  kernel_size=[3, 3],
  padding="same",
  activation=tf.nn.relu,
  kernel_initializer=tf.truncated_normal_initializer(stddev=0.01))
pool4=tf.layers.max_pooling2d(inputs=conv4, pool_size=[2, 2], strides=2)

re1 = tf.reshape(pool4, [-1, 6 * 6 * 128])

#全连接层
dense1 = tf.layers.dense(inputs=re1, 
      units=1024, 
      activation=tf.nn.relu,
      kernel_initializer=tf.truncated_normal_initializer(stddev=0.01),
      kernel_regularizer=tf.contrib.layers.l2_regularizer(0.003))
dense2= tf.layers.dense(inputs=dense1, 
      units=512, 
      activation=tf.nn.relu,
      kernel_initializer=tf.truncated_normal_initializer(stddev=0.01),
      kernel_regularizer=tf.contrib.layers.l2_regularizer(0.003))
logits= tf.layers.dense(inputs=dense2, 
      units=5, 
      activation=None,
      kernel_initializer=tf.truncated_normal_initializer(stddev=0.01),
      kernel_regularizer=tf.contrib.layers.l2_regularizer(0.003))
#---------------------------网络结束---------------------------

loss=tf.losses.sparse_softmax_cross_entropy(labels=y_,logits=logits)
train_op=tf.train.AdamOptimizer(learning_rate=0.001).minimize(loss)
correct_prediction = tf.equal(tf.cast(tf.argmax(logits,1),tf.int32), y_) 
acc= tf.reduce_mean(tf.cast(correct_prediction, tf.float32))


#定义一个函数,按批次取数据
def minibatches(inputs=None, targets=None, batch_size=None, shuffle=False):
 assert len(inputs) == len(targets)
 if shuffle:
  indices = np.arange(len(inputs))
  np.random.shuffle(indices)
 for start_idx in range(0, len(inputs) - batch_size + 1, batch_size):
  if shuffle:
   excerpt = indices[start_idx:start_idx + batch_size]
  else:
   excerpt = slice(start_idx, start_idx + batch_size)
  yield inputs[excerpt], targets[excerpt]


#训练和测试数据,可将n_epoch设置更大一些

n_epoch=10
batch_size=64
sess=tf.InteractiveSession() 
sess.run(tf.global_variables_initializer())
for epoch in range(n_epoch):
 start_time = time.time()
 
 #training
 train_loss, train_acc, n_batch = 0, 0, 0
 for x_train_a, y_train_a in minibatches(x_train, y_train, batch_size, shuffle=True):
  _,err,ac=sess.run([train_op,loss,acc], feed_dict={x: x_train_a, y_: y_train_a})
  train_loss += err; train_acc += ac; n_batch += 1
 print(" train loss: %f" % (train_loss/ n_batch))
 print(" train acc: %f" % (train_acc/ n_batch))
 
 #validation
 val_loss, val_acc, n_batch = 0, 0, 0
 for x_val_a, y_val_a in minibatches(x_val, y_val, batch_size, shuffle=False):
  err, ac = sess.run([loss,acc], feed_dict={x: x_val_a, y_: y_val_a})
  val_loss += err; val_acc += ac; n_batch += 1
 print(" validation loss: %f" % (val_loss/ n_batch))
 print(" validation acc: %f" % (val_acc/ n_batch))

sess.close()

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python将字符串转换成数组的方法
Apr 29 Python
详解Django中的form库的使用
Jul 18 Python
Scrapy爬虫实例讲解_校花网
Oct 23 Python
python实现壁纸批量下载代码实例
Jan 25 Python
Python 实现「食行生鲜」签到领积分功能
Sep 26 Python
适合Python初学者的一些编程技巧
Feb 12 Python
logging level级别介绍
Feb 21 Python
Python Numpy中数据的常用保存与读取方法
Apr 01 Python
Python生成随机验证码代码实例解析
Jun 09 Python
Django如何批量创建Model
Sep 01 Python
为2021年的第一场雪锦上添花:用matplotlib绘制雪花和雪景
Jan 05 Python
pytorch常用数据类型所占字节数对照表一览
May 17 Python
tensorflow学习笔记之mnist的卷积神经网络实例
Apr 15 #Python
tensorflow学习笔记之简单的神经网络训练和测试
Apr 15 #Python
Pytorch入门之mnist分类实例
Apr 14 #Python
pytorch构建网络模型的4种方法
Apr 13 #Python
Python输入二维数组方法
Apr 13 #Python
Python基于递归实现电话号码映射功能示例
Apr 13 #Python
Python的多维空数组赋值方法
Apr 13 #Python
You might like
php获取mysql字段名称和其它信息的例子
2014/04/14 PHP
php实现倒计时效果
2015/12/19 PHP
CentOS下搭建PHP环境与WordPress博客程序的全流程总结
2016/05/07 PHP
thinkPHP分组后模板无法加载问题解决方法
2016/07/12 PHP
事件委托与阻止冒泡阻止其父元素事件触发
2014/09/02 Javascript
jQuery实现冻结表头的方法
2015/03/09 Javascript
JavaScript中var关键字的使用详解
2015/08/14 Javascript
angularjs ui-router中路由的二级嵌套
2017/03/10 Javascript
详解用Node.js实现Restful风格webservice
2017/09/29 Javascript
Vue2.0用户权限控制解决方案
2017/11/29 Javascript
JavaScript学习总结(一) ECMAScript、BOM、DOM(核心、浏览器对象模型与文档对象模型)
2018/01/07 Javascript
Vue.js 2.x之组件的定义和注册图文详解
2018/06/19 Javascript
深入理解js A*寻路算法原理与具体实现过程
2018/12/13 Javascript
深入理解基于vue-cli的webpack打包优化实践及探索
2019/10/14 Javascript
JQuery事件委托(适用于给动态生成的脚本元素添加事件)
2020/02/01 jQuery
[03:09]2014DOTA2国际邀请赛 赛场上的美丽风景线 中国Coser也爱DOTA2
2014/07/20 DOTA
Python求两个list的差集、交集与并集的方法
2014/11/01 Python
Python安装第三方库及常见问题处理方法汇总
2016/09/13 Python
利用 Monkey 命令操作屏幕快速滑动
2016/12/07 Python
python3.4用函数操作mysql5.7数据库
2017/06/23 Python
python 限制函数调用次数的实例讲解
2018/04/21 Python
python后端接收前端回传的文件方法
2019/01/02 Python
基于python实现百度语音识别和图灵对话
2020/11/02 Python
python tkinter的消息框模块(messagebox,simpledialog)
2020/11/07 Python
html5构建触屏网站之网站尺寸探讨
2013/01/07 HTML / CSS
SmartBuyGlasses美国官网:太阳眼镜和眼镜
2017/08/20 全球购物
银行柜员应聘推荐信范文
2013/11/24 职场文书
工厂厂长的职责
2013/12/12 职场文书
2014年党的群众路线整改措施思想汇报
2014/10/12 职场文书
2015年党风廉政建设责任书
2015/01/29 职场文书
全陪导游词
2015/02/04 职场文书
房屋转让协议书(标准范本)
2016/03/21 职场文书
2019事业单位个人工作总结范文
2019/08/26 职场文书
手把手教你使用TensorFlow2实现RNN
2021/07/15 Python
在Spring-Boot中如何使用@Value注解注入集合类
2021/08/02 Java/Android
NASA 机智号火星直升机拍到了毅力号设备碎片
2022/04/29 数码科技