Tensorflow实现AlexNet卷积神经网络及运算时间评测


Posted in Python onMay 24, 2018

本文实例为大家分享了Tensorflow实现AlexNet卷积神经网络的具体实现代码,供大家参考,具体内容如下

之前已经介绍过了AlexNet的网络构建了,这次主要不是为了训练数据,而是为了对每个batch的前馈(Forward)和反馈(backward)的平均耗时进行计算。在设计网络的过程中,分类的结果很重要,但是运算速率也相当重要。尤其是在跟踪(Tracking)的任务中,如果使用的网络太深,那么也会导致实时性不好。

from datetime import datetime
import math
import time
import tensorflow as tf

batch_size = 32
num_batches = 100

def print_activations(t):
 print(t.op.name, '', t.get_shape().as_list())

def inference(images):
 parameters = []

 with tf.name_scope('conv1') as scope:
  kernel = tf.Variable(tf.truncated_normal([11, 11, 3, 64], dtype = tf.float32, stddev = 1e-1), name = 'weights')
  conv = tf.nn.conv2d(images, kernel, [1, 4, 4, 1], padding = 'SAME')
  biases = tf.Variable(tf.constant(0.0, shape = [64], dtype = tf.float32), trainable = True, name = 'biases')
  bias = tf.nn.bias_add(conv, biases)
  conv1 = tf.nn.relu(bias, name = scope)
  print_activations(conv1)
  parameters += [kernel, biases]

  lrn1 = tf.nn.lrn(conv1, 4, bias = 1.0, alpha = 0.001 / 9, beta = 0.75, name = 'lrn1')
  pool1 = tf.nn.max_pool(lrn1, ksize = [1, 3, 3, 1], strides = [1, 2, 2, 1], padding = 'VALID', name = 'pool1')
  print_activations(pool1)

 with tf.name_scope('conv2') as scope:
  kernel = tf.Variable(tf.truncated_normal([5, 5, 64, 192], dtype = tf.float32, stddev = 1e-1), name = 'weights')
  conv = tf.nn.conv2d(pool1, kernel, [1, 1, 1, 1], padding = 'SAME')
  biases = tf.Variable(tf.constant(0.0, shape = [192], dtype = tf.float32), trainable = True, name = 'biases')
  bias = tf.nn.bias_add(conv, biases)
  conv2 = tf.nn.relu(bias, name = scope)
  parameters += [kernel, biases]
  print_activations(conv2)

  lrn2 = tf.nn.lrn(conv2, 4, bias = 1.0, alpha = 0.001 / 9, beta = 0.75, name = 'lrn2')
  pool2 = tf.nn.max_pool(lrn2, ksize = [1, 3, 3, 1], strides = [1, 2, 2, 1], padding = 'VALID', name = 'pool2')
  print_activations(pool2)

 with tf.name_scope('conv3') as scope:
  kernel = tf.Variable(tf.truncated_normal([3, 3, 192, 384], dtype = tf.float32, stddev = 1e-1), name = 'weights')
  conv = tf.nn.conv2d(pool2, kernel, [1, 1, 1, 1], padding = 'SAME')
  biases = tf.Variable(tf.constant(0.0, shape = [384], dtype = tf.float32), trainable = True, name = 'biases')
  bias = tf.nn.bias_add(conv, biases)
  conv3 = tf.nn.relu(bias, name = scope)
  parameters += [kernel, biases]
  print_activations(conv3)

 with tf.name_scope('conv4') as scope:
  kernel = tf.Variable(tf.truncated_normal([3, 3, 384, 256], dtype = tf.float32, stddev = 1e-1), name = 'weights')
  conv = tf.nn.conv2d(conv3, kernel, [1, 1, 1, 1], padding = 'SAME')
  biases = tf.Variable(tf.constant(0.0, shape = [256], dtype = tf.float32), trainable = True, name = 'biases')
  bias = tf.nn.bias_add(conv, biases)
  conv4 = tf.nn.relu(bias, name = scope)
  parameters += [kernel, biases]
  print_activations(conv4)

 with tf.name_scope('conv5') as scope:
  kernel = tf.Variable(tf.truncated_normal([3, 3, 256, 256], dtype = tf.float32, stddev = 1e-1), name = 'weights')
  conv = tf.nn.conv2d(conv4, kernel, [1, 1, 1, 1], padding = 'SAME')
  biases = tf.Variable(tf.constant(0.0, shape = [256], dtype = tf.float32), trainable = True, name = 'biases')
  bias = tf.nn.bias_add(conv, biases)
  conv5 = tf.nn.relu(bias, name = scope)
  parameters += [kernel, biases]
  print_activations(conv5)

  pool5 = tf.nn.max_pool(conv5, ksize = [1, 3, 3, 1], strides = [1, 2, 2, 1], padding = 'VALID', name = 'pool5')
  print_activations(pool5)

  return pool5, parameters

def time_tensorflow_run(session, target, info_string):
 num_steps_burn_in = 10
 total_duration = 0.0
 total_duration_squared = 0.0

 for i in range(num_batches + num_steps_burn_in):
  start_time = time.time()
  _ = session.run(target)
  duration = time.time() - start_time
  if i >= num_steps_burn_in:
   if not i % 10:
    print('%s: step %d, duration = %.3f' %(datetime.now(), i - num_steps_burn_in, duration))
   total_duration += duration
   total_duration_squared += duration * duration

 mn = total_duration / num_batches
 vr = total_duration_squared / num_batches - mn * mn
 sd = math.sqrt(vr)
 print('%s: %s across %d steps, %.3f +/- %.3f sec / batch' %(datetime.now(), info_string, num_batches, mn, sd))

def run_benchmark():
 with tf.Graph().as_default():
  image_size = 224
  images = tf.Variable(tf.random_normal([batch_size, image_size, image_size, 3], dtype = tf.float32, stddev = 1e-1))
  pool5, parameters = inference(images)

  init = tf.global_variables_initializer()
  sess = tf.Session()
  sess.run(init)

  time_tensorflow_run(sess, pool5, "Forward")

  objective = tf.nn.l2_loss(pool5)
  grad = tf.gradients(objective, parameters)
  time_tensorflow_run(sess, grad, "Forward-backward")


run_benchmark()

这里的代码都是之前讲过的,只是加了一个计算时间和现实网络的卷积核的函数,应该很容易就看懂了,就不多赘述了。我在GTX TITAN X上前馈大概需要0.024s, 反馈大概需要0.079s。哈哈,自己动手试一试哦。

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
用Python的线程来解决生产者消费问题的示例
Apr 02 Python
Python实现根据IP地址和子网掩码算出网段的方法
Jul 30 Python
一个基于flask的web应用诞生 组织结构调整(7)
Apr 11 Python
在django中使用自定义标签实现分页功能
Jul 04 Python
python中numpy的矩阵、多维数组的用法
Feb 05 Python
Python向excel中写入数据的方法
May 05 Python
python异步实现定时任务和周期任务的方法
Jun 29 Python
Python爬虫 scrapy框架爬取某招聘网存入mongodb解析
Jul 31 Python
python except异常处理之后不退出,解决异常继续执行的实现
Apr 25 Python
python 基于wx实现音乐播放
Nov 24 Python
Python爬虫自动化获取华图和粉笔网站的错题(推荐)
Jan 08 Python
Python 避免字典和元组的多重嵌套问题
Jul 15 Python
Tensorflow卷积神经网络实例进阶
May 24 #Python
Tensorflow卷积神经网络实例
May 24 #Python
使用pandas的DataFrame的plot方法绘制图像的实例
May 24 #Python
TensorFlow实现卷积神经网络
May 24 #Python
tensorflow实现简单的卷积神经网络
May 24 #Python
tensorflow实现简单的卷积网络
May 24 #Python
解决pandas 作图无法显示中文的问题
May 24 #Python
You might like
PHP中开启gzip压缩的2种方法
2015/01/31 PHP
php+ajax 实现输入读取数据库显示匹配信息
2015/10/08 PHP
yii2.0整合阿里云oss上传单个文件的示例
2017/09/19 PHP
javascript的console.log()用法小结
2012/05/31 Javascript
jQuery弹性滑动导航菜单实现思路及代码
2013/05/02 Javascript
JQuery中form验证出错信息的查看方法
2013/10/08 Javascript
jquery简单实现滚动条下拉DIV固定在头部不动
2013/11/25 Javascript
window.open()详解及浏览器兼容性问题示例探讨
2014/05/29 Javascript
JS制作手机端自适应缩放显示
2015/06/11 Javascript
js获取iframe中的window对象的实现方法
2016/05/20 Javascript
如何提高Dom访问速度
2017/01/05 Javascript
详解js的异步编程技术的方法
2017/02/09 Javascript
关于Vue的路由权限管理的示例代码
2018/03/06 Javascript
vue生命周期实例小结
2018/08/15 Javascript
vue.js iview打包上线后字体图标不显示解决办法
2020/01/20 Javascript
webpack+vue-cil 中proxyTable配置接口地址代理操作
2020/07/18 Javascript
解决vuex刷新数据消失问题
2020/11/12 Javascript
python中函数默认值使用注意点详解
2016/06/01 Python
Python concurrent.futures模块使用实例
2019/12/24 Python
Python selenium爬取微博数据代码实例
2020/05/22 Python
Python Map 函数的使用
2020/08/28 Python
python自动打开浏览器下载zip并提取内容写入excel
2021/01/04 Python
深入解读CSS3中transform变换模型的渲染
2016/05/27 HTML / CSS
HTML5 Canvas实现360度全景图的示例代码
2018/01/29 HTML / CSS
BIFFI美国站:意大利BIFFI BOUTIQUES豪华多品牌时装零售公司
2020/02/11 全球购物
在DELPHI中调用存储过程和使用内嵌SQL哪种方式更好
2016/11/22 面试题
工程班组长岗位职责
2013/12/30 职场文书
中学生获奖感言
2014/02/04 职场文书
大学毕业感言一句话
2014/02/06 职场文书
博士生导师推荐信
2014/07/08 职场文书
家长评语怎么写
2014/12/30 职场文书
客房领班岗位职责
2015/02/11 职场文书
扩展多台相同的Web服务器
2021/04/01 Servers
SpringBoot2零基础到精通之数据库专项精讲
2022/03/22 Java/Android
Python之matplotlib绘制折线图
2022/04/13 Python
Golang gRPC HTTP协议转换示例
2022/06/16 Golang