使用tensorflow实现AlexNet


Posted in Python onNovember 20, 2017

AlexNet是2012年ImageNet比赛的冠军,虽然过去了很长时间,但是作为深度学习中的经典模型,AlexNet不但有助于我们理解其中所使用的很多技巧,而且非常有助于提升我们使用深度学习工具箱的熟练度。尤其是我刚入门深度学习,迫切需要一个能让自己熟悉tensorflow的小练习,于是就有了这个小玩意儿......

先放上我的代码:https://github.com/hjptriplebee/AlexNet_with_tensorflow

如果想运行代码,详细的配置要求都在上面链接的readme文件中了。本文建立在一定的tensorflow基础上,不会对太细的点进行说明。

模型结构

使用tensorflow实现AlexNet

关于模型结构网上的文献很多,我这里不赘述,一会儿都在代码里解释。

有一点需要注意,AlexNet将网络分成了上下两个部分,在论文中两部分结构完全相同,唯一不同的是他们放在不同GPU上训练,因为每一层的feature map之间都是独立的(除了全连接层),所以这相当于是提升训练速度的一种方法。很多AlexNet的复现都将上下两部分合并了,因为他们都是在单个GPU上运行的。虽然我也是在单个GPU上运行,但是我还是很想将最原始的网络结构还原出来,所以我的代码里也是分开的。

模型定义

def maxPoolLayer(x, kHeight, kWidth, strideX, strideY, name, padding = "SAME"): 
  """max-pooling""" 
  return tf.nn.max_pool(x, ksize = [1, kHeight, kWidth, 1], 
             strides = [1, strideX, strideY, 1], padding = padding, name = name) 
 
def dropout(x, keepPro, name = None): 
  """dropout""" 
  return tf.nn.dropout(x, keepPro, name) 
 
def LRN(x, R, alpha, beta, name = None, bias = 1.0): 
  """LRN""" 
  return tf.nn.local_response_normalization(x, depth_radius = R, alpha = alpha, 
                       beta = beta, bias = bias, name = name) 
 
def fcLayer(x, inputD, outputD, reluFlag, name): 
  """fully-connect""" 
  with tf.variable_scope(name) as scope: 
    w = tf.get_variable("w", shape = [inputD, outputD], dtype = "float") 
    b = tf.get_variable("b", [outputD], dtype = "float") 
    out = tf.nn.xw_plus_b(x, w, b, name = scope.name) 
    if reluFlag: 
      return tf.nn.relu(out) 
    else: 
      return out 
 
def convLayer(x, kHeight, kWidth, strideX, strideY, 
       featureNum, name, padding = "SAME", groups = 1):#group为2时等于AlexNet中分上下两部分 
  """convlutional""" 
  channel = int(x.get_shape()[-1])#获取channel 
  conv = lambda a, b: tf.nn.conv2d(a, b, strides = [1, strideY, strideX, 1], padding = padding)#定义卷积的匿名函数 
  with tf.variable_scope(name) as scope: 
    w = tf.get_variable("w", shape = [kHeight, kWidth, channel/groups, featureNum]) 
    b = tf.get_variable("b", shape = [featureNum]) 
 
    xNew = tf.split(value = x, num_or_size_splits = groups, axis = 3)#划分后的输入和权重 
    wNew = tf.split(value = w, num_or_size_splits = groups, axis = 3) 
 
    featureMap = [conv(t1, t2) for t1, t2 in zip(xNew, wNew)] #分别提取feature map 
    mergeFeatureMap = tf.concat(axis = 3, values = featureMap) #feature map整合 
    # print mergeFeatureMap.shape 
    out = tf.nn.bias_add(mergeFeatureMap, b) 
    return tf.nn.relu(tf.reshape(out, mergeFeatureMap.get_shape().as_list()), name = scope.name) #relu后的结果

定义了卷积、pooling、LRN、dropout、全连接五个模块,其中卷积模块因为将网络的上下两部分分开了,所以比较复杂。接下来定义AlexNet。

class alexNet(object): 
  """alexNet model""" 
  def __init__(self, x, keepPro, classNum, skip, modelPath = "bvlc_alexnet.npy"): 
    self.X = x 
    self.KEEPPRO = keepPro 
    self.CLASSNUM = classNum 
    self.SKIP = skip 
    self.MODELPATH = modelPath 
    #build CNN 
    self.buildCNN() 
 
  def buildCNN(self): 
    """build model""" 
    conv1 = convLayer(self.X, 11, 11, 4, 4, 96, "conv1", "VALID") 
    pool1 = maxPoolLayer(conv1, 3, 3, 2, 2, "pool1", "VALID") 
    lrn1 = LRN(pool1, 2, 2e-05, 0.75, "norm1") 
 
    conv2 = convLayer(lrn1, 5, 5, 1, 1, 256, "conv2", groups = 2) 
    pool2 = maxPoolLayer(conv2, 3, 3, 2, 2, "pool2", "VALID") 
    lrn2 = LRN(pool2, 2, 2e-05, 0.75, "lrn2") 
 
    conv3 = convLayer(lrn2, 3, 3, 1, 1, 384, "conv3") 
 
    conv4 = convLayer(conv3, 3, 3, 1, 1, 384, "conv4", groups = 2) 
 
    conv5 = convLayer(conv4, 3, 3, 1, 1, 256, "conv5", groups = 2) 
    pool5 = maxPoolLayer(conv5, 3, 3, 2, 2, "pool5", "VALID") 
 
    fcIn = tf.reshape(pool5, [-1, 256 * 6 * 6]) 
    fc1 = fcLayer(fcIn, 256 * 6 * 6, 4096, True, "fc6") 
    dropout1 = dropout(fc1, self.KEEPPRO) 
 
    fc2 = fcLayer(dropout1, 4096, 4096, True, "fc7") 
    dropout2 = dropout(fc2, self.KEEPPRO) 
 
    self.fc3 = fcLayer(dropout2, 4096, self.CLASSNUM, True, "fc8") 
 
  def loadModel(self, sess): 
    """load model""" 
    wDict = np.load(self.MODELPATH, encoding = "bytes").item() 
    #for layers in model 
    for name in wDict: 
      if name not in self.SKIP: 
        with tf.variable_scope(name, reuse = True): 
          for p in wDict[name]: 
            if len(p.shape) == 1:  
              #bias 只有一维 
              sess.run(tf.get_variable('b', trainable = False).assign(p)) 
            else: 
              #weights  
              sess.run(tf.get_variable('w', trainable = False).assign(p))

buildCNN函数完全按照alexnet的结构搭建网络。
loadModel函数从模型文件中读取参数,采用的模型文件见github上的readme说明。
至此,我们定义了完整的模型,下面开始测试模型。

模型测试

ImageNet训练的AlexNet有很多类,几乎包含所有常见的物体,因此我们随便从网上找几张图片测试。比如我直接用了之前做项目的渣土车图片:

使用tensorflow实现AlexNet

然后编写测试代码:

#some params 
dropoutPro = 1 
classNum = 1000 
skip = [] 
#get testImage 
testPath = "testModel" 
testImg = [] 
for f in os.listdir(testPath): 
  testImg.append(cv2.imread(testPath + "/" + f)) 
 
imgMean = np.array([104, 117, 124], np.float) 
x = tf.placeholder("float", [1, 227, 227, 3]) 
 
model = alexnet.alexNet(x, dropoutPro, classNum, skip) 
score = model.fc3 
softmax = tf.nn.softmax(score) 
 
with tf.Session() as sess: 
  sess.run(tf.global_variables_initializer()) 
  model.loadModel(sess) #加载模型 
 
  for i, img in enumerate(testImg): 
    #img preprocess 
    test = cv2.resize(img.astype(np.float), (227, 227)) #resize成网络输入大小 
    test -= imgMean #去均值 
    test = test.reshape((1, 227, 227, 3)) #拉成tensor 
    maxx = np.argmax(sess.run(softmax, feed_dict = {x: test})) 
    res = caffe_classes.class_names[maxx] #取概率最大类的下标 
    #print(res) 
    font = cv2.FONT_HERSHEY_SIMPLEX 
    cv2.putText(img, res, (int(img.shape[0]/3), int(img.shape[1]/3)), font, 1, (0, 255, 0), 2)#绘制类的名字 
    cv2.imshow("demo", img)  
    cv2.waitKey(5000) #显示5秒

如上代码所示,首先需要设置一些参数,然后读取指定路径下的测试图像,再对模型做一个初始化,最后是真正测试代码。测试结果如下:

使用tensorflow实现AlexNet

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
将图片文件嵌入到wxpython代码中的实现方法
Aug 11 Python
Python处理RSS、ATOM模块FEEDPARSER介绍
Feb 18 Python
简单总结Python中序列与字典的相同和不同之处
Jan 19 Python
python利用有道翻译实现"语言翻译器"的功能实例
Nov 14 Python
Python控制Firefox方法总结
Jun 03 Python
Python中asyncio模块的深入讲解
Jun 10 Python
详解Python可视化神器Yellowbrick使用
Nov 11 Python
基于python检查SSL证书到期情况代码实例
Apr 04 Python
tensorflow下的图片标准化函数per_image_standardization用法
Jun 30 Python
Numpy(Pandas)删除全为零的列的方法
Sep 11 Python
python scrapy简单模拟登录的代码分析
Jul 21 Python
Python可视化神器pyecharts之绘制地理图表练习
Jul 07 Python
Django在win10下的安装并创建工程
Nov 20 #Python
Python2与python3中 for 循环语句基础与实例分析
Nov 20 #Python
Python3中类、模块、错误与异常、文件的简易教程
Nov 20 #Python
Python实现将HTML转换成doc格式文件的方法示例
Nov 20 #Python
python中学习K-Means和图片压缩
Nov 20 #Python
深入理解Python中的super()方法
Nov 20 #Python
python实现读取excel写入mysql的小工具详解
Nov 20 #Python
You might like
深入解析PHP的引用计数机制
2013/06/14 PHP
PHP session_start()问题解疑(详细介绍)
2013/07/05 PHP
php递归函数中使用return的注意事项
2014/01/17 PHP
使用PHP函数scandir排除特定目录
2014/06/12 PHP
帝国CMS留言板回复后发送EMAIL通知客户
2015/07/06 PHP
PHP+shell脚本操作Memcached和Apache Status的实例分享
2016/03/11 PHP
CodeIgniter常用知识点小结
2016/05/26 PHP
关于PHP内置的字符串处理函数详解
2017/02/04 PHP
Laravel基础_关于view共享数据的示例讲解
2019/10/14 PHP
用Div仿showModalDialog模式菜单的效果的代码
2007/03/05 Javascript
JS正则中的RegExp对象对象
2012/11/07 Javascript
ajax页面无刷新 IE下遭遇Ajax缓存导致数据不更新的问题
2012/12/11 Javascript
利用js读取动态网站从服务器端返回的数据
2014/02/10 Javascript
Javascript加载速度慢的解决方案
2014/03/11 Javascript
JS下载文件|无刷新下载文件示例代码
2014/04/17 Javascript
js闭包所用的场合以及优缺点分析
2015/06/22 Javascript
用angular实现多选按钮的全选与反选实例代码
2017/05/23 Javascript
JavaScript实现带有子菜单和控件的slider轮播图效果
2017/11/01 Javascript
Vue+webpack+Element 兼容问题总结(小结)
2018/08/16 Javascript
浅析Vue 防抖与节流的使用
2019/11/14 Javascript
vue使用原生swiper代码实例
2020/02/05 Javascript
[42:34]VP vs VG 2018国际邀请赛小组赛BO2 第一场 8.19
2018/08/21 DOTA
使用python装饰器验证配置文件示例
2014/02/24 Python
python抓取网页中的图片示例
2014/02/28 Python
Python爬取成语接龙类网站
2018/10/19 Python
python通过实例讲解反射机制
2019/10/17 Python
通过实例学习Python Excel操作
2020/01/06 Python
在ipython notebook中使用argparse方式
2020/04/20 Python
Python存储读取HDF5文件代码解析
2020/11/25 Python
css3制作动态进度条以及附加jQuery百分比数字显示
2012/12/13 HTML / CSS
CSS3条纹背景制作的实战攻略
2016/05/31 HTML / CSS
Melissa鞋英国官方网站:Nonnon
2019/05/01 全球购物
服务生自我鉴定
2014/01/22 职场文书
竞选村长演讲稿
2014/04/28 职场文书
焦裕禄精神心得体会
2014/09/02 职场文书
Win11跳过联网界面创建本地管理账户的3种方法
2022/04/20 数码科技