使用tensorflow实现VGG网络,训练mnist数据集方式


Posted in Python onMay 26, 2020

VGG作为流行的几个模型之一,训练图形数据效果不错,在mnist数据集是常用的入门集数据,VGG层数非常多,如果严格按照规范来实现,并用来训练mnist数据集,会出现各种问题,如,经过16层卷积后,28*28*1的图片几乎无法进行。

先介绍下VGG

ILSVRC 2014的第二名是Karen Simonyan和 Andrew Zisserman实现的卷积神经网络,现在称其为VGGNet。它主要的贡献是展示出网络的深度是算法优良性能的关键部分。

他们最好的网络包含了16个卷积/全连接层。网络的结构非常一致,从头到尾全部使用的是3x3的卷积和2x2的汇聚。他们的预训练模型是可以在网络上获得并在Caffe中使用的。

VGGNet不好的一点是它耗费更多计算资源,并且使用了更多的参数,导致更多的内存占用(140M)。其中绝大多数的参数都是来自于第一个全连接层。

模型结构:

使用tensorflow实现VGG网络,训练mnist数据集方式

本文在实现时候,尽量保存VGG原来模型结构,核心代码如下:

weights ={
  'wc1':tf.Variable(tf.random_normal([3,3,1,64])),
  'wc2':tf.Variable(tf.random_normal([3,3,64,64])),
  'wc3':tf.Variable(tf.random_normal([3,3,64,128])),
  'wc4':tf.Variable(tf.random_normal([3,3,128,128])),
  
  'wc5':tf.Variable(tf.random_normal([3,3,128,256])),
  'wc6':tf.Variable(tf.random_normal([3,3,256,256])),
  'wc7':tf.Variable(tf.random_normal([3,3,256,256])),
  'wc8':tf.Variable(tf.random_normal([3,3,256,256])),
  
  'wc9':tf.Variable(tf.random_normal([3,3,256,512])),
  'wc10':tf.Variable(tf.random_normal([3,3,512,512])),
  'wc11':tf.Variable(tf.random_normal([3,3,512,512])),
  'wc12':tf.Variable(tf.random_normal([3,3,512,512])),
  'wc13':tf.Variable(tf.random_normal([3,3,512,512])),
  'wc14':tf.Variable(tf.random_normal([3,3,512,512])),
  'wc15':tf.Variable(tf.random_normal([3,3,512,512])),
  'wc16':tf.Variable(tf.random_normal([3,3,512,256])),
  
  'wd1':tf.Variable(tf.random_normal([4096,4096])),
  'wd2':tf.Variable(tf.random_normal([4096,4096])),
  'out':tf.Variable(tf.random_normal([4096,nn_classes])),
}
 
biases ={
  'bc1':tf.Variable(tf.zeros([64])),
  'bc2':tf.Variable(tf.zeros([64])),
  'bc3':tf.Variable(tf.zeros([128])),
  'bc4':tf.Variable(tf.zeros([128])),
  'bc5':tf.Variable(tf.zeros([256])),
  'bc6':tf.Variable(tf.zeros([256])),
  'bc7':tf.Variable(tf.zeros([256])),
  'bc8':tf.Variable(tf.zeros([256])),
  'bc9':tf.Variable(tf.zeros([512])),
  'bc10':tf.Variable(tf.zeros([512])),
  'bc11':tf.Variable(tf.zeros([512])),
  'bc12':tf.Variable(tf.zeros([512])),
  'bc13':tf.Variable(tf.zeros([512])),
  'bc14':tf.Variable(tf.zeros([512])),
  'bc15':tf.Variable(tf.zeros([512])),
  'bc16':tf.Variable(tf.zeros([256])),
  
  
  'bd1':tf.Variable(tf.zeros([4096])),
  'bd2':tf.Variable(tf.zeros([4096])),
  'out':tf.Variable(tf.zeros([nn_classes])),
}

卷积实现:

def convLevel(i,input,type):
  num = i
  out = conv2D('conv'+str(num),input,weights['wc'+str(num)],biases['bc'+str(num)])
  if type=='p':
    out = maxPool2D('pool'+str(num),out, k=2) 
    out = norm('norm'+str(num),out, lsize=4)
  return out 
 
def VGG(x,weights,biases,dropout):
  x = tf.reshape(x,shape=[-1,28,28,1])
 
  input = x
 
  for i in range(16):
    i += 1
    if(i==2) or (i==4) or (i==12) : # 根据模型定义还需要更多的POOL化,但mnist图片大小不允许。
      input = convLevel(i,input,'p')
    else:
      input = convLevel(i,input,'c')

训练:

pred = VGG(x, weights, biases, keep_prob)
cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=pred,labels=y))
optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(cost)
 
correct_pred = tf.equal(tf.argmax(pred,1), tf.argmax(y,1))
accuracy_ = tf.reduce_mean(tf.cast(correct_pred,tf.float32))
 
init = tf.global_variables_initializer()
with tf.Session() as sess:
  sess.run(init)
  step = 1
  while step*batch_size < train_iters:
    batch_x,batch_y = mnist.train.next_batch(batch_size)
    sess.run(optimizer,feed_dict={x:batch_x,y:batch_y,keep_prob:dropout})
    print(step*batch_size)
    if step % display_step == 0 :
      #loss,acc = sess.run([cost,accuracy],feed_dict={x:batch_x,y:batch_y,keep_prob=1.0})
      acc = sess.run(accuracy_, feed_dict={x: batch_x, y: batch_y, keep_prob: 1.})
      # 计算损失值
      
      loss = sess.run(cost, feed_dict={x: batch_x, y: batch_y, keep_prob: 1.})
      print("iter: "+str(step*batch_size)+"mini batch Loss="+"{:.6f}".format(loss)+",acc="+"{:6f}".format(acc))
 
    step += 1 
   
  print("training end!")

最终效果:

训练10000次后:结果如下:

iter: 12288 mini batch Loss=5088409.500000,acc=0.578125

iter: 12800 mini batch Loss=4514274.000000,acc=0.601562

iter: 13312 mini batch Loss=4483454.500000,acc=0.648438

这种深度的模型可以考虑循环10万次以上。目前效果还不错,本人没有GPU,心痛笔记本的CPU,100%的CPU利用率,听到风扇响就不忍心再训练,本文也借鉴了alex网络实现,当然我也实现了这个网络模型。在MNIST数据上,ALEX由于层数较少,收敛更快,当然MNIST,用CNN足够了。

以上这篇使用tensorflow实现VGG网络,训练mnist数据集方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python使用心得之获得github代码库列表
Jun 25 Python
Python中获取网页状态码的两个方法
Nov 03 Python
Python读写zip压缩文件的方法
Aug 29 Python
python 字典有序并写入json文件过程解析
Sep 30 Python
Pandas 缺失数据处理的实现
Nov 04 Python
Python面向对象原理与基础语法详解
Jan 02 Python
python使用docx模块读写docx文件的方法与docx模块常用方法详解
Feb 17 Python
解决Keras自带数据集与预训练model下载太慢问题
Jun 12 Python
python用Tkinter做自己的中文代码编辑器
Sep 07 Python
python如何利用paramiko执行服务器命令
Nov 07 Python
pycharm 实现复制一行的快捷键
Jan 15 Python
pycharm配置安装autopep8自动规范代码的实现
Mar 02 Python
浅谈Tensorflow加载Vgg预训练模型的几个注意事项
May 26 #Python
Tensorflow加载Vgg预训练模型操作
May 26 #Python
PyQt5如何将.ui文件转换为.py文件的实例代码
May 26 #Python
TensorFlow实现模型断点训练,checkpoint模型载入方式
May 26 #Python
python 日志模块 日志等级设置失效的解决方案
May 26 #Python
python3.7+selenium模拟淘宝登录功能的实现
May 26 #Python
TensorFlow固化模型的实现操作
May 26 #Python
You might like
第六章 php目录与文件操作
2011/12/30 PHP
PHP对象Object的概念 介绍
2012/06/14 PHP
php上传文件中文文件名乱码的解决方法
2013/11/01 PHP
浅析PHP微信支付通知的处理方式
2014/05/25 PHP
php判断是否连接上网络的方法实例详解
2016/12/14 PHP
php检测mysql表是否存在的方法小结
2017/07/20 PHP
CakePHP框架Model函数定义方法示例
2017/08/04 PHP
PHP PDO数据库操作预处理与注意事项
2019/03/16 PHP
js或css文件后面跟参数的原因说明
2010/01/09 Javascript
修复ie8&amp;chrome下window的resize事件多次执行
2011/10/20 Javascript
js导出table数据到excel即导出为EXCEL文档的方法
2013/10/10 Javascript
javascript 中的 delete及delete运算符
2015/11/15 Javascript
使用three.js 画渐变的直线
2016/06/05 Javascript
深入理解requestAnimationFrame的动画循环
2016/09/20 Javascript
jQuery实现圣诞节礼物传送(花式轮播)
2016/12/25 Javascript
bootstrap table 数据表格行内修改的实现代码
2017/02/13 Javascript
JS中mouseup事件丢失的原因与解决办法
2017/06/14 Javascript
十大 Node.js 的 Web 框架(快速提升工作效率)
2017/06/30 Javascript
jQuery Form插件使用详解_动力节点Java学院整理
2017/07/17 jQuery
vue.js提交按钮时进行简单的if判断表达式详解
2018/08/08 Javascript
vue 导航菜单刷新状态不消失,显示对应的路由界面操作
2020/08/06 Javascript
Python通过正则表达式选取callback的方法
2015/07/18 Python
python爬虫入门教程--正则表达式完全指南(五)
2017/05/25 Python
解决Python2.7中IDLE启动没有反应的问题
2018/11/30 Python
python3 selenium自动化测试 强大的CSS定位方法
2019/08/23 Python
np.dot()函数的用法详解
2020/01/17 Python
django模型类中,null=True,blank=True用法说明
2020/07/09 Python
Django自带的用户验证系统实现
2020/12/18 Python
Django解决frame拒绝问题的方法
2020/12/18 Python
学生个人求职自荐信格式
2013/09/23 职场文书
《阳光》教学反思
2014/02/23 职场文书
村班子对照检查材料
2014/08/18 职场文书
志愿者工作心得体会
2016/01/15 职场文书
《我是什么》教学反思
2016/02/16 职场文书
Python代码风格与编程习惯重要吗?
2021/06/03 Python
世界各国短波电台对东亚播送时间频率表(SW)
2021/06/28 无线电