使用tensorflow实现VGG网络,训练mnist数据集方式


Posted in Python onMay 26, 2020

VGG作为流行的几个模型之一,训练图形数据效果不错,在mnist数据集是常用的入门集数据,VGG层数非常多,如果严格按照规范来实现,并用来训练mnist数据集,会出现各种问题,如,经过16层卷积后,28*28*1的图片几乎无法进行。

先介绍下VGG

ILSVRC 2014的第二名是Karen Simonyan和 Andrew Zisserman实现的卷积神经网络,现在称其为VGGNet。它主要的贡献是展示出网络的深度是算法优良性能的关键部分。

他们最好的网络包含了16个卷积/全连接层。网络的结构非常一致,从头到尾全部使用的是3x3的卷积和2x2的汇聚。他们的预训练模型是可以在网络上获得并在Caffe中使用的。

VGGNet不好的一点是它耗费更多计算资源,并且使用了更多的参数,导致更多的内存占用(140M)。其中绝大多数的参数都是来自于第一个全连接层。

模型结构:

使用tensorflow实现VGG网络,训练mnist数据集方式

本文在实现时候,尽量保存VGG原来模型结构,核心代码如下:

weights ={
  'wc1':tf.Variable(tf.random_normal([3,3,1,64])),
  'wc2':tf.Variable(tf.random_normal([3,3,64,64])),
  'wc3':tf.Variable(tf.random_normal([3,3,64,128])),
  'wc4':tf.Variable(tf.random_normal([3,3,128,128])),
  
  'wc5':tf.Variable(tf.random_normal([3,3,128,256])),
  'wc6':tf.Variable(tf.random_normal([3,3,256,256])),
  'wc7':tf.Variable(tf.random_normal([3,3,256,256])),
  'wc8':tf.Variable(tf.random_normal([3,3,256,256])),
  
  'wc9':tf.Variable(tf.random_normal([3,3,256,512])),
  'wc10':tf.Variable(tf.random_normal([3,3,512,512])),
  'wc11':tf.Variable(tf.random_normal([3,3,512,512])),
  'wc12':tf.Variable(tf.random_normal([3,3,512,512])),
  'wc13':tf.Variable(tf.random_normal([3,3,512,512])),
  'wc14':tf.Variable(tf.random_normal([3,3,512,512])),
  'wc15':tf.Variable(tf.random_normal([3,3,512,512])),
  'wc16':tf.Variable(tf.random_normal([3,3,512,256])),
  
  'wd1':tf.Variable(tf.random_normal([4096,4096])),
  'wd2':tf.Variable(tf.random_normal([4096,4096])),
  'out':tf.Variable(tf.random_normal([4096,nn_classes])),
}
 
biases ={
  'bc1':tf.Variable(tf.zeros([64])),
  'bc2':tf.Variable(tf.zeros([64])),
  'bc3':tf.Variable(tf.zeros([128])),
  'bc4':tf.Variable(tf.zeros([128])),
  'bc5':tf.Variable(tf.zeros([256])),
  'bc6':tf.Variable(tf.zeros([256])),
  'bc7':tf.Variable(tf.zeros([256])),
  'bc8':tf.Variable(tf.zeros([256])),
  'bc9':tf.Variable(tf.zeros([512])),
  'bc10':tf.Variable(tf.zeros([512])),
  'bc11':tf.Variable(tf.zeros([512])),
  'bc12':tf.Variable(tf.zeros([512])),
  'bc13':tf.Variable(tf.zeros([512])),
  'bc14':tf.Variable(tf.zeros([512])),
  'bc15':tf.Variable(tf.zeros([512])),
  'bc16':tf.Variable(tf.zeros([256])),
  
  
  'bd1':tf.Variable(tf.zeros([4096])),
  'bd2':tf.Variable(tf.zeros([4096])),
  'out':tf.Variable(tf.zeros([nn_classes])),
}

卷积实现:

def convLevel(i,input,type):
  num = i
  out = conv2D('conv'+str(num),input,weights['wc'+str(num)],biases['bc'+str(num)])
  if type=='p':
    out = maxPool2D('pool'+str(num),out, k=2) 
    out = norm('norm'+str(num),out, lsize=4)
  return out 
 
def VGG(x,weights,biases,dropout):
  x = tf.reshape(x,shape=[-1,28,28,1])
 
  input = x
 
  for i in range(16):
    i += 1
    if(i==2) or (i==4) or (i==12) : # 根据模型定义还需要更多的POOL化,但mnist图片大小不允许。
      input = convLevel(i,input,'p')
    else:
      input = convLevel(i,input,'c')

训练:

pred = VGG(x, weights, biases, keep_prob)
cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=pred,labels=y))
optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(cost)
 
correct_pred = tf.equal(tf.argmax(pred,1), tf.argmax(y,1))
accuracy_ = tf.reduce_mean(tf.cast(correct_pred,tf.float32))
 
init = tf.global_variables_initializer()
with tf.Session() as sess:
  sess.run(init)
  step = 1
  while step*batch_size < train_iters:
    batch_x,batch_y = mnist.train.next_batch(batch_size)
    sess.run(optimizer,feed_dict={x:batch_x,y:batch_y,keep_prob:dropout})
    print(step*batch_size)
    if step % display_step == 0 :
      #loss,acc = sess.run([cost,accuracy],feed_dict={x:batch_x,y:batch_y,keep_prob=1.0})
      acc = sess.run(accuracy_, feed_dict={x: batch_x, y: batch_y, keep_prob: 1.})
      # 计算损失值
      
      loss = sess.run(cost, feed_dict={x: batch_x, y: batch_y, keep_prob: 1.})
      print("iter: "+str(step*batch_size)+"mini batch Loss="+"{:.6f}".format(loss)+",acc="+"{:6f}".format(acc))
 
    step += 1 
   
  print("training end!")

最终效果:

训练10000次后:结果如下:

iter: 12288 mini batch Loss=5088409.500000,acc=0.578125

iter: 12800 mini batch Loss=4514274.000000,acc=0.601562

iter: 13312 mini batch Loss=4483454.500000,acc=0.648438

这种深度的模型可以考虑循环10万次以上。目前效果还不错,本人没有GPU,心痛笔记本的CPU,100%的CPU利用率,听到风扇响就不忍心再训练,本文也借鉴了alex网络实现,当然我也实现了这个网络模型。在MNIST数据上,ALEX由于层数较少,收敛更快,当然MNIST,用CNN足够了。

以上这篇使用tensorflow实现VGG网络,训练mnist数据集方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
在主机商的共享服务器上部署Django站点的方法
Jul 22 Python
Python中在脚本中引用其他文件函数的实现方法
Jun 23 Python
Python中运算符&quot;==&quot;和&quot;is&quot;的详解
Oct 08 Python
轻松掌握python设计模式之策略模式
Nov 18 Python
python opencv实现旋转矩形框裁减功能
Jul 25 Python
python绘制中国大陆人口热力图
Nov 07 Python
图文详解python安装Scrapy框架步骤
May 20 Python
python向字符串中添加元素的实例方法
Jun 28 Python
详解如何减少python内存的消耗
Aug 09 Python
解决pyqt5异常退出无提示信息的问题
Apr 08 Python
Python分类测试代码实例汇总
Jul 23 Python
python实现一个简单RPC框架的示例
Oct 28 Python
浅谈Tensorflow加载Vgg预训练模型的几个注意事项
May 26 #Python
Tensorflow加载Vgg预训练模型操作
May 26 #Python
PyQt5如何将.ui文件转换为.py文件的实例代码
May 26 #Python
TensorFlow实现模型断点训练,checkpoint模型载入方式
May 26 #Python
python 日志模块 日志等级设置失效的解决方案
May 26 #Python
python3.7+selenium模拟淘宝登录功能的实现
May 26 #Python
TensorFlow固化模型的实现操作
May 26 #Python
You might like
蝙蝠侠:侠影之谜
2020/03/04 欧美动漫
在Windows版的PHP中使用ADO
2006/10/09 PHP
php实现根据url自动生成缩略图的方法
2014/09/23 PHP
php中$_POST与php://input的区别实例分析
2015/01/07 PHP
带你了解PHP7 性能翻倍的关键
2015/11/19 PHP
js中将多个语句写成一个语句的两种方法小结
2007/12/08 Javascript
一个小型js框架myJSFrame附API使用帮助
2008/06/28 Javascript
js判断变量是否空值的代码
2008/10/26 Javascript
a标签的href与onclick事件的区别详解
2014/11/12 Javascript
JavaScript框架是什么?怎样才能叫做框架?
2015/07/01 Javascript
bootstrap中的 form表单属性role=&quot;form&quot;的作用详解
2017/01/20 Javascript
在javascript中,null>=0 为真,null==0却为假,null的值详解
2017/02/22 Javascript
利用Angular+Angular-Ui实现分页(代码加简单)
2017/03/10 Javascript
原生js实现验证码功能
2017/03/16 Javascript
Bootstrap table学习笔记(2) 前后端分页模糊查询
2017/05/18 Javascript
Ionic2开发环境搭建教程
2020/08/20 Javascript
解决linux下node.js全局模块找不到的问题
2018/05/15 Javascript
对vue中v-on绑定自定事件的实例讲解
2018/09/06 Javascript
详解react-refetch的使用小例子
2019/02/15 Javascript
Postman环境变量全局变量使用方法详解
2020/08/13 Javascript
python与caffe改变通道顺序的方法
2018/08/04 Python
Django 实现admin后台显示图片缩略图的例子
2019/07/28 Python
python 使用while循环输出*组成的菱形实例
2020/04/12 Python
python tkinter实现连连看游戏
2020/11/16 Python
python爬取youtube视频的示例代码
2021/03/03 Python
详解HTML5中的manifest缓存使用
2015/09/09 HTML / CSS
英国在线药房和在线医生:LloydsPharmacy
2019/10/21 全球购物
商务会议邀请函
2014/01/09 职场文书
二年级体育教学反思
2014/01/15 职场文书
医生进修自我鉴定
2014/01/19 职场文书
品牌转让协议书
2014/08/20 职场文书
2014年仓库工作总结
2014/11/20 职场文书
撤诉申请书法院范本
2015/05/18 职场文书
2015年教务主任工作总结
2015/07/22 职场文书
Python趣味挑战之教你用pygame画进度条
2021/05/31 Python
nginx静态资源的服务器配置方法
2022/07/07 Servers