Tensorflow实现AlexNet卷积神经网络及运算时间评测


Posted in Python onMay 24, 2018

本文实例为大家分享了Tensorflow实现AlexNet卷积神经网络的具体实现代码,供大家参考,具体内容如下

之前已经介绍过了AlexNet的网络构建了,这次主要不是为了训练数据,而是为了对每个batch的前馈(Forward)和反馈(backward)的平均耗时进行计算。在设计网络的过程中,分类的结果很重要,但是运算速率也相当重要。尤其是在跟踪(Tracking)的任务中,如果使用的网络太深,那么也会导致实时性不好。

from datetime import datetime
import math
import time
import tensorflow as tf

batch_size = 32
num_batches = 100

def print_activations(t):
 print(t.op.name, '', t.get_shape().as_list())

def inference(images):
 parameters = []

 with tf.name_scope('conv1') as scope:
  kernel = tf.Variable(tf.truncated_normal([11, 11, 3, 64], dtype = tf.float32, stddev = 1e-1), name = 'weights')
  conv = tf.nn.conv2d(images, kernel, [1, 4, 4, 1], padding = 'SAME')
  biases = tf.Variable(tf.constant(0.0, shape = [64], dtype = tf.float32), trainable = True, name = 'biases')
  bias = tf.nn.bias_add(conv, biases)
  conv1 = tf.nn.relu(bias, name = scope)
  print_activations(conv1)
  parameters += [kernel, biases]

  lrn1 = tf.nn.lrn(conv1, 4, bias = 1.0, alpha = 0.001 / 9, beta = 0.75, name = 'lrn1')
  pool1 = tf.nn.max_pool(lrn1, ksize = [1, 3, 3, 1], strides = [1, 2, 2, 1], padding = 'VALID', name = 'pool1')
  print_activations(pool1)

 with tf.name_scope('conv2') as scope:
  kernel = tf.Variable(tf.truncated_normal([5, 5, 64, 192], dtype = tf.float32, stddev = 1e-1), name = 'weights')
  conv = tf.nn.conv2d(pool1, kernel, [1, 1, 1, 1], padding = 'SAME')
  biases = tf.Variable(tf.constant(0.0, shape = [192], dtype = tf.float32), trainable = True, name = 'biases')
  bias = tf.nn.bias_add(conv, biases)
  conv2 = tf.nn.relu(bias, name = scope)
  parameters += [kernel, biases]
  print_activations(conv2)

  lrn2 = tf.nn.lrn(conv2, 4, bias = 1.0, alpha = 0.001 / 9, beta = 0.75, name = 'lrn2')
  pool2 = tf.nn.max_pool(lrn2, ksize = [1, 3, 3, 1], strides = [1, 2, 2, 1], padding = 'VALID', name = 'pool2')
  print_activations(pool2)

 with tf.name_scope('conv3') as scope:
  kernel = tf.Variable(tf.truncated_normal([3, 3, 192, 384], dtype = tf.float32, stddev = 1e-1), name = 'weights')
  conv = tf.nn.conv2d(pool2, kernel, [1, 1, 1, 1], padding = 'SAME')
  biases = tf.Variable(tf.constant(0.0, shape = [384], dtype = tf.float32), trainable = True, name = 'biases')
  bias = tf.nn.bias_add(conv, biases)
  conv3 = tf.nn.relu(bias, name = scope)
  parameters += [kernel, biases]
  print_activations(conv3)

 with tf.name_scope('conv4') as scope:
  kernel = tf.Variable(tf.truncated_normal([3, 3, 384, 256], dtype = tf.float32, stddev = 1e-1), name = 'weights')
  conv = tf.nn.conv2d(conv3, kernel, [1, 1, 1, 1], padding = 'SAME')
  biases = tf.Variable(tf.constant(0.0, shape = [256], dtype = tf.float32), trainable = True, name = 'biases')
  bias = tf.nn.bias_add(conv, biases)
  conv4 = tf.nn.relu(bias, name = scope)
  parameters += [kernel, biases]
  print_activations(conv4)

 with tf.name_scope('conv5') as scope:
  kernel = tf.Variable(tf.truncated_normal([3, 3, 256, 256], dtype = tf.float32, stddev = 1e-1), name = 'weights')
  conv = tf.nn.conv2d(conv4, kernel, [1, 1, 1, 1], padding = 'SAME')
  biases = tf.Variable(tf.constant(0.0, shape = [256], dtype = tf.float32), trainable = True, name = 'biases')
  bias = tf.nn.bias_add(conv, biases)
  conv5 = tf.nn.relu(bias, name = scope)
  parameters += [kernel, biases]
  print_activations(conv5)

  pool5 = tf.nn.max_pool(conv5, ksize = [1, 3, 3, 1], strides = [1, 2, 2, 1], padding = 'VALID', name = 'pool5')
  print_activations(pool5)

  return pool5, parameters

def time_tensorflow_run(session, target, info_string):
 num_steps_burn_in = 10
 total_duration = 0.0
 total_duration_squared = 0.0

 for i in range(num_batches + num_steps_burn_in):
  start_time = time.time()
  _ = session.run(target)
  duration = time.time() - start_time
  if i >= num_steps_burn_in:
   if not i % 10:
    print('%s: step %d, duration = %.3f' %(datetime.now(), i - num_steps_burn_in, duration))
   total_duration += duration
   total_duration_squared += duration * duration

 mn = total_duration / num_batches
 vr = total_duration_squared / num_batches - mn * mn
 sd = math.sqrt(vr)
 print('%s: %s across %d steps, %.3f +/- %.3f sec / batch' %(datetime.now(), info_string, num_batches, mn, sd))

def run_benchmark():
 with tf.Graph().as_default():
  image_size = 224
  images = tf.Variable(tf.random_normal([batch_size, image_size, image_size, 3], dtype = tf.float32, stddev = 1e-1))
  pool5, parameters = inference(images)

  init = tf.global_variables_initializer()
  sess = tf.Session()
  sess.run(init)

  time_tensorflow_run(sess, pool5, "Forward")

  objective = tf.nn.l2_loss(pool5)
  grad = tf.gradients(objective, parameters)
  time_tensorflow_run(sess, grad, "Forward-backward")


run_benchmark()

这里的代码都是之前讲过的,只是加了一个计算时间和现实网络的卷积核的函数,应该很容易就看懂了,就不多赘述了。我在GTX TITAN X上前馈大概需要0.024s, 反馈大概需要0.079s。哈哈,自己动手试一试哦。

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python通过索引遍历列表的方法
May 04 Python
python 添加用户设置密码并发邮件给root用户
Jul 25 Python
Selenium(Python web测试工具)基本用法详解
Aug 10 Python
Python找出微信上删除你好友的人脚本写法
Nov 01 Python
详解python执行shell脚本创建用户及相关操作
Apr 11 Python
PyQt5实现让QScrollArea支持鼠标拖动的操作方法
Jun 19 Python
django使用django-apscheduler 实现定时任务的例子
Jul 20 Python
Python3操作Excel文件(读写)的简单实例
Sep 02 Python
Python3 Click模块的使用方法详解
Feb 12 Python
python GUI库图形界面开发之PyQt5中QMainWindow, QWidget以及QDialog的区别和选择
Feb 26 Python
Pytorch之Tensor和Numpy之间的转换的实现方法
Sep 03 Python
Python使用openpyxl复制整张sheet
Mar 24 Python
Tensorflow卷积神经网络实例进阶
May 24 #Python
Tensorflow卷积神经网络实例
May 24 #Python
使用pandas的DataFrame的plot方法绘制图像的实例
May 24 #Python
TensorFlow实现卷积神经网络
May 24 #Python
tensorflow实现简单的卷积神经网络
May 24 #Python
tensorflow实现简单的卷积网络
May 24 #Python
解决pandas 作图无法显示中文的问题
May 24 #Python
You might like
PHP 验证码的实现代码
2011/07/17 PHP
Yii 2.0在Grid中格式化时间方法示例
2017/06/06 PHP
CSDN轮换广告图片轮换效果
2007/03/27 Javascript
Iframe thickbox2.0使用的方法
2009/03/05 Javascript
jQuery EasyUI API 中文文档 - MenuButton菜单按钮使用介绍
2011/10/06 Javascript
使用jQuery和Bootstrap实现多层、自适应模态窗口
2014/12/22 Javascript
javascript元素动态创建实现方法
2015/05/13 Javascript
jQuery移动端日期(datedropper)和时间(timedropper)选择器附源码下载
2016/04/19 Javascript
JavaScript定义函数的三种实现方法
2017/09/23 Javascript
express默认日志组件morgan的方法
2018/04/05 Javascript
解决webpack dev-server不能匹配post请求的问题
2018/08/24 Javascript
[01:04:01]2014 DOTA2华西杯精英邀请赛5 24 DK VS VG
2014/05/25 DOTA
python 图片验证码代码
2008/12/07 Python
python访问类中docstring注释的实现方法
2015/05/04 Python
python计算方程式根的方法
2015/05/07 Python
使用Python3 编写简单信用卡管理程序
2016/12/21 Python
Python实现按中文排序的方法示例
2018/04/25 Python
Python多进程原理与用法分析
2018/08/21 Python
如何在 Django 模板中输出 "{{"
2020/01/24 Python
Python pytesseract验证码识别库用法解析
2020/06/29 Python
Python自动巡检H3C交换机实现过程解析
2020/08/14 Python
python 写一个性能测试工具(一)
2020/10/24 Python
匡威意大利官方商店 :Converse意大利
2018/11/27 全球购物
优秀英语专业毕业生求职信
2013/11/23 职场文书
宿舍使用违章电器检讨书
2014/01/12 职场文书
家居饰品店创业计划书
2014/01/31 职场文书
大学生创业策划书
2014/02/02 职场文书
人力资源部经理的岗位职责
2014/03/04 职场文书
消防标语大全
2014/06/07 职场文书
离婚协议书范本及离婚须知
2014/10/15 职场文书
2015年教师节演讲稿范文
2015/03/19 职场文书
2015秋季运动会通讯稿
2015/07/18 职场文书
2019开业庆典剪彩仪式主持词!
2019/07/22 职场文书
导游词之上海豫园
2019/10/24 职场文书
基于go interface{}==nil 的几种坑及原理分析
2021/04/24 Golang
Spring-cloud Config Server的3种配置方式
2021/09/25 Java/Android