使用tensorflow实现VGG网络,训练mnist数据集方式


Posted in Python onMay 26, 2020

VGG作为流行的几个模型之一,训练图形数据效果不错,在mnist数据集是常用的入门集数据,VGG层数非常多,如果严格按照规范来实现,并用来训练mnist数据集,会出现各种问题,如,经过16层卷积后,28*28*1的图片几乎无法进行。

先介绍下VGG

ILSVRC 2014的第二名是Karen Simonyan和 Andrew Zisserman实现的卷积神经网络,现在称其为VGGNet。它主要的贡献是展示出网络的深度是算法优良性能的关键部分。

他们最好的网络包含了16个卷积/全连接层。网络的结构非常一致,从头到尾全部使用的是3x3的卷积和2x2的汇聚。他们的预训练模型是可以在网络上获得并在Caffe中使用的。

VGGNet不好的一点是它耗费更多计算资源,并且使用了更多的参数,导致更多的内存占用(140M)。其中绝大多数的参数都是来自于第一个全连接层。

模型结构:

使用tensorflow实现VGG网络,训练mnist数据集方式

本文在实现时候,尽量保存VGG原来模型结构,核心代码如下:

weights ={
  'wc1':tf.Variable(tf.random_normal([3,3,1,64])),
  'wc2':tf.Variable(tf.random_normal([3,3,64,64])),
  'wc3':tf.Variable(tf.random_normal([3,3,64,128])),
  'wc4':tf.Variable(tf.random_normal([3,3,128,128])),
  
  'wc5':tf.Variable(tf.random_normal([3,3,128,256])),
  'wc6':tf.Variable(tf.random_normal([3,3,256,256])),
  'wc7':tf.Variable(tf.random_normal([3,3,256,256])),
  'wc8':tf.Variable(tf.random_normal([3,3,256,256])),
  
  'wc9':tf.Variable(tf.random_normal([3,3,256,512])),
  'wc10':tf.Variable(tf.random_normal([3,3,512,512])),
  'wc11':tf.Variable(tf.random_normal([3,3,512,512])),
  'wc12':tf.Variable(tf.random_normal([3,3,512,512])),
  'wc13':tf.Variable(tf.random_normal([3,3,512,512])),
  'wc14':tf.Variable(tf.random_normal([3,3,512,512])),
  'wc15':tf.Variable(tf.random_normal([3,3,512,512])),
  'wc16':tf.Variable(tf.random_normal([3,3,512,256])),
  
  'wd1':tf.Variable(tf.random_normal([4096,4096])),
  'wd2':tf.Variable(tf.random_normal([4096,4096])),
  'out':tf.Variable(tf.random_normal([4096,nn_classes])),
}
 
biases ={
  'bc1':tf.Variable(tf.zeros([64])),
  'bc2':tf.Variable(tf.zeros([64])),
  'bc3':tf.Variable(tf.zeros([128])),
  'bc4':tf.Variable(tf.zeros([128])),
  'bc5':tf.Variable(tf.zeros([256])),
  'bc6':tf.Variable(tf.zeros([256])),
  'bc7':tf.Variable(tf.zeros([256])),
  'bc8':tf.Variable(tf.zeros([256])),
  'bc9':tf.Variable(tf.zeros([512])),
  'bc10':tf.Variable(tf.zeros([512])),
  'bc11':tf.Variable(tf.zeros([512])),
  'bc12':tf.Variable(tf.zeros([512])),
  'bc13':tf.Variable(tf.zeros([512])),
  'bc14':tf.Variable(tf.zeros([512])),
  'bc15':tf.Variable(tf.zeros([512])),
  'bc16':tf.Variable(tf.zeros([256])),
  
  
  'bd1':tf.Variable(tf.zeros([4096])),
  'bd2':tf.Variable(tf.zeros([4096])),
  'out':tf.Variable(tf.zeros([nn_classes])),
}

卷积实现:

def convLevel(i,input,type):
  num = i
  out = conv2D('conv'+str(num),input,weights['wc'+str(num)],biases['bc'+str(num)])
  if type=='p':
    out = maxPool2D('pool'+str(num),out, k=2) 
    out = norm('norm'+str(num),out, lsize=4)
  return out 
 
def VGG(x,weights,biases,dropout):
  x = tf.reshape(x,shape=[-1,28,28,1])
 
  input = x
 
  for i in range(16):
    i += 1
    if(i==2) or (i==4) or (i==12) : # 根据模型定义还需要更多的POOL化,但mnist图片大小不允许。
      input = convLevel(i,input,'p')
    else:
      input = convLevel(i,input,'c')

训练:

pred = VGG(x, weights, biases, keep_prob)
cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=pred,labels=y))
optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(cost)
 
correct_pred = tf.equal(tf.argmax(pred,1), tf.argmax(y,1))
accuracy_ = tf.reduce_mean(tf.cast(correct_pred,tf.float32))
 
init = tf.global_variables_initializer()
with tf.Session() as sess:
  sess.run(init)
  step = 1
  while step*batch_size < train_iters:
    batch_x,batch_y = mnist.train.next_batch(batch_size)
    sess.run(optimizer,feed_dict={x:batch_x,y:batch_y,keep_prob:dropout})
    print(step*batch_size)
    if step % display_step == 0 :
      #loss,acc = sess.run([cost,accuracy],feed_dict={x:batch_x,y:batch_y,keep_prob=1.0})
      acc = sess.run(accuracy_, feed_dict={x: batch_x, y: batch_y, keep_prob: 1.})
      # 计算损失值
      
      loss = sess.run(cost, feed_dict={x: batch_x, y: batch_y, keep_prob: 1.})
      print("iter: "+str(step*batch_size)+"mini batch Loss="+"{:.6f}".format(loss)+",acc="+"{:6f}".format(acc))
 
    step += 1 
   
  print("training end!")

最终效果:

训练10000次后:结果如下:

iter: 12288 mini batch Loss=5088409.500000,acc=0.578125

iter: 12800 mini batch Loss=4514274.000000,acc=0.601562

iter: 13312 mini batch Loss=4483454.500000,acc=0.648438

这种深度的模型可以考虑循环10万次以上。目前效果还不错,本人没有GPU,心痛笔记本的CPU,100%的CPU利用率,听到风扇响就不忍心再训练,本文也借鉴了alex网络实现,当然我也实现了这个网络模型。在MNIST数据上,ALEX由于层数较少,收敛更快,当然MNIST,用CNN足够了。

以上这篇使用tensorflow实现VGG网络,训练mnist数据集方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python中pycurl库的用法实例
Sep 30 Python
浅析Python 中整型对象存储的位置
May 16 Python
Python爬取网易云音乐上评论火爆的歌曲
Jan 19 Python
Python解决N阶台阶走法问题的方法分析
Dec 28 Python
Python实现微信消息防撤回功能的实例代码
Apr 29 Python
python 一个figure上显示多个图像的实例
Jul 08 Python
对python中 math模块下 atan 和 atan2的区别详解
Jan 17 Python
详解基于Jupyter notebooks采用sklearn库实现多元回归方程编程
Mar 25 Python
pytorch 常用函数 max ,eq说明
Jun 28 Python
Python如何定义接口和抽象类
Jul 28 Python
python删除文件、清空目录的实现方法
Sep 23 Python
编写python代码实现简单抽奖器
Oct 20 Python
浅谈Tensorflow加载Vgg预训练模型的几个注意事项
May 26 #Python
Tensorflow加载Vgg预训练模型操作
May 26 #Python
PyQt5如何将.ui文件转换为.py文件的实例代码
May 26 #Python
TensorFlow实现模型断点训练,checkpoint模型载入方式
May 26 #Python
python 日志模块 日志等级设置失效的解决方案
May 26 #Python
python3.7+selenium模拟淘宝登录功能的实现
May 26 #Python
TensorFlow固化模型的实现操作
May 26 #Python
You might like
解决file_get_contents无法请求https连接的方法
2013/12/17 PHP
PHP编程之设置apache虚拟目录
2016/07/08 PHP
PHP调用QQ互联接口实现QQ登录网站功能示例
2019/10/24 PHP
Flash对联广告的关闭按钮讨论
2007/01/30 Javascript
在js中单选框和复选框获取值的方式
2009/11/06 Javascript
根据IP的地址,区分不同的地区,查看不同的网站页面的js代码
2013/02/26 Javascript
js实现收缩菜单效果实例代码
2013/10/30 Javascript
js 上下左右键控制焦点(示例代码)
2013/12/14 Javascript
ie9 提示'console' 未定义问题的解决方法
2014/03/20 Javascript
jquery动态调整div大小使其宽度始终为浏览器宽度
2014/06/06 Javascript
Javascript原型链和原型的一个误区
2014/10/22 Javascript
仿JQuery输写高效JSLite代码的一些技巧
2015/01/13 Javascript
JavaScript 轮播图和自定义滚动条配合鼠标滚轮分享代码贴
2016/10/28 Javascript
JS验证字符串功能
2017/02/22 Javascript
js实现简单数字变动效果
2017/11/06 Javascript
浅谈JS和jQuery的区别
2019/03/27 jQuery
js函数柯里化的方法和作用实例分析
2020/04/11 Javascript
微信小程序组件生命周期的踩坑记录
2021/03/03 Javascript
实例讲解Python中SocketServer模块处理网络请求的用法
2016/06/28 Python
Python 12306抢火车票脚本 Python京东抢手机脚本
2018/02/06 Python
python3+PyQt5重新实现自定义数据拖放处理
2018/04/19 Python
pytorch中的上采样以及各种反操作,求逆操作详解
2020/01/03 Python
TensorFlow查看输入节点和输出节点名称方式
2020/01/04 Python
Django Admin设置应用程序及模型顺序方法详解
2020/04/01 Python
详解python爬取弹幕与数据分析
2020/11/14 Python
css3背景_动力节点Java学院整理
2017/07/11 HTML / CSS
CSS3打造百度贴吧的3D翻牌效果示例
2017/01/04 HTML / CSS
网页中的电话号码如何实现一键直呼效果_附示例
2016/03/15 HTML / CSS
教师自我剖析材料(群众路线)
2014/09/29 职场文书
2014年助理政工师工作总结
2014/12/19 职场文书
大学生自我评价范文
2015/03/03 职场文书
家长通知书家长意见
2015/06/03 职场文书
法律意见书范本
2015/06/04 职场文书
城南旧事读书笔记
2015/06/29 职场文书
Nginx同一个域名配置多个项目的实现方法
2021/03/31 Servers
Mysql数据库中datetime、bigint、timestamp来表示时间选择,谁来存储时间效率最高
2021/08/23 MySQL