Tensorflow卷积神经网络实例进阶


Posted in Python onMay 24, 2018

在Tensorflow卷积神经网络实例这篇博客中,我们实现了一个简单的卷积神经网络,没有复杂的Trick。接下来,我们将使用CIFAR-10数据集进行训练。

CIFAR-10是一个经典的数据集,包含60000张32*32的彩色图像,其中训练集50000张,测试集10000张。CIFAR-10如同其名字,一共标注为10类,每一类图片6000张。

本文实现了进阶的卷积神经网络来解决CIFAR-10分类问题,我们使用了一些新的技巧:

  1. 对weights进行了L2的正则化
  2. 对图片进行了翻转、随机剪切等数据增强,制造了更多样本
  3. 在每个卷积-最大池化层后面使用了LRN(局部响应归一化层),增强了模型的泛化能力

首先需要下载Tensorflow models Tensorflow models,以便使用其中的CIFAR-10数据的类.进入目录models/tutorials/image/cifar10目录,执行以下代码

import cifar10
import cifar10_input
import tensorflow as tf
import numpy as np
import time

# 定义batch_size, 训练轮数max_steps, 以及下载CIFAR-10数据的默认路径
max_steps = 3000
batch_size = 128
data_dir = 'E:\\tmp\cifar10_data\cifar-10-batches-bin'

# 定义初始化weight的函数,定义的同时,对weight加一个L2 loss,放在集'losses'中
def variable_with_weight_loss(shape, stddev, w1):
  var = tf.Variable(tf.truncated_normal(shape, stddev=stddev))
  if w1 is not None:
    weight_loss = tf.multiply(tf.nn.l2_loss(var), w1, name='weight_loss')
    tf.add_to_collection('losses', weight_loss)
  return var

# 使用cifar10类下载数据集,并解压、展开到其默认位置
#cifar10.maybe_download_and_extract()

# 在使用cifar10_input类中的distorted_inputs函数产生训练需要使用的数据。需要注意的是,返回的是已经封装好的tensor,
# 且对数据进行了Data Augmentation(水平翻转、随机剪切、设置随机亮度和对比度、对数据进行标准化)
images_train, labels_train = cifar10_input.distorted_inputs(data_dir=data_dir, batch_size=batch_size)

# 再使用cifar10_input.inputs函数生成测试数据,这里不需要进行太多处理
images_test, labels_test = cifar10_input.inputs(eval_data=True,
                        data_dir=data_dir,
                        batch_size=batch_size)

# 创建数据的placeholder
image_holder = tf.placeholder(tf.float32, [batch_size, 24, 24, 3])
label_holder = tf.placeholder(tf.int32, [batch_size])

# 创建第一个卷积层
weight1 = variable_with_weight_loss(shape=[5, 5, 3, 64], stddev=5e-2,
                  w1=0.0)
kernel1 = tf.nn.conv2d(image_holder, weight1, strides=[1, 1, 1, 1], padding='SAME')
bias1 = tf.Variable(tf.constant(0.0, shape=[64]))
conv1 = tf.nn.relu(tf.nn.bias_add(kernel1, bias1))
pool1 = tf.nn.max_pool(conv1, ksize=[1, 3, 3, 1], strides=[1, 2, 2, 1],
            padding='SAME')
# LRN层对ReLU会比较有用,但不适合Sigmoid这种有固定边界并且能抑制过大值的激活函数
norm1 = tf.nn.lrn(pool1, 4, bias=1.0, alpha=0.001 / 9.0, beta=0.75)

# 创建第二个卷积层
weight2 = variable_with_weight_loss(shape=[5, 5, 64, 64], stddev=5e-2,
                  w1=0.0)
kernel2 = tf.nn.conv2d(norm1, weight2, strides=[1, 1, 1, 1], padding='SAME')
bias2 = tf.Variable(tf.constant(0.1, shape=[64]))
conv2 = tf.nn.relu(tf.nn.bias_add(kernel2, bias2))
norm2 = tf.nn.lrn(conv2, 4, bias=1.0, alpha=0.001 / 9.0, beta=0.75)
pool2 = tf.nn.max_pool(norm2, ksize=[1, 3, 3, 1], strides=[1, 2, 2, 1],
            padding='SAME')

# 使用一个全连接层
reshape = tf.reshape(pool2, [batch_size, -1])
dim = reshape.get_shape()[1].value
weight3 = variable_with_weight_loss(shape=[dim, 384], stddev=0.04, w1=0.004)
bias3 = tf.Variable(tf.constant(0.1, shape=[384]))
local3 = tf.nn.relu(tf.matmul(reshape, weight3) + bias3)

# 再使用一个全连接层,隐含节点数下降了一半,只有192个,其他的超参数保持不变
weight4 = variable_with_weight_loss(shape=[384, 192], stddev=0.04, w1=0.004)
bias4 = tf.Variable(tf.constant(0.1, shape=[192]))
local4 = tf.nn.relu(tf.matmul(local3, weight4) + bias4)

# 最后一层,将softmax放在了计算loss部分
weight5 = variable_with_weight_loss(shape=[192, 10], stddev=1 / 192.0, w1=0.0)
bias5 = tf.Variable(tf.constant(0.0, shape=[10]))
logits = tf.add(tf.matmul(local4, weight5), bias5)

# 定义loss
def loss(logits, labels):
  labels = tf.cast(labels, tf.int64)
  cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=labels,
                                  name='cross_entropy_per_example')
  cross_entropy_mean = tf.reduce_mean(cross_entropy, name='cross_entropy')
  tf.add_to_collection('losses', cross_entropy_mean)
  return tf.add_n(tf.get_collection('losses'), name='total_loss')

# 获取最终的loss
loss = loss(logits, label_holder)

# 优化器
train_op = tf.train.AdamOptimizer(1e-3).minimize(loss)

# 使用tf.nn.in_top_k函数求输出结果中top k的准确率,默认使用top 1,也就是输出分数最高的那一类的准确率
top_k_op = tf.nn.in_top_k(logits, label_holder, 1)

# 使用tf.InteractiveSession创建默认的session,接着初始化全部模型参数
sess = tf.InteractiveSession()
tf.global_variables_initializer().run()

# 启动图片数据增强线程
tf.train.start_queue_runners()

# 正式开始训练
for step in range(max_steps):
  start_time = time.time()
  image_batch, label_batch = sess.run([images_train, labels_train])
  _, loss_value = sess.run([train_op, loss], feed_dict={image_holder: image_batch, label_holder: label_batch})
  duration = time.time() - start_time
  if step % 10 == 0:
    example_per_sec = batch_size / duration
    sec_per_batch = float(duration)
    format_str = 'step %d, loss=%.2f ,%.1f examples/sec, %.3f sec/batch'
    print(format_str % (step, loss_value, example_per_sec, sec_per_batch))

num_examples = 10000
import math
num_iter = int(math.ceil(num_examples / batch_size))
true_count = 0
total_sample_count = num_iter * batch_size
step = 0
while step < num_iter:
  image_batch, label_batch = sess.run([images_test, labels_test])
  predictions = sess.run([top_k_op], feed_dict={image_holder: image_batch, label_holder: label_holder})
  true_count += np.sum(predictions)
  step += 1

precision = true_count / total_sample_count
print('precision @ 1 = %.3f'%precision)

运行结果:

Tensorflow卷积神经网络实例进阶

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python求算数平方根和约数的方法汇总
Mar 09 Python
Python 中的 else详解
Apr 23 Python
浅谈django中的认证与登录
Oct 31 Python
python字典多键值及重复键值的使用方法(详解)
Oct 31 Python
查看django执行的sql语句及消耗时间的两种方法
May 29 Python
python导入模块交叉引用的方法
Jan 19 Python
用Q-learning算法实现自动走迷宫机器人的方法示例
Jun 03 Python
Python 私有化操作实例分析
Nov 21 Python
Python使用Chrome插件实现爬虫过程图解
Jun 09 Python
基于python实现判断字符串是否数字算法
Jul 10 Python
mac安装python3后使用pip和pip3的区别说明
Sep 01 Python
Django实现在线无水印抖音视频下载(附源码及地址)
May 06 Python
Tensorflow卷积神经网络实例
May 24 #Python
使用pandas的DataFrame的plot方法绘制图像的实例
May 24 #Python
TensorFlow实现卷积神经网络
May 24 #Python
tensorflow实现简单的卷积神经网络
May 24 #Python
tensorflow实现简单的卷积网络
May 24 #Python
解决pandas 作图无法显示中文的问题
May 24 #Python
TensorFlow实现简单卷积神经网络
May 24 #Python
You might like
PHP数组对比函数,存在交集则返回真,否则返回假
2011/02/03 PHP
php递归创建目录的方法
2015/02/02 PHP
PHP中基本HTTP认证技巧分析
2015/03/16 PHP
laravel框架使用FormRequest进行表单验证,验证异常返回JSON操作示例
2020/02/18 PHP
JQuery 学习笔记 选择器之四
2009/07/23 Javascript
jQuery入门知识简介
2010/03/04 Javascript
JQuery设置文本框和密码框得到焦点时的样式
2013/08/30 Javascript
精彩的Bootstrap案例分享 重点在注释!(选项卡、栅格布局)
2016/07/01 Javascript
Vue自定义指令拖拽功能示例
2017/02/17 Javascript
layui前段框架日期控件使用方法详解
2017/05/19 Javascript
angular基于ng-alain定义自己的select组件示例
2018/02/23 Javascript
[原创]jquery判断元素内容是否为空的方法
2018/05/04 jQuery
AngularJS 前台分页实现的示例代码
2018/06/07 Javascript
「中高级前端面试」JavaScript手写代码无敌秘籍(推荐)
2019/04/08 Javascript
layer 关闭指定弹出层的例子
2019/09/25 Javascript
Vue的click事件防抖和节流处理详解
2019/11/13 Javascript
JS实现普通轮播图特效
2020/01/01 Javascript
flexible.js实现移动端rem适配方案
2020/04/07 Javascript
Python结巴中文分词工具使用过程中遇到的问题及解决方法
2017/04/15 Python
python实现简单淘宝秒杀功能
2018/05/03 Python
Python3.4学习笔记之类型判断,异常处理,终止程序操作小结
2019/03/01 Python
浅谈pytorch grad_fn以及权重梯度不更新的问题
2019/08/20 Python
python使用配置文件过程详解
2019/12/28 Python
Python面向对象程序设计之继承、多态原理与用法详解
2020/03/23 Python
Python实现读取并写入Excel文件过程解析
2020/05/27 Python
Python编写单元测试代码实例
2020/09/10 Python
Aerosoles爱柔仕官网:美国舒软女鞋品牌
2017/07/17 全球购物
英国拳击装备购物网站:RDX Sports
2018/01/23 全球购物
意大利在线药房:shop-farmacia.it
2019/03/12 全球购物
品恩科技软件测试面试题
2014/10/26 面试题
大学社团计划书
2014/05/01 职场文书
项目负责人任命书
2014/06/04 职场文书
小区门卫的岗位职责
2014/09/26 职场文书
员工辞职信范文
2015/03/02 职场文书
表扬信范文
2015/05/04 职场文书
Python面试不修改数组找出重复的数字
2022/05/20 Python