Tensorflow实现AlexNet卷积神经网络及运算时间评测


Posted in Python onMay 24, 2018

本文实例为大家分享了Tensorflow实现AlexNet卷积神经网络的具体实现代码,供大家参考,具体内容如下

之前已经介绍过了AlexNet的网络构建了,这次主要不是为了训练数据,而是为了对每个batch的前馈(Forward)和反馈(backward)的平均耗时进行计算。在设计网络的过程中,分类的结果很重要,但是运算速率也相当重要。尤其是在跟踪(Tracking)的任务中,如果使用的网络太深,那么也会导致实时性不好。

from datetime import datetime
import math
import time
import tensorflow as tf

batch_size = 32
num_batches = 100

def print_activations(t):
 print(t.op.name, '', t.get_shape().as_list())

def inference(images):
 parameters = []

 with tf.name_scope('conv1') as scope:
  kernel = tf.Variable(tf.truncated_normal([11, 11, 3, 64], dtype = tf.float32, stddev = 1e-1), name = 'weights')
  conv = tf.nn.conv2d(images, kernel, [1, 4, 4, 1], padding = 'SAME')
  biases = tf.Variable(tf.constant(0.0, shape = [64], dtype = tf.float32), trainable = True, name = 'biases')
  bias = tf.nn.bias_add(conv, biases)
  conv1 = tf.nn.relu(bias, name = scope)
  print_activations(conv1)
  parameters += [kernel, biases]

  lrn1 = tf.nn.lrn(conv1, 4, bias = 1.0, alpha = 0.001 / 9, beta = 0.75, name = 'lrn1')
  pool1 = tf.nn.max_pool(lrn1, ksize = [1, 3, 3, 1], strides = [1, 2, 2, 1], padding = 'VALID', name = 'pool1')
  print_activations(pool1)

 with tf.name_scope('conv2') as scope:
  kernel = tf.Variable(tf.truncated_normal([5, 5, 64, 192], dtype = tf.float32, stddev = 1e-1), name = 'weights')
  conv = tf.nn.conv2d(pool1, kernel, [1, 1, 1, 1], padding = 'SAME')
  biases = tf.Variable(tf.constant(0.0, shape = [192], dtype = tf.float32), trainable = True, name = 'biases')
  bias = tf.nn.bias_add(conv, biases)
  conv2 = tf.nn.relu(bias, name = scope)
  parameters += [kernel, biases]
  print_activations(conv2)

  lrn2 = tf.nn.lrn(conv2, 4, bias = 1.0, alpha = 0.001 / 9, beta = 0.75, name = 'lrn2')
  pool2 = tf.nn.max_pool(lrn2, ksize = [1, 3, 3, 1], strides = [1, 2, 2, 1], padding = 'VALID', name = 'pool2')
  print_activations(pool2)

 with tf.name_scope('conv3') as scope:
  kernel = tf.Variable(tf.truncated_normal([3, 3, 192, 384], dtype = tf.float32, stddev = 1e-1), name = 'weights')
  conv = tf.nn.conv2d(pool2, kernel, [1, 1, 1, 1], padding = 'SAME')
  biases = tf.Variable(tf.constant(0.0, shape = [384], dtype = tf.float32), trainable = True, name = 'biases')
  bias = tf.nn.bias_add(conv, biases)
  conv3 = tf.nn.relu(bias, name = scope)
  parameters += [kernel, biases]
  print_activations(conv3)

 with tf.name_scope('conv4') as scope:
  kernel = tf.Variable(tf.truncated_normal([3, 3, 384, 256], dtype = tf.float32, stddev = 1e-1), name = 'weights')
  conv = tf.nn.conv2d(conv3, kernel, [1, 1, 1, 1], padding = 'SAME')
  biases = tf.Variable(tf.constant(0.0, shape = [256], dtype = tf.float32), trainable = True, name = 'biases')
  bias = tf.nn.bias_add(conv, biases)
  conv4 = tf.nn.relu(bias, name = scope)
  parameters += [kernel, biases]
  print_activations(conv4)

 with tf.name_scope('conv5') as scope:
  kernel = tf.Variable(tf.truncated_normal([3, 3, 256, 256], dtype = tf.float32, stddev = 1e-1), name = 'weights')
  conv = tf.nn.conv2d(conv4, kernel, [1, 1, 1, 1], padding = 'SAME')
  biases = tf.Variable(tf.constant(0.0, shape = [256], dtype = tf.float32), trainable = True, name = 'biases')
  bias = tf.nn.bias_add(conv, biases)
  conv5 = tf.nn.relu(bias, name = scope)
  parameters += [kernel, biases]
  print_activations(conv5)

  pool5 = tf.nn.max_pool(conv5, ksize = [1, 3, 3, 1], strides = [1, 2, 2, 1], padding = 'VALID', name = 'pool5')
  print_activations(pool5)

  return pool5, parameters

def time_tensorflow_run(session, target, info_string):
 num_steps_burn_in = 10
 total_duration = 0.0
 total_duration_squared = 0.0

 for i in range(num_batches + num_steps_burn_in):
  start_time = time.time()
  _ = session.run(target)
  duration = time.time() - start_time
  if i >= num_steps_burn_in:
   if not i % 10:
    print('%s: step %d, duration = %.3f' %(datetime.now(), i - num_steps_burn_in, duration))
   total_duration += duration
   total_duration_squared += duration * duration

 mn = total_duration / num_batches
 vr = total_duration_squared / num_batches - mn * mn
 sd = math.sqrt(vr)
 print('%s: %s across %d steps, %.3f +/- %.3f sec / batch' %(datetime.now(), info_string, num_batches, mn, sd))

def run_benchmark():
 with tf.Graph().as_default():
  image_size = 224
  images = tf.Variable(tf.random_normal([batch_size, image_size, image_size, 3], dtype = tf.float32, stddev = 1e-1))
  pool5, parameters = inference(images)

  init = tf.global_variables_initializer()
  sess = tf.Session()
  sess.run(init)

  time_tensorflow_run(sess, pool5, "Forward")

  objective = tf.nn.l2_loss(pool5)
  grad = tf.gradients(objective, parameters)
  time_tensorflow_run(sess, grad, "Forward-backward")


run_benchmark()

这里的代码都是之前讲过的,只是加了一个计算时间和现实网络的卷积核的函数,应该很容易就看懂了,就不多赘述了。我在GTX TITAN X上前馈大概需要0.024s, 反馈大概需要0.079s。哈哈,自己动手试一试哦。

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python 自动提交和抓取网页
Jul 13 Python
python访问纯真IP数据库的代码
May 19 Python
python调用Moxa PCOMM Lite通过串口Ymodem协议实现发送文件
Aug 15 Python
详解Swift中属性的声明与作用
Jun 30 Python
Python线程指南详细介绍
Jan 05 Python
Python使用flask框架操作sqlite3的两种方式
Jan 31 Python
Python登录注册验证功能实现
Jun 18 Python
通过python的matplotlib包将Tensorflow数据进行可视化的方法
Jan 09 Python
python join方法使用详解
Jul 30 Python
TensorFlow实现指数衰减学习率的方法
Feb 05 Python
Python 实现平台类游戏添加跳跃功能
Mar 27 Python
python 实现socket服务端并发的四种方式
Dec 14 Python
Tensorflow卷积神经网络实例进阶
May 24 #Python
Tensorflow卷积神经网络实例
May 24 #Python
使用pandas的DataFrame的plot方法绘制图像的实例
May 24 #Python
TensorFlow实现卷积神经网络
May 24 #Python
tensorflow实现简单的卷积神经网络
May 24 #Python
tensorflow实现简单的卷积网络
May 24 #Python
解决pandas 作图无法显示中文的问题
May 24 #Python
You might like
PHP ajax+jQuery 实现批量删除功能实例代码小结
2018/12/06 PHP
关于PHP5.6+版本“No input file specified”问题的解决
2019/12/11 PHP
ASP.NET中使用后端代码注册脚本 生成JQUERY-EASYUI的界面错位的解决方法
2010/06/12 Javascript
6款经典实用的jQuery小插件及源码(对话框/提示工具等等)
2013/02/04 Javascript
Jquery 实现grid绑定模板
2015/01/28 Javascript
JavaScript表格常用操作方法汇总
2015/04/15 Javascript
javascript图片延迟加载实现方法及思路
2015/12/31 Javascript
jQuery表格插件datatables用法汇总
2016/03/29 Javascript
js+flash实现的5图变换效果广告代码(附演示与demo源码下载)
2016/04/01 Javascript
全面解析Bootstrap中nav、collapse的使用方法
2016/05/22 Javascript
使用jQuery给input标签设置默认值
2016/06/20 Javascript
jQuery+pjax简单示例汇总
2017/04/21 jQuery
layui use 定义js外部引用函数的方法
2019/09/26 Javascript
JavaScript数组排序的六种常见算法总结
2020/08/18 Javascript
基于jquery实现彩色投票进度条代码解析
2020/08/26 jQuery
解决Python print输出不换行没空格的问题
2018/11/14 Python
Django实现一对多表模型的跨表查询方法
2018/12/18 Python
解决python字典对值(值为列表)赋值出现重复的问题
2019/01/20 Python
Python3+OpenCV2实现图像的几何变换(平移、镜像、缩放、旋转、仿射)
2019/05/13 Python
如何更改 pandas dataframe 中两列的位置
2019/12/27 Python
django表单中的按钮获取数据的实例分析
2020/07/31 Python
python实现发送邮件
2021/03/02 Python
CSS3线性渐变简单实现以及该属性在浏览器中的不同
2012/12/12 HTML / CSS
域名注册、建站工具、网页主机、SSL证书:Dynadot
2017/01/06 全球购物
Sunglasses Shop丹麦:欧洲第一的太阳镜在线销售网站
2017/10/22 全球购物
印尼披萨外送专家:Domino’s Pizza印尼
2017/12/28 全球购物
2019年.net常见面试问题
2012/02/12 面试题
农村改厕实施方案
2014/03/22 职场文书
党员对十八届四中全会的期盼思想汇报范文
2014/10/17 职场文书
公司仓库管理制度
2015/08/04 职场文书
2015年度学校应急管理工作总结
2015/10/22 职场文书
中国梦党课学习心得体会
2016/01/05 职场文书
幼儿园心得体会范文
2016/01/21 职场文书
党风廉政教育心得体会2016
2016/01/22 职场文书
python简单验证码识别的实现过程
2021/06/20 Python
索尼ICF-36收音机评测
2022/04/30 无线电