Tensorflow实现AlexNet卷积神经网络及运算时间评测


Posted in Python onMay 24, 2018

本文实例为大家分享了Tensorflow实现AlexNet卷积神经网络的具体实现代码,供大家参考,具体内容如下

之前已经介绍过了AlexNet的网络构建了,这次主要不是为了训练数据,而是为了对每个batch的前馈(Forward)和反馈(backward)的平均耗时进行计算。在设计网络的过程中,分类的结果很重要,但是运算速率也相当重要。尤其是在跟踪(Tracking)的任务中,如果使用的网络太深,那么也会导致实时性不好。

from datetime import datetime
import math
import time
import tensorflow as tf

batch_size = 32
num_batches = 100

def print_activations(t):
 print(t.op.name, '', t.get_shape().as_list())

def inference(images):
 parameters = []

 with tf.name_scope('conv1') as scope:
  kernel = tf.Variable(tf.truncated_normal([11, 11, 3, 64], dtype = tf.float32, stddev = 1e-1), name = 'weights')
  conv = tf.nn.conv2d(images, kernel, [1, 4, 4, 1], padding = 'SAME')
  biases = tf.Variable(tf.constant(0.0, shape = [64], dtype = tf.float32), trainable = True, name = 'biases')
  bias = tf.nn.bias_add(conv, biases)
  conv1 = tf.nn.relu(bias, name = scope)
  print_activations(conv1)
  parameters += [kernel, biases]

  lrn1 = tf.nn.lrn(conv1, 4, bias = 1.0, alpha = 0.001 / 9, beta = 0.75, name = 'lrn1')
  pool1 = tf.nn.max_pool(lrn1, ksize = [1, 3, 3, 1], strides = [1, 2, 2, 1], padding = 'VALID', name = 'pool1')
  print_activations(pool1)

 with tf.name_scope('conv2') as scope:
  kernel = tf.Variable(tf.truncated_normal([5, 5, 64, 192], dtype = tf.float32, stddev = 1e-1), name = 'weights')
  conv = tf.nn.conv2d(pool1, kernel, [1, 1, 1, 1], padding = 'SAME')
  biases = tf.Variable(tf.constant(0.0, shape = [192], dtype = tf.float32), trainable = True, name = 'biases')
  bias = tf.nn.bias_add(conv, biases)
  conv2 = tf.nn.relu(bias, name = scope)
  parameters += [kernel, biases]
  print_activations(conv2)

  lrn2 = tf.nn.lrn(conv2, 4, bias = 1.0, alpha = 0.001 / 9, beta = 0.75, name = 'lrn2')
  pool2 = tf.nn.max_pool(lrn2, ksize = [1, 3, 3, 1], strides = [1, 2, 2, 1], padding = 'VALID', name = 'pool2')
  print_activations(pool2)

 with tf.name_scope('conv3') as scope:
  kernel = tf.Variable(tf.truncated_normal([3, 3, 192, 384], dtype = tf.float32, stddev = 1e-1), name = 'weights')
  conv = tf.nn.conv2d(pool2, kernel, [1, 1, 1, 1], padding = 'SAME')
  biases = tf.Variable(tf.constant(0.0, shape = [384], dtype = tf.float32), trainable = True, name = 'biases')
  bias = tf.nn.bias_add(conv, biases)
  conv3 = tf.nn.relu(bias, name = scope)
  parameters += [kernel, biases]
  print_activations(conv3)

 with tf.name_scope('conv4') as scope:
  kernel = tf.Variable(tf.truncated_normal([3, 3, 384, 256], dtype = tf.float32, stddev = 1e-1), name = 'weights')
  conv = tf.nn.conv2d(conv3, kernel, [1, 1, 1, 1], padding = 'SAME')
  biases = tf.Variable(tf.constant(0.0, shape = [256], dtype = tf.float32), trainable = True, name = 'biases')
  bias = tf.nn.bias_add(conv, biases)
  conv4 = tf.nn.relu(bias, name = scope)
  parameters += [kernel, biases]
  print_activations(conv4)

 with tf.name_scope('conv5') as scope:
  kernel = tf.Variable(tf.truncated_normal([3, 3, 256, 256], dtype = tf.float32, stddev = 1e-1), name = 'weights')
  conv = tf.nn.conv2d(conv4, kernel, [1, 1, 1, 1], padding = 'SAME')
  biases = tf.Variable(tf.constant(0.0, shape = [256], dtype = tf.float32), trainable = True, name = 'biases')
  bias = tf.nn.bias_add(conv, biases)
  conv5 = tf.nn.relu(bias, name = scope)
  parameters += [kernel, biases]
  print_activations(conv5)

  pool5 = tf.nn.max_pool(conv5, ksize = [1, 3, 3, 1], strides = [1, 2, 2, 1], padding = 'VALID', name = 'pool5')
  print_activations(pool5)

  return pool5, parameters

def time_tensorflow_run(session, target, info_string):
 num_steps_burn_in = 10
 total_duration = 0.0
 total_duration_squared = 0.0

 for i in range(num_batches + num_steps_burn_in):
  start_time = time.time()
  _ = session.run(target)
  duration = time.time() - start_time
  if i >= num_steps_burn_in:
   if not i % 10:
    print('%s: step %d, duration = %.3f' %(datetime.now(), i - num_steps_burn_in, duration))
   total_duration += duration
   total_duration_squared += duration * duration

 mn = total_duration / num_batches
 vr = total_duration_squared / num_batches - mn * mn
 sd = math.sqrt(vr)
 print('%s: %s across %d steps, %.3f +/- %.3f sec / batch' %(datetime.now(), info_string, num_batches, mn, sd))

def run_benchmark():
 with tf.Graph().as_default():
  image_size = 224
  images = tf.Variable(tf.random_normal([batch_size, image_size, image_size, 3], dtype = tf.float32, stddev = 1e-1))
  pool5, parameters = inference(images)

  init = tf.global_variables_initializer()
  sess = tf.Session()
  sess.run(init)

  time_tensorflow_run(sess, pool5, "Forward")

  objective = tf.nn.l2_loss(pool5)
  grad = tf.gradients(objective, parameters)
  time_tensorflow_run(sess, grad, "Forward-backward")


run_benchmark()

这里的代码都是之前讲过的,只是加了一个计算时间和现实网络的卷积核的函数,应该很容易就看懂了,就不多赘述了。我在GTX TITAN X上前馈大概需要0.024s, 反馈大概需要0.079s。哈哈,自己动手试一试哦。

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python支持断点续传的多线程下载示例
Jan 16 Python
python实现simhash算法实例
Apr 25 Python
Python实现的简单万年历例子分享
Apr 25 Python
Python+django实现文件下载
Jan 17 Python
Python的标准模块包json详解
Mar 13 Python
对python周期性定时器的示例详解
Feb 19 Python
通过cmd进入python的实例操作
Jun 26 Python
Python中print函数简单使用总结
Aug 05 Python
python matplotlib拟合直线的实现
Nov 19 Python
对Matlab中共轭、转置和共轭装置的区别说明
May 11 Python
Python实现发票自动校核微信机器人的方法
May 22 Python
Python基于template实现字符串替换
Nov 27 Python
Tensorflow卷积神经网络实例进阶
May 24 #Python
Tensorflow卷积神经网络实例
May 24 #Python
使用pandas的DataFrame的plot方法绘制图像的实例
May 24 #Python
TensorFlow实现卷积神经网络
May 24 #Python
tensorflow实现简单的卷积神经网络
May 24 #Python
tensorflow实现简单的卷积网络
May 24 #Python
解决pandas 作图无法显示中文的问题
May 24 #Python
You might like
PHP实现设计模式中的抽象工厂模式详解
2014/10/11 PHP
php输出金字塔的2种实现方法
2014/12/16 PHP
PHP中使用Imagick操作PSD文件实例
2015/01/26 PHP
PHP+Ajax 检测网络是否正常实例详解
2016/12/16 PHP
浅谈ThinkPHP5.0版本和ThinkPHP3.2版本的区别
2017/06/17 PHP
PHP实现mysqli批量执行多条语句的方法示例
2017/07/22 PHP
javascript delete 使用示例代码
2010/03/29 Javascript
js模拟类继承小例子
2010/07/17 Javascript
javascrpt绑定事件之匿名函数无法解除绑定问题
2012/12/06 Javascript
jqeury-easyui-layout问题解决方法
2014/03/24 Javascript
js+html5实现半透明遮罩层弹框效果
2020/08/24 Javascript
vue+vue-validator 表单验证功能的实现代码
2017/11/13 Javascript
vue2手机APP项目添加开屏广告或者闪屏广告
2017/11/28 Javascript
layer.confirm取消按钮绑定事件的方法
2018/08/17 Javascript
利用chrome浏览器进行js调试并找出元素绑定的点击事件详解
2021/01/30 Javascript
微信小程序实现页面跳转传递参数(实体,对象)
2019/08/12 Javascript
解决layui中onchange失效以及form动态渲染失效的问题
2019/09/27 Javascript
浅析vue cli3 封装Svgicon组件正确姿势(推荐)
2020/04/27 Javascript
ssm+vue前后端分离框架整合实现(附源码)
2020/07/08 Javascript
解决echarts 一条柱状图显示两个值,类似进度条的问题
2020/07/20 Javascript
[02:43]中国五虎出征TI3视频
2013/08/02 DOTA
python处理PHP数组文本文件实例
2014/09/18 Python
Python上传package到Pypi(代码简单)
2016/02/06 Python
Python面向对象编程基础解析(一)
2017/10/26 Python
Pycharm设置界面全黑的方法
2018/05/23 Python
使用Python实现从各个子文件夹中复制指定文件的方法
2018/10/25 Python
基于python实现自动化办公学习笔记(CSV、word、Excel、PPT)
2019/08/06 Python
pytorch实现用Resnet提取特征并保存为txt文件的方法
2019/08/20 Python
美国在线旅行社:Crystal Travel
2018/09/11 全球购物
JSP&Servlet技术面试题
2015/05/21 面试题
求职推荐信
2013/10/28 职场文书
乔迁宴答谢词
2014/01/21 职场文书
办公室主任岗位承诺书
2014/05/29 职场文书
2015年个人审计工作总结
2015/04/07 职场文书
小学学习委员竞选稿
2015/11/20 职场文书
班干部竞选演讲稿(精选5篇)
2019/09/24 职场文书