使用tensorflow实现AlexNet


Posted in Python onNovember 20, 2017

AlexNet是2012年ImageNet比赛的冠军,虽然过去了很长时间,但是作为深度学习中的经典模型,AlexNet不但有助于我们理解其中所使用的很多技巧,而且非常有助于提升我们使用深度学习工具箱的熟练度。尤其是我刚入门深度学习,迫切需要一个能让自己熟悉tensorflow的小练习,于是就有了这个小玩意儿......

先放上我的代码:https://github.com/hjptriplebee/AlexNet_with_tensorflow

如果想运行代码,详细的配置要求都在上面链接的readme文件中了。本文建立在一定的tensorflow基础上,不会对太细的点进行说明。

模型结构

使用tensorflow实现AlexNet

关于模型结构网上的文献很多,我这里不赘述,一会儿都在代码里解释。

有一点需要注意,AlexNet将网络分成了上下两个部分,在论文中两部分结构完全相同,唯一不同的是他们放在不同GPU上训练,因为每一层的feature map之间都是独立的(除了全连接层),所以这相当于是提升训练速度的一种方法。很多AlexNet的复现都将上下两部分合并了,因为他们都是在单个GPU上运行的。虽然我也是在单个GPU上运行,但是我还是很想将最原始的网络结构还原出来,所以我的代码里也是分开的。

模型定义

def maxPoolLayer(x, kHeight, kWidth, strideX, strideY, name, padding = "SAME"): 
  """max-pooling""" 
  return tf.nn.max_pool(x, ksize = [1, kHeight, kWidth, 1], 
             strides = [1, strideX, strideY, 1], padding = padding, name = name) 
 
def dropout(x, keepPro, name = None): 
  """dropout""" 
  return tf.nn.dropout(x, keepPro, name) 
 
def LRN(x, R, alpha, beta, name = None, bias = 1.0): 
  """LRN""" 
  return tf.nn.local_response_normalization(x, depth_radius = R, alpha = alpha, 
                       beta = beta, bias = bias, name = name) 
 
def fcLayer(x, inputD, outputD, reluFlag, name): 
  """fully-connect""" 
  with tf.variable_scope(name) as scope: 
    w = tf.get_variable("w", shape = [inputD, outputD], dtype = "float") 
    b = tf.get_variable("b", [outputD], dtype = "float") 
    out = tf.nn.xw_plus_b(x, w, b, name = scope.name) 
    if reluFlag: 
      return tf.nn.relu(out) 
    else: 
      return out 
 
def convLayer(x, kHeight, kWidth, strideX, strideY, 
       featureNum, name, padding = "SAME", groups = 1):#group为2时等于AlexNet中分上下两部分 
  """convlutional""" 
  channel = int(x.get_shape()[-1])#获取channel 
  conv = lambda a, b: tf.nn.conv2d(a, b, strides = [1, strideY, strideX, 1], padding = padding)#定义卷积的匿名函数 
  with tf.variable_scope(name) as scope: 
    w = tf.get_variable("w", shape = [kHeight, kWidth, channel/groups, featureNum]) 
    b = tf.get_variable("b", shape = [featureNum]) 
 
    xNew = tf.split(value = x, num_or_size_splits = groups, axis = 3)#划分后的输入和权重 
    wNew = tf.split(value = w, num_or_size_splits = groups, axis = 3) 
 
    featureMap = [conv(t1, t2) for t1, t2 in zip(xNew, wNew)] #分别提取feature map 
    mergeFeatureMap = tf.concat(axis = 3, values = featureMap) #feature map整合 
    # print mergeFeatureMap.shape 
    out = tf.nn.bias_add(mergeFeatureMap, b) 
    return tf.nn.relu(tf.reshape(out, mergeFeatureMap.get_shape().as_list()), name = scope.name) #relu后的结果

定义了卷积、pooling、LRN、dropout、全连接五个模块,其中卷积模块因为将网络的上下两部分分开了,所以比较复杂。接下来定义AlexNet。

class alexNet(object): 
  """alexNet model""" 
  def __init__(self, x, keepPro, classNum, skip, modelPath = "bvlc_alexnet.npy"): 
    self.X = x 
    self.KEEPPRO = keepPro 
    self.CLASSNUM = classNum 
    self.SKIP = skip 
    self.MODELPATH = modelPath 
    #build CNN 
    self.buildCNN() 
 
  def buildCNN(self): 
    """build model""" 
    conv1 = convLayer(self.X, 11, 11, 4, 4, 96, "conv1", "VALID") 
    pool1 = maxPoolLayer(conv1, 3, 3, 2, 2, "pool1", "VALID") 
    lrn1 = LRN(pool1, 2, 2e-05, 0.75, "norm1") 
 
    conv2 = convLayer(lrn1, 5, 5, 1, 1, 256, "conv2", groups = 2) 
    pool2 = maxPoolLayer(conv2, 3, 3, 2, 2, "pool2", "VALID") 
    lrn2 = LRN(pool2, 2, 2e-05, 0.75, "lrn2") 
 
    conv3 = convLayer(lrn2, 3, 3, 1, 1, 384, "conv3") 
 
    conv4 = convLayer(conv3, 3, 3, 1, 1, 384, "conv4", groups = 2) 
 
    conv5 = convLayer(conv4, 3, 3, 1, 1, 256, "conv5", groups = 2) 
    pool5 = maxPoolLayer(conv5, 3, 3, 2, 2, "pool5", "VALID") 
 
    fcIn = tf.reshape(pool5, [-1, 256 * 6 * 6]) 
    fc1 = fcLayer(fcIn, 256 * 6 * 6, 4096, True, "fc6") 
    dropout1 = dropout(fc1, self.KEEPPRO) 
 
    fc2 = fcLayer(dropout1, 4096, 4096, True, "fc7") 
    dropout2 = dropout(fc2, self.KEEPPRO) 
 
    self.fc3 = fcLayer(dropout2, 4096, self.CLASSNUM, True, "fc8") 
 
  def loadModel(self, sess): 
    """load model""" 
    wDict = np.load(self.MODELPATH, encoding = "bytes").item() 
    #for layers in model 
    for name in wDict: 
      if name not in self.SKIP: 
        with tf.variable_scope(name, reuse = True): 
          for p in wDict[name]: 
            if len(p.shape) == 1:  
              #bias 只有一维 
              sess.run(tf.get_variable('b', trainable = False).assign(p)) 
            else: 
              #weights  
              sess.run(tf.get_variable('w', trainable = False).assign(p))

buildCNN函数完全按照alexnet的结构搭建网络。
loadModel函数从模型文件中读取参数,采用的模型文件见github上的readme说明。
至此,我们定义了完整的模型,下面开始测试模型。

模型测试

ImageNet训练的AlexNet有很多类,几乎包含所有常见的物体,因此我们随便从网上找几张图片测试。比如我直接用了之前做项目的渣土车图片:

使用tensorflow实现AlexNet

然后编写测试代码:

#some params 
dropoutPro = 1 
classNum = 1000 
skip = [] 
#get testImage 
testPath = "testModel" 
testImg = [] 
for f in os.listdir(testPath): 
  testImg.append(cv2.imread(testPath + "/" + f)) 
 
imgMean = np.array([104, 117, 124], np.float) 
x = tf.placeholder("float", [1, 227, 227, 3]) 
 
model = alexnet.alexNet(x, dropoutPro, classNum, skip) 
score = model.fc3 
softmax = tf.nn.softmax(score) 
 
with tf.Session() as sess: 
  sess.run(tf.global_variables_initializer()) 
  model.loadModel(sess) #加载模型 
 
  for i, img in enumerate(testImg): 
    #img preprocess 
    test = cv2.resize(img.astype(np.float), (227, 227)) #resize成网络输入大小 
    test -= imgMean #去均值 
    test = test.reshape((1, 227, 227, 3)) #拉成tensor 
    maxx = np.argmax(sess.run(softmax, feed_dict = {x: test})) 
    res = caffe_classes.class_names[maxx] #取概率最大类的下标 
    #print(res) 
    font = cv2.FONT_HERSHEY_SIMPLEX 
    cv2.putText(img, res, (int(img.shape[0]/3), int(img.shape[1]/3)), font, 1, (0, 255, 0), 2)#绘制类的名字 
    cv2.imshow("demo", img)  
    cv2.waitKey(5000) #显示5秒

如上代码所示,首先需要设置一些参数,然后读取指定路径下的测试图像,再对模型做一个初始化,最后是真正测试代码。测试结果如下:

使用tensorflow实现AlexNet

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
简单介绍Python中的len()函数的使用
Apr 07 Python
Python批量按比例缩小图片脚本分享
May 21 Python
深入理解NumPy简明教程---数组1
Dec 17 Python
pip安装Python库时遇到的问题及解决方法
Nov 23 Python
pandas使用apply多列生成一列数据的实例
Nov 28 Python
python交易记录链的实现过程详解
Jul 03 Python
基于python修改srt字幕的时间轴
Feb 03 Python
Python语法垃圾回收机制原理解析
Mar 25 Python
Python实现Word表格转成Excel表格的示例代码
Apr 16 Python
基于python和flask实现http接口过程解析
Jun 15 Python
Python实现PS滤镜中的USM锐化效果
Dec 04 Python
python如何查找列表中元素的位置
May 30 Python
Django在win10下的安装并创建工程
Nov 20 #Python
Python2与python3中 for 循环语句基础与实例分析
Nov 20 #Python
Python3中类、模块、错误与异常、文件的简易教程
Nov 20 #Python
Python实现将HTML转换成doc格式文件的方法示例
Nov 20 #Python
python中学习K-Means和图片压缩
Nov 20 #Python
深入理解Python中的super()方法
Nov 20 #Python
python实现读取excel写入mysql的小工具详解
Nov 20 #Python
You might like
PHP概述.
2006/10/09 PHP
PHP中的use关键字概述
2014/07/23 PHP
php获得文件大小和文件创建时间的方法
2015/03/13 PHP
浅析Laravel5中队列的配置及使用
2016/08/04 PHP
PHP实现15位身份证号转18位的方法分析
2019/10/16 PHP
document.compatMode介绍
2009/05/21 Javascript
Jquery异步请求数据实例代码
2011/12/28 Javascript
javascript中验证大写字母、数字和中文
2014/01/15 Javascript
js和css写一个可以自动隐藏的悬浮框
2014/03/05 Javascript
JS表格组件神器bootstrap table详解(基础版)
2015/12/08 Javascript
JS点击缩略图整屏居中放大图片效果
2017/07/04 Javascript
JS中关于正则的巧妙操作
2017/08/31 Javascript
微信小程序实现轮播图效果
2017/09/07 Javascript
微信小程序如何获取openid及用户信息
2018/01/26 Javascript
jQuery实现导航样式布局操作示例【可自定义样式布局】
2018/07/24 jQuery
解决vue接口数据赋值给data没有反应的问题
2018/08/27 Javascript
Vue CLI3 开启gzip压缩文件的方式
2018/09/30 Javascript
详解无限滚动插件vue-infinite-scroll源码解析
2019/05/12 Javascript
nest.js 使用express需要提供多个静态目录的操作方法
2019/10/24 Javascript
js实现登录时记住密码的方法分析
2020/04/05 Javascript
javascript 数组(list)添加/删除的实现
2020/12/17 Javascript
[01:01:25]DOTA2上海特级锦标赛B组资格赛#2 Fnatic VS Spirit第三局
2016/02/27 DOTA
[01:18:35]DOTA2-DPC中国联赛 正赛 Elephant vs LBZS BO3 第一场 1月29日
2021/03/11 DOTA
Python中函数的参数传递与可变长参数介绍
2015/06/30 Python
TensorFlow查看输入节点和输出节点名称方式
2020/01/04 Python
详解Python中import机制
2020/09/11 Python
css3实现椭圆轨迹旋转的示例代码
2018/10/29 HTML / CSS
美国最大的万圣节服装网站:HalloweenCostumes.com
2017/10/12 全球购物
播音主持专业个人自我评价
2014/01/09 职场文书
农民工创业典型事迹
2014/01/25 职场文书
预备党员入党自我评价范文
2014/03/10 职场文书
幼儿教师师德师风演讲稿
2014/08/22 职场文书
学习党的群众路线剖析材料
2014/10/09 职场文书
幼儿园教师工作总结2015
2015/04/02 职场文书
安全教育主题班会总结
2015/08/14 职场文书
《酸的和甜的》教学反思
2016/02/18 职场文书