Tensorflow卷积神经网络实例进阶


Posted in Python onMay 24, 2018

在Tensorflow卷积神经网络实例这篇博客中,我们实现了一个简单的卷积神经网络,没有复杂的Trick。接下来,我们将使用CIFAR-10数据集进行训练。

CIFAR-10是一个经典的数据集,包含60000张32*32的彩色图像,其中训练集50000张,测试集10000张。CIFAR-10如同其名字,一共标注为10类,每一类图片6000张。

本文实现了进阶的卷积神经网络来解决CIFAR-10分类问题,我们使用了一些新的技巧:

  1. 对weights进行了L2的正则化
  2. 对图片进行了翻转、随机剪切等数据增强,制造了更多样本
  3. 在每个卷积-最大池化层后面使用了LRN(局部响应归一化层),增强了模型的泛化能力

首先需要下载Tensorflow models Tensorflow models,以便使用其中的CIFAR-10数据的类.进入目录models/tutorials/image/cifar10目录,执行以下代码

import cifar10
import cifar10_input
import tensorflow as tf
import numpy as np
import time

# 定义batch_size, 训练轮数max_steps, 以及下载CIFAR-10数据的默认路径
max_steps = 3000
batch_size = 128
data_dir = 'E:\\tmp\cifar10_data\cifar-10-batches-bin'

# 定义初始化weight的函数,定义的同时,对weight加一个L2 loss,放在集'losses'中
def variable_with_weight_loss(shape, stddev, w1):
  var = tf.Variable(tf.truncated_normal(shape, stddev=stddev))
  if w1 is not None:
    weight_loss = tf.multiply(tf.nn.l2_loss(var), w1, name='weight_loss')
    tf.add_to_collection('losses', weight_loss)
  return var

# 使用cifar10类下载数据集,并解压、展开到其默认位置
#cifar10.maybe_download_and_extract()

# 在使用cifar10_input类中的distorted_inputs函数产生训练需要使用的数据。需要注意的是,返回的是已经封装好的tensor,
# 且对数据进行了Data Augmentation(水平翻转、随机剪切、设置随机亮度和对比度、对数据进行标准化)
images_train, labels_train = cifar10_input.distorted_inputs(data_dir=data_dir, batch_size=batch_size)

# 再使用cifar10_input.inputs函数生成测试数据,这里不需要进行太多处理
images_test, labels_test = cifar10_input.inputs(eval_data=True,
                        data_dir=data_dir,
                        batch_size=batch_size)

# 创建数据的placeholder
image_holder = tf.placeholder(tf.float32, [batch_size, 24, 24, 3])
label_holder = tf.placeholder(tf.int32, [batch_size])

# 创建第一个卷积层
weight1 = variable_with_weight_loss(shape=[5, 5, 3, 64], stddev=5e-2,
                  w1=0.0)
kernel1 = tf.nn.conv2d(image_holder, weight1, strides=[1, 1, 1, 1], padding='SAME')
bias1 = tf.Variable(tf.constant(0.0, shape=[64]))
conv1 = tf.nn.relu(tf.nn.bias_add(kernel1, bias1))
pool1 = tf.nn.max_pool(conv1, ksize=[1, 3, 3, 1], strides=[1, 2, 2, 1],
            padding='SAME')
# LRN层对ReLU会比较有用,但不适合Sigmoid这种有固定边界并且能抑制过大值的激活函数
norm1 = tf.nn.lrn(pool1, 4, bias=1.0, alpha=0.001 / 9.0, beta=0.75)

# 创建第二个卷积层
weight2 = variable_with_weight_loss(shape=[5, 5, 64, 64], stddev=5e-2,
                  w1=0.0)
kernel2 = tf.nn.conv2d(norm1, weight2, strides=[1, 1, 1, 1], padding='SAME')
bias2 = tf.Variable(tf.constant(0.1, shape=[64]))
conv2 = tf.nn.relu(tf.nn.bias_add(kernel2, bias2))
norm2 = tf.nn.lrn(conv2, 4, bias=1.0, alpha=0.001 / 9.0, beta=0.75)
pool2 = tf.nn.max_pool(norm2, ksize=[1, 3, 3, 1], strides=[1, 2, 2, 1],
            padding='SAME')

# 使用一个全连接层
reshape = tf.reshape(pool2, [batch_size, -1])
dim = reshape.get_shape()[1].value
weight3 = variable_with_weight_loss(shape=[dim, 384], stddev=0.04, w1=0.004)
bias3 = tf.Variable(tf.constant(0.1, shape=[384]))
local3 = tf.nn.relu(tf.matmul(reshape, weight3) + bias3)

# 再使用一个全连接层,隐含节点数下降了一半,只有192个,其他的超参数保持不变
weight4 = variable_with_weight_loss(shape=[384, 192], stddev=0.04, w1=0.004)
bias4 = tf.Variable(tf.constant(0.1, shape=[192]))
local4 = tf.nn.relu(tf.matmul(local3, weight4) + bias4)

# 最后一层,将softmax放在了计算loss部分
weight5 = variable_with_weight_loss(shape=[192, 10], stddev=1 / 192.0, w1=0.0)
bias5 = tf.Variable(tf.constant(0.0, shape=[10]))
logits = tf.add(tf.matmul(local4, weight5), bias5)

# 定义loss
def loss(logits, labels):
  labels = tf.cast(labels, tf.int64)
  cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=labels,
                                  name='cross_entropy_per_example')
  cross_entropy_mean = tf.reduce_mean(cross_entropy, name='cross_entropy')
  tf.add_to_collection('losses', cross_entropy_mean)
  return tf.add_n(tf.get_collection('losses'), name='total_loss')

# 获取最终的loss
loss = loss(logits, label_holder)

# 优化器
train_op = tf.train.AdamOptimizer(1e-3).minimize(loss)

# 使用tf.nn.in_top_k函数求输出结果中top k的准确率,默认使用top 1,也就是输出分数最高的那一类的准确率
top_k_op = tf.nn.in_top_k(logits, label_holder, 1)

# 使用tf.InteractiveSession创建默认的session,接着初始化全部模型参数
sess = tf.InteractiveSession()
tf.global_variables_initializer().run()

# 启动图片数据增强线程
tf.train.start_queue_runners()

# 正式开始训练
for step in range(max_steps):
  start_time = time.time()
  image_batch, label_batch = sess.run([images_train, labels_train])
  _, loss_value = sess.run([train_op, loss], feed_dict={image_holder: image_batch, label_holder: label_batch})
  duration = time.time() - start_time
  if step % 10 == 0:
    example_per_sec = batch_size / duration
    sec_per_batch = float(duration)
    format_str = 'step %d, loss=%.2f ,%.1f examples/sec, %.3f sec/batch'
    print(format_str % (step, loss_value, example_per_sec, sec_per_batch))

num_examples = 10000
import math
num_iter = int(math.ceil(num_examples / batch_size))
true_count = 0
total_sample_count = num_iter * batch_size
step = 0
while step < num_iter:
  image_batch, label_batch = sess.run([images_test, labels_test])
  predictions = sess.run([top_k_op], feed_dict={image_holder: image_batch, label_holder: label_holder})
  true_count += np.sum(predictions)
  step += 1

precision = true_count / total_sample_count
print('precision @ 1 = %.3f'%precision)

运行结果:

Tensorflow卷积神经网络实例进阶

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python中用于返回绝对值的abs()方法
May 14 Python
Python遍历目录中的所有文件的方法
Jul 08 Python
python daemon守护进程实现
Aug 27 Python
python多进程和多线程究竟谁更快(详解)
May 29 Python
Python基础教程之利用期物处理并发
Mar 29 Python
python制作mysql数据迁移脚本
Jan 01 Python
Python Tkinter 简单登录界面的实现
Jun 14 Python
Python实现K折交叉验证法的方法步骤
Jul 11 Python
Python 批量刷博客园访问量脚本过程解析
Aug 30 Python
python函数map()和partial()的知识点总结
May 26 Python
详解python中的异常和文件读写
Jan 03 Python
python树莓派通过队列实现进程交互的程序分析
Jul 04 Python
Tensorflow卷积神经网络实例
May 24 #Python
使用pandas的DataFrame的plot方法绘制图像的实例
May 24 #Python
TensorFlow实现卷积神经网络
May 24 #Python
tensorflow实现简单的卷积神经网络
May 24 #Python
tensorflow实现简单的卷积网络
May 24 #Python
解决pandas 作图无法显示中文的问题
May 24 #Python
TensorFlow实现简单卷积神经网络
May 24 #Python
You might like
php 在线打包_支持子目录
2008/06/28 PHP
php缩小png图片不损失透明色的解决方法
2013/12/25 PHP
php多个文件及图片上传实例详解
2014/11/10 PHP
ThinkPHP实现动态包含文件的方法
2014/11/29 PHP
PHP编程求最大公约数与最小公倍数的方法示例
2017/05/29 PHP
PHP设计模式之工厂模式定义与用法详解
2018/04/03 PHP
PHP实现的权重算法示例【可用于游戏根据权限来随机物品】
2019/02/15 PHP
firefox浏览器下javascript 拖动层效果与原理分析代码
2007/12/04 Javascript
QUnit jQuery的TDD框架
2010/11/04 Javascript
jquery删除ID为sNews的tr元素的内容
2014/04/10 Javascript
Javascript仿PHP $_GET获取URL中的参数
2014/05/12 Javascript
javascript实现英文首字母大写
2015/04/23 Javascript
jquery操作select元素和option的实例代码
2016/02/03 Javascript
Bootstrap popover用法详解
2016/12/22 Javascript
Angular 4.x+Ionic3踩坑之Ionic 3.x界面传值详解
2018/03/13 Javascript
使用 Jest 和 Supertest 进行接口端点测试实例详解
2020/04/25 Javascript
python网络编程示例(客户端与服务端)
2014/04/24 Python
python将文本中的空格替换为换行的方法
2018/03/19 Python
numpy.where() 用法详解
2019/05/27 Python
python爬取王者荣耀全皮肤的简单实现代码
2020/01/31 Python
Python读写操作csv和excle文件代码实例
2020/03/16 Python
python中re模块知识点总结
2021/01/17 Python
美国女性服饰销售网站:Nasty Gal(坏女孩)
2016/07/26 全球购物
美国台面电器和厨具品牌:KitchenAid
2019/04/12 全球购物
vue项目实现分页效果
2021/03/24 Vue.js
自荐信的格式
2014/03/10 职场文书
高中同学会活动方案
2014/08/14 职场文书
义务教育学校标准化建设汇报材料
2014/08/16 职场文书
作风整顿剖析材料
2014/09/30 职场文书
2014年政务公开工作总结
2014/12/09 职场文书
兵马俑导游词
2015/02/02 职场文书
监察建议书
2015/02/04 职场文书
简历自荐信范文
2015/03/09 职场文书
物业工程部主管岗位职责
2015/04/16 职场文书
餐厅开业活动方案
2019/07/08 职场文书
Spring boot admin 服务监控利器详解
2022/08/05 Java/Android