Tensorflow卷积神经网络实例进阶


Posted in Python onMay 24, 2018

在Tensorflow卷积神经网络实例这篇博客中,我们实现了一个简单的卷积神经网络,没有复杂的Trick。接下来,我们将使用CIFAR-10数据集进行训练。

CIFAR-10是一个经典的数据集,包含60000张32*32的彩色图像,其中训练集50000张,测试集10000张。CIFAR-10如同其名字,一共标注为10类,每一类图片6000张。

本文实现了进阶的卷积神经网络来解决CIFAR-10分类问题,我们使用了一些新的技巧:

  1. 对weights进行了L2的正则化
  2. 对图片进行了翻转、随机剪切等数据增强,制造了更多样本
  3. 在每个卷积-最大池化层后面使用了LRN(局部响应归一化层),增强了模型的泛化能力

首先需要下载Tensorflow models Tensorflow models,以便使用其中的CIFAR-10数据的类.进入目录models/tutorials/image/cifar10目录,执行以下代码

import cifar10
import cifar10_input
import tensorflow as tf
import numpy as np
import time

# 定义batch_size, 训练轮数max_steps, 以及下载CIFAR-10数据的默认路径
max_steps = 3000
batch_size = 128
data_dir = 'E:\\tmp\cifar10_data\cifar-10-batches-bin'

# 定义初始化weight的函数,定义的同时,对weight加一个L2 loss,放在集'losses'中
def variable_with_weight_loss(shape, stddev, w1):
  var = tf.Variable(tf.truncated_normal(shape, stddev=stddev))
  if w1 is not None:
    weight_loss = tf.multiply(tf.nn.l2_loss(var), w1, name='weight_loss')
    tf.add_to_collection('losses', weight_loss)
  return var

# 使用cifar10类下载数据集,并解压、展开到其默认位置
#cifar10.maybe_download_and_extract()

# 在使用cifar10_input类中的distorted_inputs函数产生训练需要使用的数据。需要注意的是,返回的是已经封装好的tensor,
# 且对数据进行了Data Augmentation(水平翻转、随机剪切、设置随机亮度和对比度、对数据进行标准化)
images_train, labels_train = cifar10_input.distorted_inputs(data_dir=data_dir, batch_size=batch_size)

# 再使用cifar10_input.inputs函数生成测试数据,这里不需要进行太多处理
images_test, labels_test = cifar10_input.inputs(eval_data=True,
                        data_dir=data_dir,
                        batch_size=batch_size)

# 创建数据的placeholder
image_holder = tf.placeholder(tf.float32, [batch_size, 24, 24, 3])
label_holder = tf.placeholder(tf.int32, [batch_size])

# 创建第一个卷积层
weight1 = variable_with_weight_loss(shape=[5, 5, 3, 64], stddev=5e-2,
                  w1=0.0)
kernel1 = tf.nn.conv2d(image_holder, weight1, strides=[1, 1, 1, 1], padding='SAME')
bias1 = tf.Variable(tf.constant(0.0, shape=[64]))
conv1 = tf.nn.relu(tf.nn.bias_add(kernel1, bias1))
pool1 = tf.nn.max_pool(conv1, ksize=[1, 3, 3, 1], strides=[1, 2, 2, 1],
            padding='SAME')
# LRN层对ReLU会比较有用,但不适合Sigmoid这种有固定边界并且能抑制过大值的激活函数
norm1 = tf.nn.lrn(pool1, 4, bias=1.0, alpha=0.001 / 9.0, beta=0.75)

# 创建第二个卷积层
weight2 = variable_with_weight_loss(shape=[5, 5, 64, 64], stddev=5e-2,
                  w1=0.0)
kernel2 = tf.nn.conv2d(norm1, weight2, strides=[1, 1, 1, 1], padding='SAME')
bias2 = tf.Variable(tf.constant(0.1, shape=[64]))
conv2 = tf.nn.relu(tf.nn.bias_add(kernel2, bias2))
norm2 = tf.nn.lrn(conv2, 4, bias=1.0, alpha=0.001 / 9.0, beta=0.75)
pool2 = tf.nn.max_pool(norm2, ksize=[1, 3, 3, 1], strides=[1, 2, 2, 1],
            padding='SAME')

# 使用一个全连接层
reshape = tf.reshape(pool2, [batch_size, -1])
dim = reshape.get_shape()[1].value
weight3 = variable_with_weight_loss(shape=[dim, 384], stddev=0.04, w1=0.004)
bias3 = tf.Variable(tf.constant(0.1, shape=[384]))
local3 = tf.nn.relu(tf.matmul(reshape, weight3) + bias3)

# 再使用一个全连接层,隐含节点数下降了一半,只有192个,其他的超参数保持不变
weight4 = variable_with_weight_loss(shape=[384, 192], stddev=0.04, w1=0.004)
bias4 = tf.Variable(tf.constant(0.1, shape=[192]))
local4 = tf.nn.relu(tf.matmul(local3, weight4) + bias4)

# 最后一层,将softmax放在了计算loss部分
weight5 = variable_with_weight_loss(shape=[192, 10], stddev=1 / 192.0, w1=0.0)
bias5 = tf.Variable(tf.constant(0.0, shape=[10]))
logits = tf.add(tf.matmul(local4, weight5), bias5)

# 定义loss
def loss(logits, labels):
  labels = tf.cast(labels, tf.int64)
  cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=labels,
                                  name='cross_entropy_per_example')
  cross_entropy_mean = tf.reduce_mean(cross_entropy, name='cross_entropy')
  tf.add_to_collection('losses', cross_entropy_mean)
  return tf.add_n(tf.get_collection('losses'), name='total_loss')

# 获取最终的loss
loss = loss(logits, label_holder)

# 优化器
train_op = tf.train.AdamOptimizer(1e-3).minimize(loss)

# 使用tf.nn.in_top_k函数求输出结果中top k的准确率,默认使用top 1,也就是输出分数最高的那一类的准确率
top_k_op = tf.nn.in_top_k(logits, label_holder, 1)

# 使用tf.InteractiveSession创建默认的session,接着初始化全部模型参数
sess = tf.InteractiveSession()
tf.global_variables_initializer().run()

# 启动图片数据增强线程
tf.train.start_queue_runners()

# 正式开始训练
for step in range(max_steps):
  start_time = time.time()
  image_batch, label_batch = sess.run([images_train, labels_train])
  _, loss_value = sess.run([train_op, loss], feed_dict={image_holder: image_batch, label_holder: label_batch})
  duration = time.time() - start_time
  if step % 10 == 0:
    example_per_sec = batch_size / duration
    sec_per_batch = float(duration)
    format_str = 'step %d, loss=%.2f ,%.1f examples/sec, %.3f sec/batch'
    print(format_str % (step, loss_value, example_per_sec, sec_per_batch))

num_examples = 10000
import math
num_iter = int(math.ceil(num_examples / batch_size))
true_count = 0
total_sample_count = num_iter * batch_size
step = 0
while step < num_iter:
  image_batch, label_batch = sess.run([images_test, labels_test])
  predictions = sess.run([top_k_op], feed_dict={image_holder: image_batch, label_holder: label_holder})
  true_count += np.sum(predictions)
  step += 1

precision = true_count / total_sample_count
print('precision @ 1 = %.3f'%precision)

运行结果:

Tensorflow卷积神经网络实例进阶

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python Trie树实现字典排序
Mar 28 Python
整理Python最基本的操作字典的方法
Apr 24 Python
python简单文本处理的方法
Jul 10 Python
Android应用开发中Action bar编写的入门教程
Feb 26 Python
Python构建网页爬虫原理分析
Dec 19 Python
解决python读取几千万行的大表内存问题
Jun 26 Python
python使用suds调用webservice接口的方法
Jan 03 Python
Python中的类与类型示例详解
Jul 10 Python
Python创建一个元素都为0的列表实例
Nov 28 Python
PyCharm+Pipenv虚拟环境开发和依赖管理的教程详解
Apr 16 Python
Python实现加密接口测试方法步骤详解
Jun 05 Python
Python3爬虫mitmproxy的安装步骤
Jul 29 Python
Tensorflow卷积神经网络实例
May 24 #Python
使用pandas的DataFrame的plot方法绘制图像的实例
May 24 #Python
TensorFlow实现卷积神经网络
May 24 #Python
tensorflow实现简单的卷积神经网络
May 24 #Python
tensorflow实现简单的卷积网络
May 24 #Python
解决pandas 作图无法显示中文的问题
May 24 #Python
TensorFlow实现简单卷积神经网络
May 24 #Python
You might like
PHPMyAdmin 快速配置方法
2009/05/11 PHP
php和javascript之间变量的传递实现代码
2012/12/19 PHP
php 注册时输入信息验证器的实现详解
2013/07/05 PHP
那些年我们错过的魔术方法(Magic Methods)
2014/01/14 PHP
从零开始学YII2框架(二)通过 Composer 安装扩展插件
2014/08/20 PHP
php新浪微博登录接口用法实例
2014/12/23 PHP
WordPress中获取页面链接和标题的相关PHP函数用法解析
2015/12/17 PHP
Yii2中datetime类的使用
2016/12/17 PHP
PHP中危险的file_put_contents函数详解
2017/11/04 PHP
laravel 解决paginate查询多个字段报错的问题
2019/10/22 PHP
一个背景云变换js特效 鼠标移动背景云变化
2012/12/28 Javascript
获取客户端网卡MAC地址和IP地址实现JS代码
2013/03/17 Javascript
JavaScript将数据转换成整数的方法
2014/01/04 Javascript
jQuery的load()方法及其回调函数用法实例
2015/03/25 Javascript
jQuery控制Div拖拽效果完整实例分析
2015/04/15 Javascript
jQuery使用animate实现ul列表项相互飘动效果示例
2016/09/16 Javascript
jQuery中页面返回顶部的方法总结
2016/12/30 Javascript
jQuery pjax 应用简单示例
2018/09/20 jQuery
详解Vue改变数组中对象的属性不重新渲染View的解决方案
2018/09/21 Javascript
使用Javascript简单计算器
2018/11/17 Javascript
vue实现移动端悬浮窗效果
2018/12/01 Javascript
微信小程序自定义导航教程(兼容各种手机)
2018/12/12 Javascript
Vue起步(无cli)的啊教程详解
2019/04/11 Javascript
vue实现手机号码的校验实例代码(防抖函数的应用场景)
2019/09/05 Javascript
vue 组件内获取actions的response方式
2019/11/08 Javascript
JS实现简单日历特效
2020/01/03 Javascript
nuxt 服务器渲染动态设置 title和seo关键字的操作
2020/11/05 Javascript
C#笔试题集合
2013/06/21 面试题
班主任经验交流会主持词
2014/04/01 职场文书
超越自我演讲稿
2014/05/21 职场文书
2014统计局民主生活会对照检查材料思想汇报
2014/10/02 职场文书
学生违反校规检讨书
2014/10/28 职场文书
2015年监理个人工作总结
2015/05/23 职场文书
安全教育观后感
2015/06/17 职场文书
python绘制简单直方图(质量分布图)的方法
2022/04/21 Python
oracle delete误删除表数据后如何恢复
2022/06/28 Oracle