Tensorflow卷积神经网络实例进阶


Posted in Python onMay 24, 2018

在Tensorflow卷积神经网络实例这篇博客中,我们实现了一个简单的卷积神经网络,没有复杂的Trick。接下来,我们将使用CIFAR-10数据集进行训练。

CIFAR-10是一个经典的数据集,包含60000张32*32的彩色图像,其中训练集50000张,测试集10000张。CIFAR-10如同其名字,一共标注为10类,每一类图片6000张。

本文实现了进阶的卷积神经网络来解决CIFAR-10分类问题,我们使用了一些新的技巧:

  1. 对weights进行了L2的正则化
  2. 对图片进行了翻转、随机剪切等数据增强,制造了更多样本
  3. 在每个卷积-最大池化层后面使用了LRN(局部响应归一化层),增强了模型的泛化能力

首先需要下载Tensorflow models Tensorflow models,以便使用其中的CIFAR-10数据的类.进入目录models/tutorials/image/cifar10目录,执行以下代码

import cifar10
import cifar10_input
import tensorflow as tf
import numpy as np
import time

# 定义batch_size, 训练轮数max_steps, 以及下载CIFAR-10数据的默认路径
max_steps = 3000
batch_size = 128
data_dir = 'E:\\tmp\cifar10_data\cifar-10-batches-bin'

# 定义初始化weight的函数,定义的同时,对weight加一个L2 loss,放在集'losses'中
def variable_with_weight_loss(shape, stddev, w1):
  var = tf.Variable(tf.truncated_normal(shape, stddev=stddev))
  if w1 is not None:
    weight_loss = tf.multiply(tf.nn.l2_loss(var), w1, name='weight_loss')
    tf.add_to_collection('losses', weight_loss)
  return var

# 使用cifar10类下载数据集,并解压、展开到其默认位置
#cifar10.maybe_download_and_extract()

# 在使用cifar10_input类中的distorted_inputs函数产生训练需要使用的数据。需要注意的是,返回的是已经封装好的tensor,
# 且对数据进行了Data Augmentation(水平翻转、随机剪切、设置随机亮度和对比度、对数据进行标准化)
images_train, labels_train = cifar10_input.distorted_inputs(data_dir=data_dir, batch_size=batch_size)

# 再使用cifar10_input.inputs函数生成测试数据,这里不需要进行太多处理
images_test, labels_test = cifar10_input.inputs(eval_data=True,
                        data_dir=data_dir,
                        batch_size=batch_size)

# 创建数据的placeholder
image_holder = tf.placeholder(tf.float32, [batch_size, 24, 24, 3])
label_holder = tf.placeholder(tf.int32, [batch_size])

# 创建第一个卷积层
weight1 = variable_with_weight_loss(shape=[5, 5, 3, 64], stddev=5e-2,
                  w1=0.0)
kernel1 = tf.nn.conv2d(image_holder, weight1, strides=[1, 1, 1, 1], padding='SAME')
bias1 = tf.Variable(tf.constant(0.0, shape=[64]))
conv1 = tf.nn.relu(tf.nn.bias_add(kernel1, bias1))
pool1 = tf.nn.max_pool(conv1, ksize=[1, 3, 3, 1], strides=[1, 2, 2, 1],
            padding='SAME')
# LRN层对ReLU会比较有用,但不适合Sigmoid这种有固定边界并且能抑制过大值的激活函数
norm1 = tf.nn.lrn(pool1, 4, bias=1.0, alpha=0.001 / 9.0, beta=0.75)

# 创建第二个卷积层
weight2 = variable_with_weight_loss(shape=[5, 5, 64, 64], stddev=5e-2,
                  w1=0.0)
kernel2 = tf.nn.conv2d(norm1, weight2, strides=[1, 1, 1, 1], padding='SAME')
bias2 = tf.Variable(tf.constant(0.1, shape=[64]))
conv2 = tf.nn.relu(tf.nn.bias_add(kernel2, bias2))
norm2 = tf.nn.lrn(conv2, 4, bias=1.0, alpha=0.001 / 9.0, beta=0.75)
pool2 = tf.nn.max_pool(norm2, ksize=[1, 3, 3, 1], strides=[1, 2, 2, 1],
            padding='SAME')

# 使用一个全连接层
reshape = tf.reshape(pool2, [batch_size, -1])
dim = reshape.get_shape()[1].value
weight3 = variable_with_weight_loss(shape=[dim, 384], stddev=0.04, w1=0.004)
bias3 = tf.Variable(tf.constant(0.1, shape=[384]))
local3 = tf.nn.relu(tf.matmul(reshape, weight3) + bias3)

# 再使用一个全连接层,隐含节点数下降了一半,只有192个,其他的超参数保持不变
weight4 = variable_with_weight_loss(shape=[384, 192], stddev=0.04, w1=0.004)
bias4 = tf.Variable(tf.constant(0.1, shape=[192]))
local4 = tf.nn.relu(tf.matmul(local3, weight4) + bias4)

# 最后一层,将softmax放在了计算loss部分
weight5 = variable_with_weight_loss(shape=[192, 10], stddev=1 / 192.0, w1=0.0)
bias5 = tf.Variable(tf.constant(0.0, shape=[10]))
logits = tf.add(tf.matmul(local4, weight5), bias5)

# 定义loss
def loss(logits, labels):
  labels = tf.cast(labels, tf.int64)
  cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=labels,
                                  name='cross_entropy_per_example')
  cross_entropy_mean = tf.reduce_mean(cross_entropy, name='cross_entropy')
  tf.add_to_collection('losses', cross_entropy_mean)
  return tf.add_n(tf.get_collection('losses'), name='total_loss')

# 获取最终的loss
loss = loss(logits, label_holder)

# 优化器
train_op = tf.train.AdamOptimizer(1e-3).minimize(loss)

# 使用tf.nn.in_top_k函数求输出结果中top k的准确率,默认使用top 1,也就是输出分数最高的那一类的准确率
top_k_op = tf.nn.in_top_k(logits, label_holder, 1)

# 使用tf.InteractiveSession创建默认的session,接着初始化全部模型参数
sess = tf.InteractiveSession()
tf.global_variables_initializer().run()

# 启动图片数据增强线程
tf.train.start_queue_runners()

# 正式开始训练
for step in range(max_steps):
  start_time = time.time()
  image_batch, label_batch = sess.run([images_train, labels_train])
  _, loss_value = sess.run([train_op, loss], feed_dict={image_holder: image_batch, label_holder: label_batch})
  duration = time.time() - start_time
  if step % 10 == 0:
    example_per_sec = batch_size / duration
    sec_per_batch = float(duration)
    format_str = 'step %d, loss=%.2f ,%.1f examples/sec, %.3f sec/batch'
    print(format_str % (step, loss_value, example_per_sec, sec_per_batch))

num_examples = 10000
import math
num_iter = int(math.ceil(num_examples / batch_size))
true_count = 0
total_sample_count = num_iter * batch_size
step = 0
while step < num_iter:
  image_batch, label_batch = sess.run([images_test, labels_test])
  predictions = sess.run([top_k_op], feed_dict={image_holder: image_batch, label_holder: label_holder})
  true_count += np.sum(predictions)
  step += 1

precision = true_count / total_sample_count
print('precision @ 1 = %.3f'%precision)

运行结果:

Tensorflow卷积神经网络实例进阶

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python多线程结合队列下载百度音乐的方法
Jul 27 Python
使用Python的Bottle框架写一个简单的服务接口的示例
Aug 25 Python
python中pandas.DataFrame对行与列求和及添加新行与列示例
Mar 12 Python
Python查询IP地址归属完整代码
Jun 21 Python
Python正则表达式非贪婪、多行匹配功能示例
Aug 08 Python
python如何创建TCP服务端和客户端
Aug 26 Python
python itchat给指定联系人发消息的方法
Jun 11 Python
基于Python实现人脸自动戴口罩系统
Feb 06 Python
Django模板标签{% for %}循环,获取制定条数据实例
May 14 Python
Python flask框架端口失效解决方案
Jun 04 Python
django中ImageField的使用详解
Dec 21 Python
Python利用socket模块开发简单的端口扫描工具的实现
Jan 27 Python
Tensorflow卷积神经网络实例
May 24 #Python
使用pandas的DataFrame的plot方法绘制图像的实例
May 24 #Python
TensorFlow实现卷积神经网络
May 24 #Python
tensorflow实现简单的卷积神经网络
May 24 #Python
tensorflow实现简单的卷积网络
May 24 #Python
解决pandas 作图无法显示中文的问题
May 24 #Python
TensorFlow实现简单卷积神经网络
May 24 #Python
You might like
PHP字符串逆序排列实现方法小结【strrev函数,二分法,循环法,递归法】
2017/01/13 PHP
PHP Post获取不到非表单数据的问题解决办法
2018/02/27 PHP
thinkphp5框架实现的自定义扩展类操作示例
2019/05/16 PHP
php layui实现前端多图上传实例
2019/07/30 PHP
php扩展开发入门demo示例
2019/09/23 PHP
PHP封装请求类实例分析【基于Yii框架】
2019/10/17 PHP
基于MooTools的很有创意的滚动条时钟动画
2010/11/14 Javascript
关于jQuery参考实例 1.0 jQuery的哲学
2013/04/07 Javascript
用jQuery模拟select下拉框的简单示例代码
2014/01/26 Javascript
window.location 对象所包含的属性
2014/10/10 Javascript
jQuery简单实现日历的方法
2015/05/04 Javascript
JavaScript中调用函数的4种方式代码实例
2015/07/08 Javascript
jQuery文本框得到与失去焦点动态改变样式效果
2016/09/08 Javascript
angularjs中使用ng-bind-html和ng-include的实例
2017/04/28 Javascript
JavaScript实现时间表动态效果
2017/07/15 Javascript
JS与HTML结合实现流程进度展示条思路详解
2017/09/03 Javascript
node基于puppeteer模拟登录抓取页面的实现
2018/05/09 Javascript
Vue CLI3.0中使用jQuery和Bootstrap的方法
2019/02/28 jQuery
js作用域和作用域链及预解析
2019/04/11 Javascript
JS定义函数的几种常用方法小结
2019/05/23 Javascript
原生JavaScript实现日历功能代码实例(无引用Jq)
2019/09/23 Javascript
python实现将html表格转换成CSV文件的方法
2015/06/28 Python
python 把数据 json格式输出的实例代码
2016/10/31 Python
一个基于flask的web应用诞生 bootstrap框架美化(3)
2017/04/11 Python
Python 爬虫之超链接 url中含有中文出错及解决办法
2017/08/03 Python
在Python中使用Neo4j的方法
2019/03/14 Python
Python 70行代码实现简单算式计算器解析
2019/08/30 Python
关于python 跨域处理方式详解
2020/03/28 Python
django使用channels实现通信的示例
2020/10/19 Python
python程序实现BTC(比特币)挖矿的完整代码
2021/01/20 Python
材料化学应届生求职信
2013/10/09 职场文书
诚实守信道德模范事迹材料
2014/08/15 职场文书
消防隐患整改通知书
2015/04/22 职场文书
学校通报表扬范文
2015/05/04 职场文书
党校团干班培训心得体会
2016/01/06 职场文书
关于Redis的主从复制及哨兵问题
2022/06/16 Redis