Tensorflow卷积神经网络实例进阶


Posted in Python onMay 24, 2018

在Tensorflow卷积神经网络实例这篇博客中,我们实现了一个简单的卷积神经网络,没有复杂的Trick。接下来,我们将使用CIFAR-10数据集进行训练。

CIFAR-10是一个经典的数据集,包含60000张32*32的彩色图像,其中训练集50000张,测试集10000张。CIFAR-10如同其名字,一共标注为10类,每一类图片6000张。

本文实现了进阶的卷积神经网络来解决CIFAR-10分类问题,我们使用了一些新的技巧:

  1. 对weights进行了L2的正则化
  2. 对图片进行了翻转、随机剪切等数据增强,制造了更多样本
  3. 在每个卷积-最大池化层后面使用了LRN(局部响应归一化层),增强了模型的泛化能力

首先需要下载Tensorflow models Tensorflow models,以便使用其中的CIFAR-10数据的类.进入目录models/tutorials/image/cifar10目录,执行以下代码

import cifar10
import cifar10_input
import tensorflow as tf
import numpy as np
import time

# 定义batch_size, 训练轮数max_steps, 以及下载CIFAR-10数据的默认路径
max_steps = 3000
batch_size = 128
data_dir = 'E:\\tmp\cifar10_data\cifar-10-batches-bin'

# 定义初始化weight的函数,定义的同时,对weight加一个L2 loss,放在集'losses'中
def variable_with_weight_loss(shape, stddev, w1):
  var = tf.Variable(tf.truncated_normal(shape, stddev=stddev))
  if w1 is not None:
    weight_loss = tf.multiply(tf.nn.l2_loss(var), w1, name='weight_loss')
    tf.add_to_collection('losses', weight_loss)
  return var

# 使用cifar10类下载数据集,并解压、展开到其默认位置
#cifar10.maybe_download_and_extract()

# 在使用cifar10_input类中的distorted_inputs函数产生训练需要使用的数据。需要注意的是,返回的是已经封装好的tensor,
# 且对数据进行了Data Augmentation(水平翻转、随机剪切、设置随机亮度和对比度、对数据进行标准化)
images_train, labels_train = cifar10_input.distorted_inputs(data_dir=data_dir, batch_size=batch_size)

# 再使用cifar10_input.inputs函数生成测试数据,这里不需要进行太多处理
images_test, labels_test = cifar10_input.inputs(eval_data=True,
                        data_dir=data_dir,
                        batch_size=batch_size)

# 创建数据的placeholder
image_holder = tf.placeholder(tf.float32, [batch_size, 24, 24, 3])
label_holder = tf.placeholder(tf.int32, [batch_size])

# 创建第一个卷积层
weight1 = variable_with_weight_loss(shape=[5, 5, 3, 64], stddev=5e-2,
                  w1=0.0)
kernel1 = tf.nn.conv2d(image_holder, weight1, strides=[1, 1, 1, 1], padding='SAME')
bias1 = tf.Variable(tf.constant(0.0, shape=[64]))
conv1 = tf.nn.relu(tf.nn.bias_add(kernel1, bias1))
pool1 = tf.nn.max_pool(conv1, ksize=[1, 3, 3, 1], strides=[1, 2, 2, 1],
            padding='SAME')
# LRN层对ReLU会比较有用,但不适合Sigmoid这种有固定边界并且能抑制过大值的激活函数
norm1 = tf.nn.lrn(pool1, 4, bias=1.0, alpha=0.001 / 9.0, beta=0.75)

# 创建第二个卷积层
weight2 = variable_with_weight_loss(shape=[5, 5, 64, 64], stddev=5e-2,
                  w1=0.0)
kernel2 = tf.nn.conv2d(norm1, weight2, strides=[1, 1, 1, 1], padding='SAME')
bias2 = tf.Variable(tf.constant(0.1, shape=[64]))
conv2 = tf.nn.relu(tf.nn.bias_add(kernel2, bias2))
norm2 = tf.nn.lrn(conv2, 4, bias=1.0, alpha=0.001 / 9.0, beta=0.75)
pool2 = tf.nn.max_pool(norm2, ksize=[1, 3, 3, 1], strides=[1, 2, 2, 1],
            padding='SAME')

# 使用一个全连接层
reshape = tf.reshape(pool2, [batch_size, -1])
dim = reshape.get_shape()[1].value
weight3 = variable_with_weight_loss(shape=[dim, 384], stddev=0.04, w1=0.004)
bias3 = tf.Variable(tf.constant(0.1, shape=[384]))
local3 = tf.nn.relu(tf.matmul(reshape, weight3) + bias3)

# 再使用一个全连接层,隐含节点数下降了一半,只有192个,其他的超参数保持不变
weight4 = variable_with_weight_loss(shape=[384, 192], stddev=0.04, w1=0.004)
bias4 = tf.Variable(tf.constant(0.1, shape=[192]))
local4 = tf.nn.relu(tf.matmul(local3, weight4) + bias4)

# 最后一层,将softmax放在了计算loss部分
weight5 = variable_with_weight_loss(shape=[192, 10], stddev=1 / 192.0, w1=0.0)
bias5 = tf.Variable(tf.constant(0.0, shape=[10]))
logits = tf.add(tf.matmul(local4, weight5), bias5)

# 定义loss
def loss(logits, labels):
  labels = tf.cast(labels, tf.int64)
  cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=labels,
                                  name='cross_entropy_per_example')
  cross_entropy_mean = tf.reduce_mean(cross_entropy, name='cross_entropy')
  tf.add_to_collection('losses', cross_entropy_mean)
  return tf.add_n(tf.get_collection('losses'), name='total_loss')

# 获取最终的loss
loss = loss(logits, label_holder)

# 优化器
train_op = tf.train.AdamOptimizer(1e-3).minimize(loss)

# 使用tf.nn.in_top_k函数求输出结果中top k的准确率,默认使用top 1,也就是输出分数最高的那一类的准确率
top_k_op = tf.nn.in_top_k(logits, label_holder, 1)

# 使用tf.InteractiveSession创建默认的session,接着初始化全部模型参数
sess = tf.InteractiveSession()
tf.global_variables_initializer().run()

# 启动图片数据增强线程
tf.train.start_queue_runners()

# 正式开始训练
for step in range(max_steps):
  start_time = time.time()
  image_batch, label_batch = sess.run([images_train, labels_train])
  _, loss_value = sess.run([train_op, loss], feed_dict={image_holder: image_batch, label_holder: label_batch})
  duration = time.time() - start_time
  if step % 10 == 0:
    example_per_sec = batch_size / duration
    sec_per_batch = float(duration)
    format_str = 'step %d, loss=%.2f ,%.1f examples/sec, %.3f sec/batch'
    print(format_str % (step, loss_value, example_per_sec, sec_per_batch))

num_examples = 10000
import math
num_iter = int(math.ceil(num_examples / batch_size))
true_count = 0
total_sample_count = num_iter * batch_size
step = 0
while step < num_iter:
  image_batch, label_batch = sess.run([images_test, labels_test])
  predictions = sess.run([top_k_op], feed_dict={image_holder: image_batch, label_holder: label_holder})
  true_count += np.sum(predictions)
  step += 1

precision = true_count / total_sample_count
print('precision @ 1 = %.3f'%precision)

运行结果:

Tensorflow卷积神经网络实例进阶

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
进一步了解Python中的XML 工具
Apr 13 Python
解决Python传递中文参数的问题
Aug 04 Python
Python 模块EasyGui详细介绍
Feb 19 Python
Python实现GUI学生信息管理系统
Apr 05 Python
Python爬虫框架scrapy实现的文件下载功能示例
Aug 04 Python
python 同时运行多个程序的实例
Jan 07 Python
解决Pyinstaller 打包exe文件 取消dos窗口(黑框框)的问题
Jun 21 Python
python使用opencv对图像mask处理的方法
Jul 05 Python
pytest中文文档之编写断言
Sep 12 Python
postman传递当前时间戳实例详解
Sep 14 Python
Pycharm中Python环境配置常见问题解析
Jan 16 Python
Python合并pdf文件的工具
Jul 01 Python
Tensorflow卷积神经网络实例
May 24 #Python
使用pandas的DataFrame的plot方法绘制图像的实例
May 24 #Python
TensorFlow实现卷积神经网络
May 24 #Python
tensorflow实现简单的卷积神经网络
May 24 #Python
tensorflow实现简单的卷积网络
May 24 #Python
解决pandas 作图无法显示中文的问题
May 24 #Python
TensorFlow实现简单卷积神经网络
May 24 #Python
You might like
TP5(thinkPHP5)框架使用ajax实现与后台数据交互的方法小结
2020/02/10 PHP
Jquery插件 easyUI属性汇总
2011/01/19 Javascript
Javascript 实现的数独解题算法网页实例
2013/10/15 Javascript
Jquery的each里用return true或false代替break或continue
2014/05/21 Javascript
Javascript动态创建表格及删除行列的方法
2015/05/15 Javascript
全面解析JS字符串和正则表达式中的match、replace、exec等函数
2016/07/01 Javascript
vue开发心得和技巧分享
2016/10/27 Javascript
通过button将form表单的数据提交到action层的实例
2017/09/08 Javascript
详解Vue 全局引入bass.scss 处理方案
2018/03/26 Javascript
实例解析Vue.js下载方式及基本概念
2018/05/11 Javascript
vue使用代理解决请求跨域问题详解
2019/07/24 Javascript
vue el-table实现行内编辑功能
2019/12/11 Javascript
js 实现碰撞检测的示例
2020/10/28 Javascript
详解python发送各类邮件的主要方法
2016/12/22 Python
python安装模块如何通过setup.py安装(超简单)
2018/05/05 Python
无法使用pip命令安装python第三方库的原因及解决方法
2018/06/12 Python
在Django中URL正则表达式匹配的方法
2018/12/20 Python
Anaconda 查看、创建、管理和使用python环境的方法
2019/12/03 Python
python 实现rolling和apply函数的向下取值操作
2020/06/08 Python
Python Matplotlib绘图基础知识代码解析
2020/08/31 Python
pycharm专业版远程登录服务器的详细教程
2020/09/15 Python
css3实现平移效果(transfrom:translate)的示例
2020/11/13 HTML / CSS
联想新加坡官方网站:Lenovo Singapore
2017/10/24 全球购物
Vans澳大利亚官网:购买鞋子、服装及配件
2019/09/05 全球购物
2014年元旦促销活动方案
2014/02/22 职场文书
公司管理建议书范文
2014/03/12 职场文书
蟋蟀的住宅教学反思
2014/04/26 职场文书
乡镇党员群众路线教育实践活动对照检查材料思想汇报
2014/10/05 职场文书
个人工作作风整改措施思想汇报
2014/10/13 职场文书
收费员岗位职责
2015/02/14 职场文书
幼师求职自荐信
2015/03/26 职场文书
不同意离婚代理词
2015/05/23 职场文书
红色电影观后感
2015/06/18 职场文书
使用pandas生成/读取csv文件的方法实例
2021/07/09 Python
JavaScript架构搭建前端监控如何采集异常数据
2022/06/25 Javascript
MySQL中TIMESTAMP类型返回日期时间数据中带有T的解决
2022/12/24 MySQL