Tensorflow实现AlexNet卷积神经网络及运算时间评测


Posted in Python onMay 24, 2018

本文实例为大家分享了Tensorflow实现AlexNet卷积神经网络的具体实现代码,供大家参考,具体内容如下

之前已经介绍过了AlexNet的网络构建了,这次主要不是为了训练数据,而是为了对每个batch的前馈(Forward)和反馈(backward)的平均耗时进行计算。在设计网络的过程中,分类的结果很重要,但是运算速率也相当重要。尤其是在跟踪(Tracking)的任务中,如果使用的网络太深,那么也会导致实时性不好。

from datetime import datetime
import math
import time
import tensorflow as tf

batch_size = 32
num_batches = 100

def print_activations(t):
 print(t.op.name, '', t.get_shape().as_list())

def inference(images):
 parameters = []

 with tf.name_scope('conv1') as scope:
  kernel = tf.Variable(tf.truncated_normal([11, 11, 3, 64], dtype = tf.float32, stddev = 1e-1), name = 'weights')
  conv = tf.nn.conv2d(images, kernel, [1, 4, 4, 1], padding = 'SAME')
  biases = tf.Variable(tf.constant(0.0, shape = [64], dtype = tf.float32), trainable = True, name = 'biases')
  bias = tf.nn.bias_add(conv, biases)
  conv1 = tf.nn.relu(bias, name = scope)
  print_activations(conv1)
  parameters += [kernel, biases]

  lrn1 = tf.nn.lrn(conv1, 4, bias = 1.0, alpha = 0.001 / 9, beta = 0.75, name = 'lrn1')
  pool1 = tf.nn.max_pool(lrn1, ksize = [1, 3, 3, 1], strides = [1, 2, 2, 1], padding = 'VALID', name = 'pool1')
  print_activations(pool1)

 with tf.name_scope('conv2') as scope:
  kernel = tf.Variable(tf.truncated_normal([5, 5, 64, 192], dtype = tf.float32, stddev = 1e-1), name = 'weights')
  conv = tf.nn.conv2d(pool1, kernel, [1, 1, 1, 1], padding = 'SAME')
  biases = tf.Variable(tf.constant(0.0, shape = [192], dtype = tf.float32), trainable = True, name = 'biases')
  bias = tf.nn.bias_add(conv, biases)
  conv2 = tf.nn.relu(bias, name = scope)
  parameters += [kernel, biases]
  print_activations(conv2)

  lrn2 = tf.nn.lrn(conv2, 4, bias = 1.0, alpha = 0.001 / 9, beta = 0.75, name = 'lrn2')
  pool2 = tf.nn.max_pool(lrn2, ksize = [1, 3, 3, 1], strides = [1, 2, 2, 1], padding = 'VALID', name = 'pool2')
  print_activations(pool2)

 with tf.name_scope('conv3') as scope:
  kernel = tf.Variable(tf.truncated_normal([3, 3, 192, 384], dtype = tf.float32, stddev = 1e-1), name = 'weights')
  conv = tf.nn.conv2d(pool2, kernel, [1, 1, 1, 1], padding = 'SAME')
  biases = tf.Variable(tf.constant(0.0, shape = [384], dtype = tf.float32), trainable = True, name = 'biases')
  bias = tf.nn.bias_add(conv, biases)
  conv3 = tf.nn.relu(bias, name = scope)
  parameters += [kernel, biases]
  print_activations(conv3)

 with tf.name_scope('conv4') as scope:
  kernel = tf.Variable(tf.truncated_normal([3, 3, 384, 256], dtype = tf.float32, stddev = 1e-1), name = 'weights')
  conv = tf.nn.conv2d(conv3, kernel, [1, 1, 1, 1], padding = 'SAME')
  biases = tf.Variable(tf.constant(0.0, shape = [256], dtype = tf.float32), trainable = True, name = 'biases')
  bias = tf.nn.bias_add(conv, biases)
  conv4 = tf.nn.relu(bias, name = scope)
  parameters += [kernel, biases]
  print_activations(conv4)

 with tf.name_scope('conv5') as scope:
  kernel = tf.Variable(tf.truncated_normal([3, 3, 256, 256], dtype = tf.float32, stddev = 1e-1), name = 'weights')
  conv = tf.nn.conv2d(conv4, kernel, [1, 1, 1, 1], padding = 'SAME')
  biases = tf.Variable(tf.constant(0.0, shape = [256], dtype = tf.float32), trainable = True, name = 'biases')
  bias = tf.nn.bias_add(conv, biases)
  conv5 = tf.nn.relu(bias, name = scope)
  parameters += [kernel, biases]
  print_activations(conv5)

  pool5 = tf.nn.max_pool(conv5, ksize = [1, 3, 3, 1], strides = [1, 2, 2, 1], padding = 'VALID', name = 'pool5')
  print_activations(pool5)

  return pool5, parameters

def time_tensorflow_run(session, target, info_string):
 num_steps_burn_in = 10
 total_duration = 0.0
 total_duration_squared = 0.0

 for i in range(num_batches + num_steps_burn_in):
  start_time = time.time()
  _ = session.run(target)
  duration = time.time() - start_time
  if i >= num_steps_burn_in:
   if not i % 10:
    print('%s: step %d, duration = %.3f' %(datetime.now(), i - num_steps_burn_in, duration))
   total_duration += duration
   total_duration_squared += duration * duration

 mn = total_duration / num_batches
 vr = total_duration_squared / num_batches - mn * mn
 sd = math.sqrt(vr)
 print('%s: %s across %d steps, %.3f +/- %.3f sec / batch' %(datetime.now(), info_string, num_batches, mn, sd))

def run_benchmark():
 with tf.Graph().as_default():
  image_size = 224
  images = tf.Variable(tf.random_normal([batch_size, image_size, image_size, 3], dtype = tf.float32, stddev = 1e-1))
  pool5, parameters = inference(images)

  init = tf.global_variables_initializer()
  sess = tf.Session()
  sess.run(init)

  time_tensorflow_run(sess, pool5, "Forward")

  objective = tf.nn.l2_loss(pool5)
  grad = tf.gradients(objective, parameters)
  time_tensorflow_run(sess, grad, "Forward-backward")


run_benchmark()

这里的代码都是之前讲过的,只是加了一个计算时间和现实网络的卷积核的函数,应该很容易就看懂了,就不多赘述了。我在GTX TITAN X上前馈大概需要0.024s, 反馈大概需要0.079s。哈哈,自己动手试一试哦。

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
对Python新手编程过程中如何规避一些常见问题的建议
Apr 01 Python
Python中处理时间的几种方法小结
Apr 09 Python
Python字符串匹配算法KMP实例
Jul 18 Python
详解用Python处理HTML转义字符的5种方式
Dec 27 Python
Python实现PS滤镜Fish lens图像扭曲效果示例
Jan 29 Python
python skimage 连通性区域检测方法
Jun 21 Python
python基础 range的用法解析
Aug 23 Python
Pandas 缺失数据处理的实现
Nov 04 Python
Python 实现自动获取种子磁力链接方式
Jan 16 Python
基于Tensorflow高阶读写教程
Feb 10 Python
查看keras各种网络结构各层的名字方式
Jun 11 Python
Numpy中的数组搜索中np.where方法详细介绍
Jan 08 Python
Tensorflow卷积神经网络实例进阶
May 24 #Python
Tensorflow卷积神经网络实例
May 24 #Python
使用pandas的DataFrame的plot方法绘制图像的实例
May 24 #Python
TensorFlow实现卷积神经网络
May 24 #Python
tensorflow实现简单的卷积神经网络
May 24 #Python
tensorflow实现简单的卷积网络
May 24 #Python
解决pandas 作图无法显示中文的问题
May 24 #Python
You might like
是否存在第一台收音机的说法
2021/03/01 无线电
天津市收音机工业发展史
2021/03/04 无线电
thinkphp框架使用JWTtoken的方法详解
2019/10/10 PHP
自写的一个jQuery圆角插件
2010/10/26 Javascript
js、css、img等浏览器缓存问题的2种解决方案
2013/10/23 Javascript
jQuery插件scroll实现无缝滚动效果
2015/04/27 Javascript
解决js图片加载时出现404的问题
2020/11/30 Javascript
javascript实现五星评分功能
2015/11/10 Javascript
JS iFrame加载慢怎么解决
2016/05/13 Javascript
JS动态给对象添加事件的简单方法
2016/07/19 Javascript
Vue.js实现简单动态数据处理
2017/02/13 Javascript
完美实现js选项卡切换效果(一)
2017/03/08 Javascript
vue2.0 子组件改变props值,并向父组件传值的方法
2018/03/01 Javascript
layui插件表单验证提交触发提交的例子
2019/09/09 Javascript
Vue的生命周期操作示例
2019/09/17 Javascript
Flutter实现仿微信底部菜单栏功能
2019/09/18 Javascript
使用Vue调取接口,并渲染数据的示例代码
2019/10/28 Javascript
Vue 实现显示/隐藏层的思路(加全局点击事件)
2019/12/31 Javascript
JS typeof fn === 'function' && fn()详解
2020/08/22 Javascript
vue监听浏览器原生返回按钮,进行路由转跳操作
2020/09/09 Javascript
python如何在循环引用中管理内存
2018/03/20 Python
python实现名片管理系统项目
2019/04/26 Python
如何解决django-celery启动后迅速关闭
2019/10/16 Python
python保存log日志,实现用log日志画图
2019/12/24 Python
python实现图片转字符画的完整代码
2021/02/21 Python
使用css创建三角形 使用CSS3创建3d四面体原理及代码(html5实践)
2013/01/06 HTML / CSS
Mansur Gavriel官网:纽约市的一个设计品牌
2019/05/02 全球购物
GWebs公司笔试题
2012/05/04 面试题
如何开发一个JQuery插件
2016/07/28 面试题
委托证明的格式
2014/01/10 职场文书
数控专业毕业生自荐信范文
2014/03/04 职场文书
培训班开班主持词
2015/07/02 职场文书
2015年高三教学工作总结
2015/07/21 职场文书
小程序与后端Java接口交互实现HelloWorld入门
2021/07/09 Java/Android
Nginx工作模式及代理配置的使用细节
2022/03/21 Servers
Python使用MapReduce进行简单的销售统计
2022/04/22 Python