使用tensorflow实现AlexNet


Posted in Python onNovember 20, 2017

AlexNet是2012年ImageNet比赛的冠军,虽然过去了很长时间,但是作为深度学习中的经典模型,AlexNet不但有助于我们理解其中所使用的很多技巧,而且非常有助于提升我们使用深度学习工具箱的熟练度。尤其是我刚入门深度学习,迫切需要一个能让自己熟悉tensorflow的小练习,于是就有了这个小玩意儿......

先放上我的代码:https://github.com/hjptriplebee/AlexNet_with_tensorflow

如果想运行代码,详细的配置要求都在上面链接的readme文件中了。本文建立在一定的tensorflow基础上,不会对太细的点进行说明。

模型结构

使用tensorflow实现AlexNet

关于模型结构网上的文献很多,我这里不赘述,一会儿都在代码里解释。

有一点需要注意,AlexNet将网络分成了上下两个部分,在论文中两部分结构完全相同,唯一不同的是他们放在不同GPU上训练,因为每一层的feature map之间都是独立的(除了全连接层),所以这相当于是提升训练速度的一种方法。很多AlexNet的复现都将上下两部分合并了,因为他们都是在单个GPU上运行的。虽然我也是在单个GPU上运行,但是我还是很想将最原始的网络结构还原出来,所以我的代码里也是分开的。

模型定义

def maxPoolLayer(x, kHeight, kWidth, strideX, strideY, name, padding = "SAME"): 
  """max-pooling""" 
  return tf.nn.max_pool(x, ksize = [1, kHeight, kWidth, 1], 
             strides = [1, strideX, strideY, 1], padding = padding, name = name) 
 
def dropout(x, keepPro, name = None): 
  """dropout""" 
  return tf.nn.dropout(x, keepPro, name) 
 
def LRN(x, R, alpha, beta, name = None, bias = 1.0): 
  """LRN""" 
  return tf.nn.local_response_normalization(x, depth_radius = R, alpha = alpha, 
                       beta = beta, bias = bias, name = name) 
 
def fcLayer(x, inputD, outputD, reluFlag, name): 
  """fully-connect""" 
  with tf.variable_scope(name) as scope: 
    w = tf.get_variable("w", shape = [inputD, outputD], dtype = "float") 
    b = tf.get_variable("b", [outputD], dtype = "float") 
    out = tf.nn.xw_plus_b(x, w, b, name = scope.name) 
    if reluFlag: 
      return tf.nn.relu(out) 
    else: 
      return out 
 
def convLayer(x, kHeight, kWidth, strideX, strideY, 
       featureNum, name, padding = "SAME", groups = 1):#group为2时等于AlexNet中分上下两部分 
  """convlutional""" 
  channel = int(x.get_shape()[-1])#获取channel 
  conv = lambda a, b: tf.nn.conv2d(a, b, strides = [1, strideY, strideX, 1], padding = padding)#定义卷积的匿名函数 
  with tf.variable_scope(name) as scope: 
    w = tf.get_variable("w", shape = [kHeight, kWidth, channel/groups, featureNum]) 
    b = tf.get_variable("b", shape = [featureNum]) 
 
    xNew = tf.split(value = x, num_or_size_splits = groups, axis = 3)#划分后的输入和权重 
    wNew = tf.split(value = w, num_or_size_splits = groups, axis = 3) 
 
    featureMap = [conv(t1, t2) for t1, t2 in zip(xNew, wNew)] #分别提取feature map 
    mergeFeatureMap = tf.concat(axis = 3, values = featureMap) #feature map整合 
    # print mergeFeatureMap.shape 
    out = tf.nn.bias_add(mergeFeatureMap, b) 
    return tf.nn.relu(tf.reshape(out, mergeFeatureMap.get_shape().as_list()), name = scope.name) #relu后的结果

定义了卷积、pooling、LRN、dropout、全连接五个模块,其中卷积模块因为将网络的上下两部分分开了,所以比较复杂。接下来定义AlexNet。

class alexNet(object): 
  """alexNet model""" 
  def __init__(self, x, keepPro, classNum, skip, modelPath = "bvlc_alexnet.npy"): 
    self.X = x 
    self.KEEPPRO = keepPro 
    self.CLASSNUM = classNum 
    self.SKIP = skip 
    self.MODELPATH = modelPath 
    #build CNN 
    self.buildCNN() 
 
  def buildCNN(self): 
    """build model""" 
    conv1 = convLayer(self.X, 11, 11, 4, 4, 96, "conv1", "VALID") 
    pool1 = maxPoolLayer(conv1, 3, 3, 2, 2, "pool1", "VALID") 
    lrn1 = LRN(pool1, 2, 2e-05, 0.75, "norm1") 
 
    conv2 = convLayer(lrn1, 5, 5, 1, 1, 256, "conv2", groups = 2) 
    pool2 = maxPoolLayer(conv2, 3, 3, 2, 2, "pool2", "VALID") 
    lrn2 = LRN(pool2, 2, 2e-05, 0.75, "lrn2") 
 
    conv3 = convLayer(lrn2, 3, 3, 1, 1, 384, "conv3") 
 
    conv4 = convLayer(conv3, 3, 3, 1, 1, 384, "conv4", groups = 2) 
 
    conv5 = convLayer(conv4, 3, 3, 1, 1, 256, "conv5", groups = 2) 
    pool5 = maxPoolLayer(conv5, 3, 3, 2, 2, "pool5", "VALID") 
 
    fcIn = tf.reshape(pool5, [-1, 256 * 6 * 6]) 
    fc1 = fcLayer(fcIn, 256 * 6 * 6, 4096, True, "fc6") 
    dropout1 = dropout(fc1, self.KEEPPRO) 
 
    fc2 = fcLayer(dropout1, 4096, 4096, True, "fc7") 
    dropout2 = dropout(fc2, self.KEEPPRO) 
 
    self.fc3 = fcLayer(dropout2, 4096, self.CLASSNUM, True, "fc8") 
 
  def loadModel(self, sess): 
    """load model""" 
    wDict = np.load(self.MODELPATH, encoding = "bytes").item() 
    #for layers in model 
    for name in wDict: 
      if name not in self.SKIP: 
        with tf.variable_scope(name, reuse = True): 
          for p in wDict[name]: 
            if len(p.shape) == 1:  
              #bias 只有一维 
              sess.run(tf.get_variable('b', trainable = False).assign(p)) 
            else: 
              #weights  
              sess.run(tf.get_variable('w', trainable = False).assign(p))

buildCNN函数完全按照alexnet的结构搭建网络。
loadModel函数从模型文件中读取参数,采用的模型文件见github上的readme说明。
至此,我们定义了完整的模型,下面开始测试模型。

模型测试

ImageNet训练的AlexNet有很多类,几乎包含所有常见的物体,因此我们随便从网上找几张图片测试。比如我直接用了之前做项目的渣土车图片:

使用tensorflow实现AlexNet

然后编写测试代码:

#some params 
dropoutPro = 1 
classNum = 1000 
skip = [] 
#get testImage 
testPath = "testModel" 
testImg = [] 
for f in os.listdir(testPath): 
  testImg.append(cv2.imread(testPath + "/" + f)) 
 
imgMean = np.array([104, 117, 124], np.float) 
x = tf.placeholder("float", [1, 227, 227, 3]) 
 
model = alexnet.alexNet(x, dropoutPro, classNum, skip) 
score = model.fc3 
softmax = tf.nn.softmax(score) 
 
with tf.Session() as sess: 
  sess.run(tf.global_variables_initializer()) 
  model.loadModel(sess) #加载模型 
 
  for i, img in enumerate(testImg): 
    #img preprocess 
    test = cv2.resize(img.astype(np.float), (227, 227)) #resize成网络输入大小 
    test -= imgMean #去均值 
    test = test.reshape((1, 227, 227, 3)) #拉成tensor 
    maxx = np.argmax(sess.run(softmax, feed_dict = {x: test})) 
    res = caffe_classes.class_names[maxx] #取概率最大类的下标 
    #print(res) 
    font = cv2.FONT_HERSHEY_SIMPLEX 
    cv2.putText(img, res, (int(img.shape[0]/3), int(img.shape[1]/3)), font, 1, (0, 255, 0), 2)#绘制类的名字 
    cv2.imshow("demo", img)  
    cv2.waitKey(5000) #显示5秒

如上代码所示,首先需要设置一些参数,然后读取指定路径下的测试图像,再对模型做一个初始化,最后是真正测试代码。测试结果如下:

使用tensorflow实现AlexNet

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python网络编程学习笔记(八):XML生成与解析(DOM、ElementTree)
Jun 09 Python
Python轻量级ORM框架Peewee访问sqlite数据库的方法详解
Jul 20 Python
Python 查看文件的编码格式方法
Dec 21 Python
python使用xpath中遇到:到底是什么?
Jan 04 Python
Python使用pyh生成HTML文档的方法示例
Mar 10 Python
一些Centos Python 生产环境的部署命令(推荐)
May 07 Python
详解如何用django实现redirect的几种方法总结
Nov 22 Python
详解Python3注释知识点
Feb 19 Python
在Sublime Editor中配置Python环境的详细教程
May 03 Python
Win10下配置tensorflow-gpu的详细教程(无VS2015/2017)
Jul 14 Python
Pytorch自定义Dataset和DataLoader去除不存在和空数据的操作
Mar 03 Python
python基础之爬虫入门
May 10 Python
Django在win10下的安装并创建工程
Nov 20 #Python
Python2与python3中 for 循环语句基础与实例分析
Nov 20 #Python
Python3中类、模块、错误与异常、文件的简易教程
Nov 20 #Python
Python实现将HTML转换成doc格式文件的方法示例
Nov 20 #Python
python中学习K-Means和图片压缩
Nov 20 #Python
深入理解Python中的super()方法
Nov 20 #Python
python实现读取excel写入mysql的小工具详解
Nov 20 #Python
You might like
PHP stream_context_create()作用和用法分析
2011/03/29 PHP
改写函数实现PHP二维/三维数组转字符串
2013/09/13 PHP
PHP不用递归实现无限分级的例子分享
2014/04/18 PHP
php实现中文转数字
2016/02/18 PHP
PHP编辑器PhpStrom运行缓慢问题
2017/02/21 PHP
PHPExcel实现的读取多工作表操作示例
2020/04/14 PHP
PJ Blog修改-禁止复制的代码和方法
2006/10/25 Javascript
js 操作符实例代码
2009/10/24 Javascript
元素未显示设置width/height时IE中使用currentStyle获取为auto
2014/05/04 Javascript
使用jquery清空、复位整个输入域
2015/04/02 Javascript
JavaScript汉诺塔问题解决方法
2015/04/21 Javascript
jQuery封装的tab选项卡插件分享
2015/06/16 Javascript
JS随机调用指定函数的方法
2015/07/01 Javascript
jQuery简单实现仿京东商城的左侧菜单效果代码
2015/09/09 Javascript
jQuery中实现prop()函数控制多选框(全选,反选)
2016/08/19 Javascript
Bootstrap导航条的使用和理解3
2016/12/14 Javascript
bootstrap提示标签、提示框实现代码
2016/12/28 Javascript
使用vue-cli编写vue插件的方法
2018/02/26 Javascript
微信小程序url传参写变量的方法
2018/08/09 Javascript
在Vue中用canvas实现二维码和图片合成海报的方法
2019/06/10 Javascript
jQuery - AJAX load() 实例用法详解
2019/08/27 jQuery
layui实现显示数据表格、搜索和修改功能示例
2020/06/03 Javascript
我所理解的JavaScript中的this指向
2020/09/04 Javascript
[01:30]2016国际邀请赛中国区预选赛神秘商店火爆开启
2016/06/26 DOTA
[01:05:41]EG vs Optic Supermajor 败者组 BO3 第二场 6.6
2018/06/07 DOTA
Anaconda 离线安装 python 包的操作方法
2018/06/11 Python
python 实现将字典dict、列表list中的中文正常显示方法
2018/07/06 Python
浅谈Pycharm调用同级目录下的py脚本bug
2018/12/03 Python
python 定义类时,实现内部方法的互相调用
2019/12/25 Python
印度在线购物网站:Paytmmall
2019/07/24 全球购物
实习生求职自荐信
2014/02/07 职场文书
商务日语专业毕业生自荐信
2014/03/27 职场文书
电影开国大典观后感
2015/06/04 职场文书
PHP控制循环操作的时间
2021/04/01 PHP
JavaScript分页组件使用方法详解
2021/07/26 Javascript
Appium中scroll和drag_and_drop根据元素位置滑动
2022/02/15 Python