使用tensorflow实现VGG网络,训练mnist数据集方式


Posted in Python onMay 26, 2020

VGG作为流行的几个模型之一,训练图形数据效果不错,在mnist数据集是常用的入门集数据,VGG层数非常多,如果严格按照规范来实现,并用来训练mnist数据集,会出现各种问题,如,经过16层卷积后,28*28*1的图片几乎无法进行。

先介绍下VGG

ILSVRC 2014的第二名是Karen Simonyan和 Andrew Zisserman实现的卷积神经网络,现在称其为VGGNet。它主要的贡献是展示出网络的深度是算法优良性能的关键部分。

他们最好的网络包含了16个卷积/全连接层。网络的结构非常一致,从头到尾全部使用的是3x3的卷积和2x2的汇聚。他们的预训练模型是可以在网络上获得并在Caffe中使用的。

VGGNet不好的一点是它耗费更多计算资源,并且使用了更多的参数,导致更多的内存占用(140M)。其中绝大多数的参数都是来自于第一个全连接层。

模型结构:

使用tensorflow实现VGG网络,训练mnist数据集方式

本文在实现时候,尽量保存VGG原来模型结构,核心代码如下:

weights ={
  'wc1':tf.Variable(tf.random_normal([3,3,1,64])),
  'wc2':tf.Variable(tf.random_normal([3,3,64,64])),
  'wc3':tf.Variable(tf.random_normal([3,3,64,128])),
  'wc4':tf.Variable(tf.random_normal([3,3,128,128])),
  
  'wc5':tf.Variable(tf.random_normal([3,3,128,256])),
  'wc6':tf.Variable(tf.random_normal([3,3,256,256])),
  'wc7':tf.Variable(tf.random_normal([3,3,256,256])),
  'wc8':tf.Variable(tf.random_normal([3,3,256,256])),
  
  'wc9':tf.Variable(tf.random_normal([3,3,256,512])),
  'wc10':tf.Variable(tf.random_normal([3,3,512,512])),
  'wc11':tf.Variable(tf.random_normal([3,3,512,512])),
  'wc12':tf.Variable(tf.random_normal([3,3,512,512])),
  'wc13':tf.Variable(tf.random_normal([3,3,512,512])),
  'wc14':tf.Variable(tf.random_normal([3,3,512,512])),
  'wc15':tf.Variable(tf.random_normal([3,3,512,512])),
  'wc16':tf.Variable(tf.random_normal([3,3,512,256])),
  
  'wd1':tf.Variable(tf.random_normal([4096,4096])),
  'wd2':tf.Variable(tf.random_normal([4096,4096])),
  'out':tf.Variable(tf.random_normal([4096,nn_classes])),
}
 
biases ={
  'bc1':tf.Variable(tf.zeros([64])),
  'bc2':tf.Variable(tf.zeros([64])),
  'bc3':tf.Variable(tf.zeros([128])),
  'bc4':tf.Variable(tf.zeros([128])),
  'bc5':tf.Variable(tf.zeros([256])),
  'bc6':tf.Variable(tf.zeros([256])),
  'bc7':tf.Variable(tf.zeros([256])),
  'bc8':tf.Variable(tf.zeros([256])),
  'bc9':tf.Variable(tf.zeros([512])),
  'bc10':tf.Variable(tf.zeros([512])),
  'bc11':tf.Variable(tf.zeros([512])),
  'bc12':tf.Variable(tf.zeros([512])),
  'bc13':tf.Variable(tf.zeros([512])),
  'bc14':tf.Variable(tf.zeros([512])),
  'bc15':tf.Variable(tf.zeros([512])),
  'bc16':tf.Variable(tf.zeros([256])),
  
  
  'bd1':tf.Variable(tf.zeros([4096])),
  'bd2':tf.Variable(tf.zeros([4096])),
  'out':tf.Variable(tf.zeros([nn_classes])),
}

卷积实现:

def convLevel(i,input,type):
  num = i
  out = conv2D('conv'+str(num),input,weights['wc'+str(num)],biases['bc'+str(num)])
  if type=='p':
    out = maxPool2D('pool'+str(num),out, k=2) 
    out = norm('norm'+str(num),out, lsize=4)
  return out 
 
def VGG(x,weights,biases,dropout):
  x = tf.reshape(x,shape=[-1,28,28,1])
 
  input = x
 
  for i in range(16):
    i += 1
    if(i==2) or (i==4) or (i==12) : # 根据模型定义还需要更多的POOL化,但mnist图片大小不允许。
      input = convLevel(i,input,'p')
    else:
      input = convLevel(i,input,'c')

训练:

pred = VGG(x, weights, biases, keep_prob)
cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=pred,labels=y))
optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(cost)
 
correct_pred = tf.equal(tf.argmax(pred,1), tf.argmax(y,1))
accuracy_ = tf.reduce_mean(tf.cast(correct_pred,tf.float32))
 
init = tf.global_variables_initializer()
with tf.Session() as sess:
  sess.run(init)
  step = 1
  while step*batch_size < train_iters:
    batch_x,batch_y = mnist.train.next_batch(batch_size)
    sess.run(optimizer,feed_dict={x:batch_x,y:batch_y,keep_prob:dropout})
    print(step*batch_size)
    if step % display_step == 0 :
      #loss,acc = sess.run([cost,accuracy],feed_dict={x:batch_x,y:batch_y,keep_prob=1.0})
      acc = sess.run(accuracy_, feed_dict={x: batch_x, y: batch_y, keep_prob: 1.})
      # 计算损失值
      
      loss = sess.run(cost, feed_dict={x: batch_x, y: batch_y, keep_prob: 1.})
      print("iter: "+str(step*batch_size)+"mini batch Loss="+"{:.6f}".format(loss)+",acc="+"{:6f}".format(acc))
 
    step += 1 
   
  print("training end!")

最终效果:

训练10000次后:结果如下:

iter: 12288 mini batch Loss=5088409.500000,acc=0.578125

iter: 12800 mini batch Loss=4514274.000000,acc=0.601562

iter: 13312 mini batch Loss=4483454.500000,acc=0.648438

这种深度的模型可以考虑循环10万次以上。目前效果还不错,本人没有GPU,心痛笔记本的CPU,100%的CPU利用率,听到风扇响就不忍心再训练,本文也借鉴了alex网络实现,当然我也实现了这个网络模型。在MNIST数据上,ALEX由于层数较少,收敛更快,当然MNIST,用CNN足够了。

以上这篇使用tensorflow实现VGG网络,训练mnist数据集方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
实例讲解Python中SocketServer模块处理网络请求的用法
Jun 28 Python
在Python中执行系统命令的方法示例详解
Sep 14 Python
Python数据结构之哈夫曼树定义与使用方法示例
Apr 22 Python
python得到qq句柄,并显示在前台的方法
Oct 14 Python
python3 实现对图片进行局部切割的方法
Dec 05 Python
Python类的继承用法示例
Jan 31 Python
python模拟键盘输入 切换键盘布局过程解析
Aug 15 Python
selenium+python配置chrome浏览器的选项的实现
Mar 18 Python
基于Python快速处理PDF表格数据
Jun 03 Python
python mysql中in参数化说明
Jun 05 Python
Python 如何调试程序崩溃错误
Aug 03 Python
Python IO文件管理的具体使用
Mar 20 Python
浅谈Tensorflow加载Vgg预训练模型的几个注意事项
May 26 #Python
Tensorflow加载Vgg预训练模型操作
May 26 #Python
PyQt5如何将.ui文件转换为.py文件的实例代码
May 26 #Python
TensorFlow实现模型断点训练,checkpoint模型载入方式
May 26 #Python
python 日志模块 日志等级设置失效的解决方案
May 26 #Python
python3.7+selenium模拟淘宝登录功能的实现
May 26 #Python
TensorFlow固化模型的实现操作
May 26 #Python
You might like
2014年最新推荐的10款 PHP 开发框架
2014/08/01 PHP
php绘制一个矩形的方法
2015/01/24 PHP
PHP抓取网页、解析HTML常用的方法总结
2015/07/01 PHP
Symfony2实现从数据库获取数据的方法小结
2016/03/18 PHP
详解PHP处理字符串类似indexof的方法函数
2017/06/11 PHP
jquery 图片截取工具jquery.imagecropper.js
2010/04/09 Javascript
40款非常棒的jQuery 插件和制作教程(系列一)
2011/10/26 Javascript
js 在定义的时候立即执行的函数表达式(function)写法
2013/01/16 Javascript
js左侧三级菜单导航实例代码
2013/09/13 Javascript
浅析jQuery(function(){})与(function(){})(jQuery)之间的区别
2014/01/09 Javascript
深入理解JavaScript中的浮点数
2016/05/18 Javascript
js获取腾讯视频ID的方法
2016/10/03 Javascript
让浏览器崩溃的12行JS代码(DoS攻击分析及防御)
2016/10/10 Javascript
vue插件tab选项卡使用小结
2016/10/27 Javascript
详解vue2.0监听属性的使用心得及搭配计算属性的使用
2018/07/18 Javascript
微信小程序使用for循环动态渲染页面操作示例
2018/12/25 Javascript
JavaScript的查询机制LHS和RHS解析
2019/08/16 Javascript
Vue中跨域及打包部署到nginx跨域设置方法
2019/08/26 Javascript
Python处理XML格式数据的方法详解
2017/03/21 Python
回调函数的意义以及python实现实例
2017/06/20 Python
深入浅析Python2.x和3.x版本的主要区别
2018/11/30 Python
python中的decimal类型转换实例详解
2019/06/26 Python
python 利用jinja2模板生成html代码实例
2019/10/10 Python
python定义具名元组实例操作
2021/02/28 Python
CSS实现雨滴动画效果的实例代码
2019/10/08 HTML / CSS
使用html5制作loading图的示例
2014/04/14 HTML / CSS
Nordgreen台湾官网:极简北欧设计手表
2019/08/21 全球购物
计算机个人求职信范例
2014/01/24 职场文书
校园公益广告语
2014/03/13 职场文书
低碳环保标语
2014/06/12 职场文书
纪律教育学习月活动总结
2014/08/27 职场文书
离职证明标准格式
2014/09/15 职场文书
干部作风整顿自我剖析材料和整改措施
2014/09/18 职场文书
初中优秀学生评语
2014/12/29 职场文书
HR必备:超全面的薪酬待遇管理方案!
2019/07/12 职场文书
Redis 常见使用场景
2021/08/30 Redis