使用tensorflow实现VGG网络,训练mnist数据集方式


Posted in Python onMay 26, 2020

VGG作为流行的几个模型之一,训练图形数据效果不错,在mnist数据集是常用的入门集数据,VGG层数非常多,如果严格按照规范来实现,并用来训练mnist数据集,会出现各种问题,如,经过16层卷积后,28*28*1的图片几乎无法进行。

先介绍下VGG

ILSVRC 2014的第二名是Karen Simonyan和 Andrew Zisserman实现的卷积神经网络,现在称其为VGGNet。它主要的贡献是展示出网络的深度是算法优良性能的关键部分。

他们最好的网络包含了16个卷积/全连接层。网络的结构非常一致,从头到尾全部使用的是3x3的卷积和2x2的汇聚。他们的预训练模型是可以在网络上获得并在Caffe中使用的。

VGGNet不好的一点是它耗费更多计算资源,并且使用了更多的参数,导致更多的内存占用(140M)。其中绝大多数的参数都是来自于第一个全连接层。

模型结构:

使用tensorflow实现VGG网络,训练mnist数据集方式

本文在实现时候,尽量保存VGG原来模型结构,核心代码如下:

weights ={
  'wc1':tf.Variable(tf.random_normal([3,3,1,64])),
  'wc2':tf.Variable(tf.random_normal([3,3,64,64])),
  'wc3':tf.Variable(tf.random_normal([3,3,64,128])),
  'wc4':tf.Variable(tf.random_normal([3,3,128,128])),
  
  'wc5':tf.Variable(tf.random_normal([3,3,128,256])),
  'wc6':tf.Variable(tf.random_normal([3,3,256,256])),
  'wc7':tf.Variable(tf.random_normal([3,3,256,256])),
  'wc8':tf.Variable(tf.random_normal([3,3,256,256])),
  
  'wc9':tf.Variable(tf.random_normal([3,3,256,512])),
  'wc10':tf.Variable(tf.random_normal([3,3,512,512])),
  'wc11':tf.Variable(tf.random_normal([3,3,512,512])),
  'wc12':tf.Variable(tf.random_normal([3,3,512,512])),
  'wc13':tf.Variable(tf.random_normal([3,3,512,512])),
  'wc14':tf.Variable(tf.random_normal([3,3,512,512])),
  'wc15':tf.Variable(tf.random_normal([3,3,512,512])),
  'wc16':tf.Variable(tf.random_normal([3,3,512,256])),
  
  'wd1':tf.Variable(tf.random_normal([4096,4096])),
  'wd2':tf.Variable(tf.random_normal([4096,4096])),
  'out':tf.Variable(tf.random_normal([4096,nn_classes])),
}
 
biases ={
  'bc1':tf.Variable(tf.zeros([64])),
  'bc2':tf.Variable(tf.zeros([64])),
  'bc3':tf.Variable(tf.zeros([128])),
  'bc4':tf.Variable(tf.zeros([128])),
  'bc5':tf.Variable(tf.zeros([256])),
  'bc6':tf.Variable(tf.zeros([256])),
  'bc7':tf.Variable(tf.zeros([256])),
  'bc8':tf.Variable(tf.zeros([256])),
  'bc9':tf.Variable(tf.zeros([512])),
  'bc10':tf.Variable(tf.zeros([512])),
  'bc11':tf.Variable(tf.zeros([512])),
  'bc12':tf.Variable(tf.zeros([512])),
  'bc13':tf.Variable(tf.zeros([512])),
  'bc14':tf.Variable(tf.zeros([512])),
  'bc15':tf.Variable(tf.zeros([512])),
  'bc16':tf.Variable(tf.zeros([256])),
  
  
  'bd1':tf.Variable(tf.zeros([4096])),
  'bd2':tf.Variable(tf.zeros([4096])),
  'out':tf.Variable(tf.zeros([nn_classes])),
}

卷积实现:

def convLevel(i,input,type):
  num = i
  out = conv2D('conv'+str(num),input,weights['wc'+str(num)],biases['bc'+str(num)])
  if type=='p':
    out = maxPool2D('pool'+str(num),out, k=2) 
    out = norm('norm'+str(num),out, lsize=4)
  return out 
 
def VGG(x,weights,biases,dropout):
  x = tf.reshape(x,shape=[-1,28,28,1])
 
  input = x
 
  for i in range(16):
    i += 1
    if(i==2) or (i==4) or (i==12) : # 根据模型定义还需要更多的POOL化,但mnist图片大小不允许。
      input = convLevel(i,input,'p')
    else:
      input = convLevel(i,input,'c')

训练:

pred = VGG(x, weights, biases, keep_prob)
cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=pred,labels=y))
optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(cost)
 
correct_pred = tf.equal(tf.argmax(pred,1), tf.argmax(y,1))
accuracy_ = tf.reduce_mean(tf.cast(correct_pred,tf.float32))
 
init = tf.global_variables_initializer()
with tf.Session() as sess:
  sess.run(init)
  step = 1
  while step*batch_size < train_iters:
    batch_x,batch_y = mnist.train.next_batch(batch_size)
    sess.run(optimizer,feed_dict={x:batch_x,y:batch_y,keep_prob:dropout})
    print(step*batch_size)
    if step % display_step == 0 :
      #loss,acc = sess.run([cost,accuracy],feed_dict={x:batch_x,y:batch_y,keep_prob=1.0})
      acc = sess.run(accuracy_, feed_dict={x: batch_x, y: batch_y, keep_prob: 1.})
      # 计算损失值
      
      loss = sess.run(cost, feed_dict={x: batch_x, y: batch_y, keep_prob: 1.})
      print("iter: "+str(step*batch_size)+"mini batch Loss="+"{:.6f}".format(loss)+",acc="+"{:6f}".format(acc))
 
    step += 1 
   
  print("training end!")

最终效果:

训练10000次后:结果如下:

iter: 12288 mini batch Loss=5088409.500000,acc=0.578125

iter: 12800 mini batch Loss=4514274.000000,acc=0.601562

iter: 13312 mini batch Loss=4483454.500000,acc=0.648438

这种深度的模型可以考虑循环10万次以上。目前效果还不错,本人没有GPU,心痛笔记本的CPU,100%的CPU利用率,听到风扇响就不忍心再训练,本文也借鉴了alex网络实现,当然我也实现了这个网络模型。在MNIST数据上,ALEX由于层数较少,收敛更快,当然MNIST,用CNN足够了。

以上这篇使用tensorflow实现VGG网络,训练mnist数据集方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
基于python元祖与字典与集合的粗浅认识
Aug 23 Python
python实现泊松图像融合
Jul 26 Python
Centos下实现安装Python3.6和Python2共存
Aug 15 Python
Python实现压缩文件夹与解压缩zip文件的方法
Sep 01 Python
python中字符串内置函数的用法总结
Sep 13 Python
python web自制框架之接受url传递过来的参数实例
Dec 17 Python
python利用tkinter实现屏保
Jul 30 Python
使用python实现滑动验证码功能
Aug 05 Python
解决Python在导入文件时的FileNotFoundError问题
Apr 10 Python
Tensorflow安装问题: Could not find a version that satisfies the requirement tensorflow
Apr 20 Python
Python3.7安装pyaudio教程解析
Jul 24 Python
python基于tkinter制作无损音乐下载工具
Mar 29 Python
浅谈Tensorflow加载Vgg预训练模型的几个注意事项
May 26 #Python
Tensorflow加载Vgg预训练模型操作
May 26 #Python
PyQt5如何将.ui文件转换为.py文件的实例代码
May 26 #Python
TensorFlow实现模型断点训练,checkpoint模型载入方式
May 26 #Python
python 日志模块 日志等级设置失效的解决方案
May 26 #Python
python3.7+selenium模拟淘宝登录功能的实现
May 26 #Python
TensorFlow固化模型的实现操作
May 26 #Python
You might like
网络资源
2006/10/09 PHP
用PHP实现维护文件代码
2007/06/14 PHP
php通过会话控制实现身份验证实例
2016/10/18 PHP
jQuery EasyUI API 中文文档 - NumberSpinner数值微调器使用介绍
2011/10/21 Javascript
jquery用get实现ajax在ie里面刷新不进入后台解决方法
2013/08/12 Javascript
js取消单选按钮选中示例代码
2013/11/14 Javascript
jQuery中$this和$(this)的区别介绍(一看就懂)
2015/07/06 Javascript
BootStrap初学者对弹出框和进度条的使用感觉
2016/06/27 Javascript
Dropzone.js实现文件拖拽上传功能(附源码下载)
2016/11/22 Javascript
原生js二级联动效果
2017/06/20 Javascript
浅谈Angular4中常用管道
2017/09/27 Javascript
小程序视频或音频自定义可拖拽进度条的示例代码
2018/09/30 Javascript
JavaScript实现简单贪吃蛇效果
2020/03/09 Javascript
JS call()及apply()方法使用实例汇总
2020/07/11 Javascript
js实现删除json中指定的元素
2020/09/22 Javascript
[04:12]第二届DOTA2亚洲邀请赛选手传记-Newbee.Sccc
2017/04/03 DOTA
Python中的特殊语法:filter、map、reduce、lambda介绍
2015/04/14 Python
在Python的Django框架下使用django-tagging的教程
2015/05/30 Python
Python文本相似性计算之编辑距离详解
2016/11/28 Python
Python实现字符串匹配算法代码示例
2017/12/05 Python
关于Django显示时间你应该知道的一些问题
2017/12/25 Python
深入浅析Python中的yield关键字
2018/01/24 Python
django解决跨域请求的问题
2018/11/11 Python
Python当中的array数组对象实例详解
2019/06/12 Python
HTML5注册页面示例代码
2014/03/27 HTML / CSS
HTML 5.1来了 9月份正式发布 更新内容预览
2016/04/26 HTML / CSS
英国最大的电脑零售连锁店集团:PC World
2016/10/10 全球购物
乌克兰珠宝大卖场:Zlato.ua
2020/09/27 全球购物
宣传活动总结范文
2014/07/01 职场文书
旅游与酒店管理专业求职信
2014/07/21 职场文书
八一建军节慰问信
2015/02/14 职场文书
焦点访谈观后感
2015/06/11 职场文书
2015年公路路政个人工作总结
2015/07/24 职场文书
HTML5简单实现添加背景音乐的几种方法
2021/05/12 HTML / CSS
Java实现多文件上传功能
2021/06/30 Java/Android
使用CSS实现六边形的图片效果
2022/08/05 HTML / CSS