Tensorflow实现AlexNet卷积神经网络及运算时间评测


Posted in Python onMay 24, 2018

本文实例为大家分享了Tensorflow实现AlexNet卷积神经网络的具体实现代码,供大家参考,具体内容如下

之前已经介绍过了AlexNet的网络构建了,这次主要不是为了训练数据,而是为了对每个batch的前馈(Forward)和反馈(backward)的平均耗时进行计算。在设计网络的过程中,分类的结果很重要,但是运算速率也相当重要。尤其是在跟踪(Tracking)的任务中,如果使用的网络太深,那么也会导致实时性不好。

from datetime import datetime
import math
import time
import tensorflow as tf

batch_size = 32
num_batches = 100

def print_activations(t):
 print(t.op.name, '', t.get_shape().as_list())

def inference(images):
 parameters = []

 with tf.name_scope('conv1') as scope:
  kernel = tf.Variable(tf.truncated_normal([11, 11, 3, 64], dtype = tf.float32, stddev = 1e-1), name = 'weights')
  conv = tf.nn.conv2d(images, kernel, [1, 4, 4, 1], padding = 'SAME')
  biases = tf.Variable(tf.constant(0.0, shape = [64], dtype = tf.float32), trainable = True, name = 'biases')
  bias = tf.nn.bias_add(conv, biases)
  conv1 = tf.nn.relu(bias, name = scope)
  print_activations(conv1)
  parameters += [kernel, biases]

  lrn1 = tf.nn.lrn(conv1, 4, bias = 1.0, alpha = 0.001 / 9, beta = 0.75, name = 'lrn1')
  pool1 = tf.nn.max_pool(lrn1, ksize = [1, 3, 3, 1], strides = [1, 2, 2, 1], padding = 'VALID', name = 'pool1')
  print_activations(pool1)

 with tf.name_scope('conv2') as scope:
  kernel = tf.Variable(tf.truncated_normal([5, 5, 64, 192], dtype = tf.float32, stddev = 1e-1), name = 'weights')
  conv = tf.nn.conv2d(pool1, kernel, [1, 1, 1, 1], padding = 'SAME')
  biases = tf.Variable(tf.constant(0.0, shape = [192], dtype = tf.float32), trainable = True, name = 'biases')
  bias = tf.nn.bias_add(conv, biases)
  conv2 = tf.nn.relu(bias, name = scope)
  parameters += [kernel, biases]
  print_activations(conv2)

  lrn2 = tf.nn.lrn(conv2, 4, bias = 1.0, alpha = 0.001 / 9, beta = 0.75, name = 'lrn2')
  pool2 = tf.nn.max_pool(lrn2, ksize = [1, 3, 3, 1], strides = [1, 2, 2, 1], padding = 'VALID', name = 'pool2')
  print_activations(pool2)

 with tf.name_scope('conv3') as scope:
  kernel = tf.Variable(tf.truncated_normal([3, 3, 192, 384], dtype = tf.float32, stddev = 1e-1), name = 'weights')
  conv = tf.nn.conv2d(pool2, kernel, [1, 1, 1, 1], padding = 'SAME')
  biases = tf.Variable(tf.constant(0.0, shape = [384], dtype = tf.float32), trainable = True, name = 'biases')
  bias = tf.nn.bias_add(conv, biases)
  conv3 = tf.nn.relu(bias, name = scope)
  parameters += [kernel, biases]
  print_activations(conv3)

 with tf.name_scope('conv4') as scope:
  kernel = tf.Variable(tf.truncated_normal([3, 3, 384, 256], dtype = tf.float32, stddev = 1e-1), name = 'weights')
  conv = tf.nn.conv2d(conv3, kernel, [1, 1, 1, 1], padding = 'SAME')
  biases = tf.Variable(tf.constant(0.0, shape = [256], dtype = tf.float32), trainable = True, name = 'biases')
  bias = tf.nn.bias_add(conv, biases)
  conv4 = tf.nn.relu(bias, name = scope)
  parameters += [kernel, biases]
  print_activations(conv4)

 with tf.name_scope('conv5') as scope:
  kernel = tf.Variable(tf.truncated_normal([3, 3, 256, 256], dtype = tf.float32, stddev = 1e-1), name = 'weights')
  conv = tf.nn.conv2d(conv4, kernel, [1, 1, 1, 1], padding = 'SAME')
  biases = tf.Variable(tf.constant(0.0, shape = [256], dtype = tf.float32), trainable = True, name = 'biases')
  bias = tf.nn.bias_add(conv, biases)
  conv5 = tf.nn.relu(bias, name = scope)
  parameters += [kernel, biases]
  print_activations(conv5)

  pool5 = tf.nn.max_pool(conv5, ksize = [1, 3, 3, 1], strides = [1, 2, 2, 1], padding = 'VALID', name = 'pool5')
  print_activations(pool5)

  return pool5, parameters

def time_tensorflow_run(session, target, info_string):
 num_steps_burn_in = 10
 total_duration = 0.0
 total_duration_squared = 0.0

 for i in range(num_batches + num_steps_burn_in):
  start_time = time.time()
  _ = session.run(target)
  duration = time.time() - start_time
  if i >= num_steps_burn_in:
   if not i % 10:
    print('%s: step %d, duration = %.3f' %(datetime.now(), i - num_steps_burn_in, duration))
   total_duration += duration
   total_duration_squared += duration * duration

 mn = total_duration / num_batches
 vr = total_duration_squared / num_batches - mn * mn
 sd = math.sqrt(vr)
 print('%s: %s across %d steps, %.3f +/- %.3f sec / batch' %(datetime.now(), info_string, num_batches, mn, sd))

def run_benchmark():
 with tf.Graph().as_default():
  image_size = 224
  images = tf.Variable(tf.random_normal([batch_size, image_size, image_size, 3], dtype = tf.float32, stddev = 1e-1))
  pool5, parameters = inference(images)

  init = tf.global_variables_initializer()
  sess = tf.Session()
  sess.run(init)

  time_tensorflow_run(sess, pool5, "Forward")

  objective = tf.nn.l2_loss(pool5)
  grad = tf.gradients(objective, parameters)
  time_tensorflow_run(sess, grad, "Forward-backward")


run_benchmark()

这里的代码都是之前讲过的,只是加了一个计算时间和现实网络的卷积核的函数,应该很容易就看懂了,就不多赘述了。我在GTX TITAN X上前馈大概需要0.024s, 反馈大概需要0.079s。哈哈,自己动手试一试哦。

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python使用urllib2获取网络资源实例讲解
Dec 02 Python
浅谈Scrapy框架普通反爬虫机制的应对策略
Dec 28 Python
浅谈Python3 numpy.ptp()最大值与最小值的差
Aug 24 Python
Python2与Python3的区别详解
Feb 09 Python
python 计算概率密度、累计分布、逆函数的例子
Feb 25 Python
基于Python数据分析之pandas统计分析
Mar 03 Python
Selenium使用Chrome模拟手机浏览器方法解析
Apr 10 Python
基于python实现音乐播放器代码实例
Jul 01 Python
python如何删除列为空的行
Jul 17 Python
python实现企业微信定时发送文本消息的示例代码
Nov 24 Python
python获取对象信息的实例详解
Jul 07 Python
Tensorflow卷积神经网络实例进阶
May 24 #Python
Tensorflow卷积神经网络实例
May 24 #Python
使用pandas的DataFrame的plot方法绘制图像的实例
May 24 #Python
TensorFlow实现卷积神经网络
May 24 #Python
tensorflow实现简单的卷积神经网络
May 24 #Python
tensorflow实现简单的卷积网络
May 24 #Python
解决pandas 作图无法显示中文的问题
May 24 #Python
You might like
咖啡与牛奶
2021/03/03 冲泡冲煮
php 日期时间处理函数小结
2009/12/18 PHP
PHP中使用json数据格式定义字面量对象的方法
2014/08/20 PHP
php实现二叉树中和为某一值的路径方法
2018/10/14 PHP
Gambit vs CL BO3 第一场 2.13
2021/03/10 DOTA
jquery 利用show和hidden实现级联菜单示例代码
2013/08/09 Javascript
谷歌浏览器不支持showModalDialog模态对话框的解决方法
2014/09/22 Javascript
jQuery实现带分组数据的Table表头排序实例分析
2015/11/24 Javascript
第一次接触神奇的Bootstrap基础排版
2016/07/26 Javascript
javascript中setAttribute兼容性用法分析
2016/12/12 Javascript
iscroll动态加载数据完美解决方法
2017/07/18 Javascript
使用FormData实现上传多个文件
2018/12/04 Javascript
Vue.js项目实战之多语种网站的功能实现(租车)
2019/08/07 Javascript
mpvue微信小程序的接口请求fly全局拦截代码实例
2019/11/13 Javascript
[06:16]DOTA2守卫传承者——职业选手谈心路历程
2015/02/26 DOTA
Python中字典(dict)和列表(list)的排序方法实例
2014/06/16 Python
python3编码问题汇总
2016/09/06 Python
python实现电子产品商店
2019/02/26 Python
python 协程中的迭代器,生成器原理及应用实例详解
2019/10/28 Python
pytorch 状态字典:state_dict使用详解
2020/01/17 Python
使用Keras实现简单线性回归模型操作
2020/06/12 Python
python 获取字典特定值对应的键的实现
2020/09/29 Python
Pycharm快捷键配置详细整理
2020/10/13 Python
浅谈关于html5中图片抛物线运动的一些心得
2018/01/09 HTML / CSS
英国大码女性时装零售商:Evans
2018/08/29 全球购物
给医务人员表扬信
2014/01/12 职场文书
毕业生简历自我评价范文
2014/04/09 职场文书
同居协议书范本
2014/04/23 职场文书
文明演讲稿范文
2014/05/12 职场文书
党员查摆四风问题思想汇报
2014/10/25 职场文书
2015年五一劳动节活动总结
2015/02/09 职场文书
客房服务员岗位职责
2015/02/09 职场文书
2015年学校教科室工作总结
2015/07/20 职场文书
简单实现一个手持弹幕功能+文字抖动特效
2021/03/31 HTML / CSS
JavaScript的function函数详细介绍
2021/11/20 Javascript
Springboot集成kafka高级应用实战分享
2022/08/14 Java/Android