Tensorflow卷积神经网络实例进阶


Posted in Python onMay 24, 2018

在Tensorflow卷积神经网络实例这篇博客中,我们实现了一个简单的卷积神经网络,没有复杂的Trick。接下来,我们将使用CIFAR-10数据集进行训练。

CIFAR-10是一个经典的数据集,包含60000张32*32的彩色图像,其中训练集50000张,测试集10000张。CIFAR-10如同其名字,一共标注为10类,每一类图片6000张。

本文实现了进阶的卷积神经网络来解决CIFAR-10分类问题,我们使用了一些新的技巧:

  1. 对weights进行了L2的正则化
  2. 对图片进行了翻转、随机剪切等数据增强,制造了更多样本
  3. 在每个卷积-最大池化层后面使用了LRN(局部响应归一化层),增强了模型的泛化能力

首先需要下载Tensorflow models Tensorflow models,以便使用其中的CIFAR-10数据的类.进入目录models/tutorials/image/cifar10目录,执行以下代码

import cifar10
import cifar10_input
import tensorflow as tf
import numpy as np
import time

# 定义batch_size, 训练轮数max_steps, 以及下载CIFAR-10数据的默认路径
max_steps = 3000
batch_size = 128
data_dir = 'E:\\tmp\cifar10_data\cifar-10-batches-bin'

# 定义初始化weight的函数,定义的同时,对weight加一个L2 loss,放在集'losses'中
def variable_with_weight_loss(shape, stddev, w1):
  var = tf.Variable(tf.truncated_normal(shape, stddev=stddev))
  if w1 is not None:
    weight_loss = tf.multiply(tf.nn.l2_loss(var), w1, name='weight_loss')
    tf.add_to_collection('losses', weight_loss)
  return var

# 使用cifar10类下载数据集,并解压、展开到其默认位置
#cifar10.maybe_download_and_extract()

# 在使用cifar10_input类中的distorted_inputs函数产生训练需要使用的数据。需要注意的是,返回的是已经封装好的tensor,
# 且对数据进行了Data Augmentation(水平翻转、随机剪切、设置随机亮度和对比度、对数据进行标准化)
images_train, labels_train = cifar10_input.distorted_inputs(data_dir=data_dir, batch_size=batch_size)

# 再使用cifar10_input.inputs函数生成测试数据,这里不需要进行太多处理
images_test, labels_test = cifar10_input.inputs(eval_data=True,
                        data_dir=data_dir,
                        batch_size=batch_size)

# 创建数据的placeholder
image_holder = tf.placeholder(tf.float32, [batch_size, 24, 24, 3])
label_holder = tf.placeholder(tf.int32, [batch_size])

# 创建第一个卷积层
weight1 = variable_with_weight_loss(shape=[5, 5, 3, 64], stddev=5e-2,
                  w1=0.0)
kernel1 = tf.nn.conv2d(image_holder, weight1, strides=[1, 1, 1, 1], padding='SAME')
bias1 = tf.Variable(tf.constant(0.0, shape=[64]))
conv1 = tf.nn.relu(tf.nn.bias_add(kernel1, bias1))
pool1 = tf.nn.max_pool(conv1, ksize=[1, 3, 3, 1], strides=[1, 2, 2, 1],
            padding='SAME')
# LRN层对ReLU会比较有用,但不适合Sigmoid这种有固定边界并且能抑制过大值的激活函数
norm1 = tf.nn.lrn(pool1, 4, bias=1.0, alpha=0.001 / 9.0, beta=0.75)

# 创建第二个卷积层
weight2 = variable_with_weight_loss(shape=[5, 5, 64, 64], stddev=5e-2,
                  w1=0.0)
kernel2 = tf.nn.conv2d(norm1, weight2, strides=[1, 1, 1, 1], padding='SAME')
bias2 = tf.Variable(tf.constant(0.1, shape=[64]))
conv2 = tf.nn.relu(tf.nn.bias_add(kernel2, bias2))
norm2 = tf.nn.lrn(conv2, 4, bias=1.0, alpha=0.001 / 9.0, beta=0.75)
pool2 = tf.nn.max_pool(norm2, ksize=[1, 3, 3, 1], strides=[1, 2, 2, 1],
            padding='SAME')

# 使用一个全连接层
reshape = tf.reshape(pool2, [batch_size, -1])
dim = reshape.get_shape()[1].value
weight3 = variable_with_weight_loss(shape=[dim, 384], stddev=0.04, w1=0.004)
bias3 = tf.Variable(tf.constant(0.1, shape=[384]))
local3 = tf.nn.relu(tf.matmul(reshape, weight3) + bias3)

# 再使用一个全连接层,隐含节点数下降了一半,只有192个,其他的超参数保持不变
weight4 = variable_with_weight_loss(shape=[384, 192], stddev=0.04, w1=0.004)
bias4 = tf.Variable(tf.constant(0.1, shape=[192]))
local4 = tf.nn.relu(tf.matmul(local3, weight4) + bias4)

# 最后一层,将softmax放在了计算loss部分
weight5 = variable_with_weight_loss(shape=[192, 10], stddev=1 / 192.0, w1=0.0)
bias5 = tf.Variable(tf.constant(0.0, shape=[10]))
logits = tf.add(tf.matmul(local4, weight5), bias5)

# 定义loss
def loss(logits, labels):
  labels = tf.cast(labels, tf.int64)
  cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=labels,
                                  name='cross_entropy_per_example')
  cross_entropy_mean = tf.reduce_mean(cross_entropy, name='cross_entropy')
  tf.add_to_collection('losses', cross_entropy_mean)
  return tf.add_n(tf.get_collection('losses'), name='total_loss')

# 获取最终的loss
loss = loss(logits, label_holder)

# 优化器
train_op = tf.train.AdamOptimizer(1e-3).minimize(loss)

# 使用tf.nn.in_top_k函数求输出结果中top k的准确率,默认使用top 1,也就是输出分数最高的那一类的准确率
top_k_op = tf.nn.in_top_k(logits, label_holder, 1)

# 使用tf.InteractiveSession创建默认的session,接着初始化全部模型参数
sess = tf.InteractiveSession()
tf.global_variables_initializer().run()

# 启动图片数据增强线程
tf.train.start_queue_runners()

# 正式开始训练
for step in range(max_steps):
  start_time = time.time()
  image_batch, label_batch = sess.run([images_train, labels_train])
  _, loss_value = sess.run([train_op, loss], feed_dict={image_holder: image_batch, label_holder: label_batch})
  duration = time.time() - start_time
  if step % 10 == 0:
    example_per_sec = batch_size / duration
    sec_per_batch = float(duration)
    format_str = 'step %d, loss=%.2f ,%.1f examples/sec, %.3f sec/batch'
    print(format_str % (step, loss_value, example_per_sec, sec_per_batch))

num_examples = 10000
import math
num_iter = int(math.ceil(num_examples / batch_size))
true_count = 0
total_sample_count = num_iter * batch_size
step = 0
while step < num_iter:
  image_batch, label_batch = sess.run([images_test, labels_test])
  predictions = sess.run([top_k_op], feed_dict={image_holder: image_batch, label_holder: label_holder})
  true_count += np.sum(predictions)
  step += 1

precision = true_count / total_sample_count
print('precision @ 1 = %.3f'%precision)

运行结果:

Tensorflow卷积神经网络实例进阶

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python线程池的实现实例
Nov 18 Python
Python中使用MELIAE分析程序内存占用实例
Feb 18 Python
Python写的一个简单监控系统
Jun 19 Python
python字典操作实例详解
Nov 16 Python
Flask框架Flask-Login用法分析
Jul 23 Python
python一键去抖音视频水印工具
Sep 14 Python
Python高级特性切片(Slice)操作详解
Sep 27 Python
Python日期时间Time模块实例详解
Apr 15 Python
Python各种扩展名区别点整理
Feb 27 Python
Python IDE环境之 新版Pycharm安装详细教程
Mar 05 Python
pandas 数据类型转换的实现
Dec 29 Python
python实现商品进销存管理系统
May 30 Python
Tensorflow卷积神经网络实例
May 24 #Python
使用pandas的DataFrame的plot方法绘制图像的实例
May 24 #Python
TensorFlow实现卷积神经网络
May 24 #Python
tensorflow实现简单的卷积神经网络
May 24 #Python
tensorflow实现简单的卷积网络
May 24 #Python
解决pandas 作图无法显示中文的问题
May 24 #Python
TensorFlow实现简单卷积神经网络
May 24 #Python
You might like
解析php取整的几种方式
2013/06/25 PHP
PHP函数nl2br()与自定义函数nl2p()换行用法分析
2016/04/02 PHP
php判断str字符串是否是xml格式数据的方法示例
2017/07/26 PHP
laravel框架使用极光推送消息操作示例
2020/02/15 PHP
JQuery Easyui Tree的oncheck事件实现代码
2010/05/28 Javascript
重写javascript中window.confirm的行为
2012/10/21 Javascript
用js来刷新当前页面保留参数的具体实现
2013/12/23 Javascript
js动画效果制件让图片组成动画代码分享
2014/01/14 Javascript
JavaScript中的公有、私有、特权和静态成员用法分析
2014/11/20 Javascript
angularJS中router的使用指南
2015/02/09 Javascript
使用Javascript写的2048小游戏
2015/11/25 Javascript
BootStrap 智能表单实战系列(十)自动完成组件的支持
2016/06/13 Javascript
jQuery弹出层插件popShow用法示例
2017/01/23 Javascript
vue轮播图插件vue-awesome-swiper的使用代码实例
2017/07/10 Javascript
ES6中字符串string常用的新增方法小结
2017/11/07 Javascript
JS中Promise函数then的奥秘探究
2018/07/30 Javascript
详解Angular中通过$location获取地址栏的参数
2018/08/02 Javascript
基于js实现逐步显示文字输出代码实例
2020/04/02 Javascript
进一步探究Python的装饰器的运用
2015/05/05 Python
Python 包含汉字的文件读写之每行末尾加上特定字符
2016/12/12 Python
详解Django中间件的5种自定义方法
2018/07/26 Python
如何利用python制作时间戳转换工具详解
2018/09/12 Python
Python django使用多进程连接mysql错误的解决方法
2018/10/08 Python
Python Tkinter图形工具使用方法及实例解析
2020/06/15 Python
Python爬取股票信息,并可视化数据的示例
2020/09/26 Python
美赞臣营养马来西亚旗舰店:Enfagrow马来西亚
2019/07/26 全球购物
中英双版中文教师求职信
2013/10/27 职场文书
岗位职责的定义
2013/11/10 职场文书
药学专业大专生的自我评价
2013/12/12 职场文书
电子技术专业中专生的自我评价
2013/12/17 职场文书
班长自荐书范文
2014/02/11 职场文书
质量负责人任命书
2014/06/06 职场文书
创先争优宣传标语
2014/10/08 职场文书
群众路线教育实践活动整改落实情况汇报
2014/10/28 职场文书
一篇文章带你搞懂Python类的相关知识
2021/05/20 Python
MySQL中B树索引和B+树索引的区别详解
2022/03/03 MySQL