Tensorflow实现AlexNet卷积神经网络及运算时间评测


Posted in Python onMay 24, 2018

本文实例为大家分享了Tensorflow实现AlexNet卷积神经网络的具体实现代码,供大家参考,具体内容如下

之前已经介绍过了AlexNet的网络构建了,这次主要不是为了训练数据,而是为了对每个batch的前馈(Forward)和反馈(backward)的平均耗时进行计算。在设计网络的过程中,分类的结果很重要,但是运算速率也相当重要。尤其是在跟踪(Tracking)的任务中,如果使用的网络太深,那么也会导致实时性不好。

from datetime import datetime
import math
import time
import tensorflow as tf

batch_size = 32
num_batches = 100

def print_activations(t):
 print(t.op.name, '', t.get_shape().as_list())

def inference(images):
 parameters = []

 with tf.name_scope('conv1') as scope:
  kernel = tf.Variable(tf.truncated_normal([11, 11, 3, 64], dtype = tf.float32, stddev = 1e-1), name = 'weights')
  conv = tf.nn.conv2d(images, kernel, [1, 4, 4, 1], padding = 'SAME')
  biases = tf.Variable(tf.constant(0.0, shape = [64], dtype = tf.float32), trainable = True, name = 'biases')
  bias = tf.nn.bias_add(conv, biases)
  conv1 = tf.nn.relu(bias, name = scope)
  print_activations(conv1)
  parameters += [kernel, biases]

  lrn1 = tf.nn.lrn(conv1, 4, bias = 1.0, alpha = 0.001 / 9, beta = 0.75, name = 'lrn1')
  pool1 = tf.nn.max_pool(lrn1, ksize = [1, 3, 3, 1], strides = [1, 2, 2, 1], padding = 'VALID', name = 'pool1')
  print_activations(pool1)

 with tf.name_scope('conv2') as scope:
  kernel = tf.Variable(tf.truncated_normal([5, 5, 64, 192], dtype = tf.float32, stddev = 1e-1), name = 'weights')
  conv = tf.nn.conv2d(pool1, kernel, [1, 1, 1, 1], padding = 'SAME')
  biases = tf.Variable(tf.constant(0.0, shape = [192], dtype = tf.float32), trainable = True, name = 'biases')
  bias = tf.nn.bias_add(conv, biases)
  conv2 = tf.nn.relu(bias, name = scope)
  parameters += [kernel, biases]
  print_activations(conv2)

  lrn2 = tf.nn.lrn(conv2, 4, bias = 1.0, alpha = 0.001 / 9, beta = 0.75, name = 'lrn2')
  pool2 = tf.nn.max_pool(lrn2, ksize = [1, 3, 3, 1], strides = [1, 2, 2, 1], padding = 'VALID', name = 'pool2')
  print_activations(pool2)

 with tf.name_scope('conv3') as scope:
  kernel = tf.Variable(tf.truncated_normal([3, 3, 192, 384], dtype = tf.float32, stddev = 1e-1), name = 'weights')
  conv = tf.nn.conv2d(pool2, kernel, [1, 1, 1, 1], padding = 'SAME')
  biases = tf.Variable(tf.constant(0.0, shape = [384], dtype = tf.float32), trainable = True, name = 'biases')
  bias = tf.nn.bias_add(conv, biases)
  conv3 = tf.nn.relu(bias, name = scope)
  parameters += [kernel, biases]
  print_activations(conv3)

 with tf.name_scope('conv4') as scope:
  kernel = tf.Variable(tf.truncated_normal([3, 3, 384, 256], dtype = tf.float32, stddev = 1e-1), name = 'weights')
  conv = tf.nn.conv2d(conv3, kernel, [1, 1, 1, 1], padding = 'SAME')
  biases = tf.Variable(tf.constant(0.0, shape = [256], dtype = tf.float32), trainable = True, name = 'biases')
  bias = tf.nn.bias_add(conv, biases)
  conv4 = tf.nn.relu(bias, name = scope)
  parameters += [kernel, biases]
  print_activations(conv4)

 with tf.name_scope('conv5') as scope:
  kernel = tf.Variable(tf.truncated_normal([3, 3, 256, 256], dtype = tf.float32, stddev = 1e-1), name = 'weights')
  conv = tf.nn.conv2d(conv4, kernel, [1, 1, 1, 1], padding = 'SAME')
  biases = tf.Variable(tf.constant(0.0, shape = [256], dtype = tf.float32), trainable = True, name = 'biases')
  bias = tf.nn.bias_add(conv, biases)
  conv5 = tf.nn.relu(bias, name = scope)
  parameters += [kernel, biases]
  print_activations(conv5)

  pool5 = tf.nn.max_pool(conv5, ksize = [1, 3, 3, 1], strides = [1, 2, 2, 1], padding = 'VALID', name = 'pool5')
  print_activations(pool5)

  return pool5, parameters

def time_tensorflow_run(session, target, info_string):
 num_steps_burn_in = 10
 total_duration = 0.0
 total_duration_squared = 0.0

 for i in range(num_batches + num_steps_burn_in):
  start_time = time.time()
  _ = session.run(target)
  duration = time.time() - start_time
  if i >= num_steps_burn_in:
   if not i % 10:
    print('%s: step %d, duration = %.3f' %(datetime.now(), i - num_steps_burn_in, duration))
   total_duration += duration
   total_duration_squared += duration * duration

 mn = total_duration / num_batches
 vr = total_duration_squared / num_batches - mn * mn
 sd = math.sqrt(vr)
 print('%s: %s across %d steps, %.3f +/- %.3f sec / batch' %(datetime.now(), info_string, num_batches, mn, sd))

def run_benchmark():
 with tf.Graph().as_default():
  image_size = 224
  images = tf.Variable(tf.random_normal([batch_size, image_size, image_size, 3], dtype = tf.float32, stddev = 1e-1))
  pool5, parameters = inference(images)

  init = tf.global_variables_initializer()
  sess = tf.Session()
  sess.run(init)

  time_tensorflow_run(sess, pool5, "Forward")

  objective = tf.nn.l2_loss(pool5)
  grad = tf.gradients(objective, parameters)
  time_tensorflow_run(sess, grad, "Forward-backward")


run_benchmark()

这里的代码都是之前讲过的,只是加了一个计算时间和现实网络的卷积核的函数,应该很容易就看懂了,就不多赘述了。我在GTX TITAN X上前馈大概需要0.024s, 反馈大概需要0.079s。哈哈,自己动手试一试哦。

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python实现截屏的函数
Jul 25 Python
Python实现的微信公众号群发图片与文本消息功能实例详解
Jun 30 Python
python监控文件并且发送告警邮件
Jun 21 Python
python 将对象设置为可迭代的两种实现方法
Jan 21 Python
解决pycharm 远程调试 上传 helpers 卡住的问题
Jun 27 Python
python 基于TCP协议的套接字编程详解
Jun 29 Python
Python OpenCV调用摄像头检测人脸并截图
Aug 20 Python
Pytorch保存模型用于测试和用于继续训练的区别详解
Jan 10 Python
selenium 多窗口切换的实现(windows)
Jan 18 Python
python3实现往mysql中插入datetime类型的数据
Mar 02 Python
Python排序函数的使用方法详解
Dec 11 Python
pandas中DataFrame数据合并连接(merge、join、concat)
May 30 Python
Tensorflow卷积神经网络实例进阶
May 24 #Python
Tensorflow卷积神经网络实例
May 24 #Python
使用pandas的DataFrame的plot方法绘制图像的实例
May 24 #Python
TensorFlow实现卷积神经网络
May 24 #Python
tensorflow实现简单的卷积神经网络
May 24 #Python
tensorflow实现简单的卷积网络
May 24 #Python
解决pandas 作图无法显示中文的问题
May 24 #Python
You might like
php 什么是PEAR?(第三篇)
2009/03/19 PHP
JavaScript初学者需要了解10个小技巧
2010/08/25 Javascript
JavaScript高级程序设计(第3版)学习笔记8 js函数(中)
2012/10/11 Javascript
jquery 表格的增行删行实现思路
2013/03/21 Javascript
图片延迟加载的实现代码(模仿懒惰)
2013/03/29 Javascript
无刷新预览所选择的图片示例代码
2014/04/02 Javascript
jQuery中click事件的定义和用法
2014/12/20 Javascript
使用javascript实现雪花飘落的效果
2015/01/13 Javascript
JavaScript使用push方法添加一个元素到数组末尾用法实例
2015/04/06 Javascript
JS更改select内option属性的方法
2015/10/14 Javascript
JS代码实现table数据分页效果
2016/05/26 Javascript
jQuery 3.0 的 setter和getter 模式详解
2016/07/11 Javascript
使用jQuery实现动态添加小广告
2017/07/11 jQuery
如何更好的编写js async函数
2018/05/13 Javascript
Vue.js中该如何自己维护路由跳转记录
2019/05/19 Javascript
vue实现后台管理权限系统及顶栏三级菜单显示功能
2019/06/19 Javascript
node获取客户端ip功能简单示例
2019/08/24 Javascript
Js视频播放器插件Video.js使用方法详解
2020/02/04 Javascript
antd vue 刷新保留当前页面路由,保留选中菜单,保留menu选中操作
2020/08/06 Javascript
详解在Python中处理异常的教程
2015/05/24 Python
详解Python3中的Sequence type的使用
2015/08/01 Python
详解pandas DataFrame的查询方法(loc,iloc,at,iat,ix的用法和区别)
2019/08/02 Python
python 根据网易云歌曲的ID 直接下载歌曲的实例
2019/08/24 Python
详解Django配置优化方法
2019/11/18 Python
使用matplotlib绘制图例标签中带有公式的图
2019/12/13 Python
详解python安装matplotlib库三种失败情况
2020/07/28 Python
OpenCV利用python来实现图像的直方图均衡化
2020/10/21 Python
Python学习之time模块的基本使用
2021/01/17 Python
用CSS3实现背景渐变的方法
2015/07/14 HTML / CSS
J2SDK1.5与J2SDK5.0有什么区别
2012/09/19 面试题
图书馆志愿者活动总结
2014/06/27 职场文书
个人总结与自我评价2015
2015/03/11 职场文书
基于Go Int转string几种方式性能测试
2021/04/28 Golang
使用tensorflow 实现反向传播求导
2021/05/26 Python
Python 绘制多因子柱状图
2022/05/11 Python
java.util.NoSuchElementException原因及两种解决方法
2022/06/28 Java/Android