Tensorflow卷积神经网络实例进阶


Posted in Python onMay 24, 2018

在Tensorflow卷积神经网络实例这篇博客中,我们实现了一个简单的卷积神经网络,没有复杂的Trick。接下来,我们将使用CIFAR-10数据集进行训练。

CIFAR-10是一个经典的数据集,包含60000张32*32的彩色图像,其中训练集50000张,测试集10000张。CIFAR-10如同其名字,一共标注为10类,每一类图片6000张。

本文实现了进阶的卷积神经网络来解决CIFAR-10分类问题,我们使用了一些新的技巧:

  1. 对weights进行了L2的正则化
  2. 对图片进行了翻转、随机剪切等数据增强,制造了更多样本
  3. 在每个卷积-最大池化层后面使用了LRN(局部响应归一化层),增强了模型的泛化能力

首先需要下载Tensorflow models Tensorflow models,以便使用其中的CIFAR-10数据的类.进入目录models/tutorials/image/cifar10目录,执行以下代码

import cifar10
import cifar10_input
import tensorflow as tf
import numpy as np
import time

# 定义batch_size, 训练轮数max_steps, 以及下载CIFAR-10数据的默认路径
max_steps = 3000
batch_size = 128
data_dir = 'E:\\tmp\cifar10_data\cifar-10-batches-bin'

# 定义初始化weight的函数,定义的同时,对weight加一个L2 loss,放在集'losses'中
def variable_with_weight_loss(shape, stddev, w1):
  var = tf.Variable(tf.truncated_normal(shape, stddev=stddev))
  if w1 is not None:
    weight_loss = tf.multiply(tf.nn.l2_loss(var), w1, name='weight_loss')
    tf.add_to_collection('losses', weight_loss)
  return var

# 使用cifar10类下载数据集,并解压、展开到其默认位置
#cifar10.maybe_download_and_extract()

# 在使用cifar10_input类中的distorted_inputs函数产生训练需要使用的数据。需要注意的是,返回的是已经封装好的tensor,
# 且对数据进行了Data Augmentation(水平翻转、随机剪切、设置随机亮度和对比度、对数据进行标准化)
images_train, labels_train = cifar10_input.distorted_inputs(data_dir=data_dir, batch_size=batch_size)

# 再使用cifar10_input.inputs函数生成测试数据,这里不需要进行太多处理
images_test, labels_test = cifar10_input.inputs(eval_data=True,
                        data_dir=data_dir,
                        batch_size=batch_size)

# 创建数据的placeholder
image_holder = tf.placeholder(tf.float32, [batch_size, 24, 24, 3])
label_holder = tf.placeholder(tf.int32, [batch_size])

# 创建第一个卷积层
weight1 = variable_with_weight_loss(shape=[5, 5, 3, 64], stddev=5e-2,
                  w1=0.0)
kernel1 = tf.nn.conv2d(image_holder, weight1, strides=[1, 1, 1, 1], padding='SAME')
bias1 = tf.Variable(tf.constant(0.0, shape=[64]))
conv1 = tf.nn.relu(tf.nn.bias_add(kernel1, bias1))
pool1 = tf.nn.max_pool(conv1, ksize=[1, 3, 3, 1], strides=[1, 2, 2, 1],
            padding='SAME')
# LRN层对ReLU会比较有用,但不适合Sigmoid这种有固定边界并且能抑制过大值的激活函数
norm1 = tf.nn.lrn(pool1, 4, bias=1.0, alpha=0.001 / 9.0, beta=0.75)

# 创建第二个卷积层
weight2 = variable_with_weight_loss(shape=[5, 5, 64, 64], stddev=5e-2,
                  w1=0.0)
kernel2 = tf.nn.conv2d(norm1, weight2, strides=[1, 1, 1, 1], padding='SAME')
bias2 = tf.Variable(tf.constant(0.1, shape=[64]))
conv2 = tf.nn.relu(tf.nn.bias_add(kernel2, bias2))
norm2 = tf.nn.lrn(conv2, 4, bias=1.0, alpha=0.001 / 9.0, beta=0.75)
pool2 = tf.nn.max_pool(norm2, ksize=[1, 3, 3, 1], strides=[1, 2, 2, 1],
            padding='SAME')

# 使用一个全连接层
reshape = tf.reshape(pool2, [batch_size, -1])
dim = reshape.get_shape()[1].value
weight3 = variable_with_weight_loss(shape=[dim, 384], stddev=0.04, w1=0.004)
bias3 = tf.Variable(tf.constant(0.1, shape=[384]))
local3 = tf.nn.relu(tf.matmul(reshape, weight3) + bias3)

# 再使用一个全连接层,隐含节点数下降了一半,只有192个,其他的超参数保持不变
weight4 = variable_with_weight_loss(shape=[384, 192], stddev=0.04, w1=0.004)
bias4 = tf.Variable(tf.constant(0.1, shape=[192]))
local4 = tf.nn.relu(tf.matmul(local3, weight4) + bias4)

# 最后一层,将softmax放在了计算loss部分
weight5 = variable_with_weight_loss(shape=[192, 10], stddev=1 / 192.0, w1=0.0)
bias5 = tf.Variable(tf.constant(0.0, shape=[10]))
logits = tf.add(tf.matmul(local4, weight5), bias5)

# 定义loss
def loss(logits, labels):
  labels = tf.cast(labels, tf.int64)
  cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=labels,
                                  name='cross_entropy_per_example')
  cross_entropy_mean = tf.reduce_mean(cross_entropy, name='cross_entropy')
  tf.add_to_collection('losses', cross_entropy_mean)
  return tf.add_n(tf.get_collection('losses'), name='total_loss')

# 获取最终的loss
loss = loss(logits, label_holder)

# 优化器
train_op = tf.train.AdamOptimizer(1e-3).minimize(loss)

# 使用tf.nn.in_top_k函数求输出结果中top k的准确率,默认使用top 1,也就是输出分数最高的那一类的准确率
top_k_op = tf.nn.in_top_k(logits, label_holder, 1)

# 使用tf.InteractiveSession创建默认的session,接着初始化全部模型参数
sess = tf.InteractiveSession()
tf.global_variables_initializer().run()

# 启动图片数据增强线程
tf.train.start_queue_runners()

# 正式开始训练
for step in range(max_steps):
  start_time = time.time()
  image_batch, label_batch = sess.run([images_train, labels_train])
  _, loss_value = sess.run([train_op, loss], feed_dict={image_holder: image_batch, label_holder: label_batch})
  duration = time.time() - start_time
  if step % 10 == 0:
    example_per_sec = batch_size / duration
    sec_per_batch = float(duration)
    format_str = 'step %d, loss=%.2f ,%.1f examples/sec, %.3f sec/batch'
    print(format_str % (step, loss_value, example_per_sec, sec_per_batch))

num_examples = 10000
import math
num_iter = int(math.ceil(num_examples / batch_size))
true_count = 0
total_sample_count = num_iter * batch_size
step = 0
while step < num_iter:
  image_batch, label_batch = sess.run([images_test, labels_test])
  predictions = sess.run([top_k_op], feed_dict={image_holder: image_batch, label_holder: label_holder})
  true_count += np.sum(predictions)
  step += 1

precision = true_count / total_sample_count
print('precision @ 1 = %.3f'%precision)

运行结果:

Tensorflow卷积神经网络实例进阶

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python的Django中将文件上传至七牛云存储的代码分享
Jun 03 Python
Python 25行代码实现的RSA算法详解
Apr 10 Python
一个简单的python爬虫程序 爬取豆瓣热度Top100以内的电影信息
Apr 17 Python
python3.6.3安装图文教程 TensorFlow安装配置方法
Jun 24 Python
pycharm使用matplotlib.pyplot不显示图形的解决方法
Oct 28 Python
Python实现简单查找最长子串功能示例
Feb 26 Python
Django项目使用ckeditor详解(不使用admin)
Dec 17 Python
双向RNN:bidirectional_dynamic_rnn()函数的使用详解
Jan 20 Python
Python字符串格式化常用手段及注意事项
Jun 17 Python
Python xlrd/xlwt 创建excel文件及常用操作
Sep 24 Python
利用python调用摄像头的实例分析
Jun 07 Python
利用python做数据拟合详情
Nov 17 Python
Tensorflow卷积神经网络实例
May 24 #Python
使用pandas的DataFrame的plot方法绘制图像的实例
May 24 #Python
TensorFlow实现卷积神经网络
May 24 #Python
tensorflow实现简单的卷积神经网络
May 24 #Python
tensorflow实现简单的卷积网络
May 24 #Python
解决pandas 作图无法显示中文的问题
May 24 #Python
TensorFlow实现简单卷积神经网络
May 24 #Python
You might like
PHP删除数组中的特定元素的代码
2012/06/28 PHP
配置eAccelerator和XCache扩展来加速PHP程序的执行
2015/12/22 PHP
PHP设计模式之简单投诉页面实例
2016/02/24 PHP
PHP实现根据密码长度显示安全条
2017/07/04 PHP
PHP面向对象五大原则之开放-封闭原则(OCP)详解
2018/04/04 PHP
关于div自适应高度/左右高度自适应一致的js代码
2013/03/22 Javascript
javascript模拟地球旋转效果代码实例
2013/12/02 Javascript
javascript创建createXmlHttpRequest对象示例代码
2014/02/10 Javascript
微信小程序 两种为对象属性赋值的方式详解
2017/02/23 Javascript
jQuery实现table中两列CheckBox只能选中一个的示例
2017/09/22 jQuery
vue2.0设置proxyTable使用axios进行跨域请求的方法
2017/10/19 Javascript
iconfont的三种使用方式详解
2018/08/05 Javascript
详解如何模拟实现node中的Events模块(通俗易懂版)
2019/04/15 Javascript
Vue中的组件及路由使用实例代码详解
2019/05/22 Javascript
JS document对象简单用法完整示例
2020/01/14 Javascript
JavaScript代理模式原理与用法实例详解
2020/03/10 Javascript
详细分析vue响应式原理
2020/06/22 Javascript
[01:04:01]2014 DOTA2国际邀请赛中国区预选赛 5 23 CIS VS DT第一场
2014/05/24 DOTA
举例讲解如何在Python编程中进行迭代和遍历
2016/01/19 Python
Python 实现数据库更新脚本的生成方法
2017/07/09 Python
scrapy爬虫完整实例
2018/01/25 Python
Python中单例模式总结
2018/02/20 Python
HTML5响应式(自适应)网页设计的实现
2017/11/17 HTML / CSS
马来西亚网上购物平台:ezbuy
2018/02/13 全球购物
俄罗斯三星品牌商店:Samsungstore
2020/04/05 全球购物
银行门卫岗位职责
2013/12/29 职场文书
物流专业大学生职业生涯规划书范文
2014/01/15 职场文书
《菜园里》教学反思
2014/04/17 职场文书
英语演讲稿3分钟
2014/04/29 职场文书
好书伴我成长演讲稿
2014/05/14 职场文书
奥巴马英文演讲稿
2014/05/15 职场文书
运动会开幕式新闻稿
2015/07/17 职场文书
Go语言使用select{}阻塞main函数介绍
2021/04/25 Golang
Python函数中的不定长参数相关知识总结
2021/06/24 Python
Java多条件判断场景中规则执行器的设计
2021/06/26 Java/Android
javascript遍历对象的五种方式实例代码
2021/10/24 Javascript