使用tensorflow实现VGG网络,训练mnist数据集方式


Posted in Python onMay 26, 2020

VGG作为流行的几个模型之一,训练图形数据效果不错,在mnist数据集是常用的入门集数据,VGG层数非常多,如果严格按照规范来实现,并用来训练mnist数据集,会出现各种问题,如,经过16层卷积后,28*28*1的图片几乎无法进行。

先介绍下VGG

ILSVRC 2014的第二名是Karen Simonyan和 Andrew Zisserman实现的卷积神经网络,现在称其为VGGNet。它主要的贡献是展示出网络的深度是算法优良性能的关键部分。

他们最好的网络包含了16个卷积/全连接层。网络的结构非常一致,从头到尾全部使用的是3x3的卷积和2x2的汇聚。他们的预训练模型是可以在网络上获得并在Caffe中使用的。

VGGNet不好的一点是它耗费更多计算资源,并且使用了更多的参数,导致更多的内存占用(140M)。其中绝大多数的参数都是来自于第一个全连接层。

模型结构:

使用tensorflow实现VGG网络,训练mnist数据集方式

本文在实现时候,尽量保存VGG原来模型结构,核心代码如下:

weights ={
  'wc1':tf.Variable(tf.random_normal([3,3,1,64])),
  'wc2':tf.Variable(tf.random_normal([3,3,64,64])),
  'wc3':tf.Variable(tf.random_normal([3,3,64,128])),
  'wc4':tf.Variable(tf.random_normal([3,3,128,128])),
  
  'wc5':tf.Variable(tf.random_normal([3,3,128,256])),
  'wc6':tf.Variable(tf.random_normal([3,3,256,256])),
  'wc7':tf.Variable(tf.random_normal([3,3,256,256])),
  'wc8':tf.Variable(tf.random_normal([3,3,256,256])),
  
  'wc9':tf.Variable(tf.random_normal([3,3,256,512])),
  'wc10':tf.Variable(tf.random_normal([3,3,512,512])),
  'wc11':tf.Variable(tf.random_normal([3,3,512,512])),
  'wc12':tf.Variable(tf.random_normal([3,3,512,512])),
  'wc13':tf.Variable(tf.random_normal([3,3,512,512])),
  'wc14':tf.Variable(tf.random_normal([3,3,512,512])),
  'wc15':tf.Variable(tf.random_normal([3,3,512,512])),
  'wc16':tf.Variable(tf.random_normal([3,3,512,256])),
  
  'wd1':tf.Variable(tf.random_normal([4096,4096])),
  'wd2':tf.Variable(tf.random_normal([4096,4096])),
  'out':tf.Variable(tf.random_normal([4096,nn_classes])),
}
 
biases ={
  'bc1':tf.Variable(tf.zeros([64])),
  'bc2':tf.Variable(tf.zeros([64])),
  'bc3':tf.Variable(tf.zeros([128])),
  'bc4':tf.Variable(tf.zeros([128])),
  'bc5':tf.Variable(tf.zeros([256])),
  'bc6':tf.Variable(tf.zeros([256])),
  'bc7':tf.Variable(tf.zeros([256])),
  'bc8':tf.Variable(tf.zeros([256])),
  'bc9':tf.Variable(tf.zeros([512])),
  'bc10':tf.Variable(tf.zeros([512])),
  'bc11':tf.Variable(tf.zeros([512])),
  'bc12':tf.Variable(tf.zeros([512])),
  'bc13':tf.Variable(tf.zeros([512])),
  'bc14':tf.Variable(tf.zeros([512])),
  'bc15':tf.Variable(tf.zeros([512])),
  'bc16':tf.Variable(tf.zeros([256])),
  
  
  'bd1':tf.Variable(tf.zeros([4096])),
  'bd2':tf.Variable(tf.zeros([4096])),
  'out':tf.Variable(tf.zeros([nn_classes])),
}

卷积实现:

def convLevel(i,input,type):
  num = i
  out = conv2D('conv'+str(num),input,weights['wc'+str(num)],biases['bc'+str(num)])
  if type=='p':
    out = maxPool2D('pool'+str(num),out, k=2) 
    out = norm('norm'+str(num),out, lsize=4)
  return out 
 
def VGG(x,weights,biases,dropout):
  x = tf.reshape(x,shape=[-1,28,28,1])
 
  input = x
 
  for i in range(16):
    i += 1
    if(i==2) or (i==4) or (i==12) : # 根据模型定义还需要更多的POOL化,但mnist图片大小不允许。
      input = convLevel(i,input,'p')
    else:
      input = convLevel(i,input,'c')

训练:

pred = VGG(x, weights, biases, keep_prob)
cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=pred,labels=y))
optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(cost)
 
correct_pred = tf.equal(tf.argmax(pred,1), tf.argmax(y,1))
accuracy_ = tf.reduce_mean(tf.cast(correct_pred,tf.float32))
 
init = tf.global_variables_initializer()
with tf.Session() as sess:
  sess.run(init)
  step = 1
  while step*batch_size < train_iters:
    batch_x,batch_y = mnist.train.next_batch(batch_size)
    sess.run(optimizer,feed_dict={x:batch_x,y:batch_y,keep_prob:dropout})
    print(step*batch_size)
    if step % display_step == 0 :
      #loss,acc = sess.run([cost,accuracy],feed_dict={x:batch_x,y:batch_y,keep_prob=1.0})
      acc = sess.run(accuracy_, feed_dict={x: batch_x, y: batch_y, keep_prob: 1.})
      # 计算损失值
      
      loss = sess.run(cost, feed_dict={x: batch_x, y: batch_y, keep_prob: 1.})
      print("iter: "+str(step*batch_size)+"mini batch Loss="+"{:.6f}".format(loss)+",acc="+"{:6f}".format(acc))
 
    step += 1 
   
  print("training end!")

最终效果:

训练10000次后:结果如下:

iter: 12288 mini batch Loss=5088409.500000,acc=0.578125

iter: 12800 mini batch Loss=4514274.000000,acc=0.601562

iter: 13312 mini batch Loss=4483454.500000,acc=0.648438

这种深度的模型可以考虑循环10万次以上。目前效果还不错,本人没有GPU,心痛笔记本的CPU,100%的CPU利用率,听到风扇响就不忍心再训练,本文也借鉴了alex网络实现,当然我也实现了这个网络模型。在MNIST数据上,ALEX由于层数较少,收敛更快,当然MNIST,用CNN足够了。

以上这篇使用tensorflow实现VGG网络,训练mnist数据集方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
web.py在SAE中的Session问题解决方法(使用mysql存储)
Jun 24 Python
使用Python的urllib和urllib2模块制作爬虫的实例教程
Jan 20 Python
Python线性方程组求解运算示例
Jan 17 Python
解决python os.mkdir创建目录失败的问题
Oct 16 Python
Python中安装easy_install的方法
Nov 18 Python
Python (Win)readline和tab补全的安装方法
Aug 27 Python
使用apiDoc实现python接口文档编写
Nov 19 Python
python 统计文件中的字符串数目示例
Dec 24 Python
Python读取JSON数据操作实例解析
May 18 Python
Django自带用户认证系统使用方法解析
Nov 12 Python
使用python爬取抖音app视频的实例代码
Dec 01 Python
Python基础之元编程知识总结
May 23 Python
浅谈Tensorflow加载Vgg预训练模型的几个注意事项
May 26 #Python
Tensorflow加载Vgg预训练模型操作
May 26 #Python
PyQt5如何将.ui文件转换为.py文件的实例代码
May 26 #Python
TensorFlow实现模型断点训练,checkpoint模型载入方式
May 26 #Python
python 日志模块 日志等级设置失效的解决方案
May 26 #Python
python3.7+selenium模拟淘宝登录功能的实现
May 26 #Python
TensorFlow固化模型的实现操作
May 26 #Python
You might like
PHP的Socket通信之UDP通信实例
2015/07/02 PHP
10个超级有用的PHP代码片段果断收藏
2015/09/23 PHP
php mysql获取表字段名称和字段信息的三种方法
2016/11/13 PHP
php文件后缀不强制为.php的实操方法
2019/09/18 PHP
用javascript控制iframe滚动的代码
2007/04/10 Javascript
解析Javascript中中括号“[]”的多义性
2013/12/03 Javascript
快速解决jQuery与其他库冲突的方法介绍
2014/01/02 Javascript
JS中判断JSON数据是否存在某字段的方法
2014/03/07 Javascript
JavaScript DOM元素尺寸和位置
2015/04/13 Javascript
JavaScript对数组进行随机重排的方法
2015/07/22 Javascript
基于JavaScript实现定时跳转到指定页面
2016/01/01 Javascript
jQuery中的each()详细介绍(推荐)
2016/05/25 Javascript
实例分析浏览器中“JavaScript解析器”的工作原理
2016/12/12 Javascript
angularJS+requireJS实现controller及directive的按需加载示例
2017/02/20 Javascript
详解Vue.js搭建路由报错 router.map is not a function
2017/06/27 Javascript
Vue0.1的过滤代码如何添加到Vue2.0直接使用
2017/08/23 Javascript
js禁止浏览器页面后退功能的实例(推荐)
2017/09/01 Javascript
node的process以及child_process模块学习笔记
2018/03/06 Javascript
微信小程序如何像vue一样在动态绑定类名
2018/04/17 Javascript
详解vue几种主动刷新的方法总结
2019/02/19 Javascript
three.js利用卷积法如何实现物体描边效果
2019/11/27 Javascript
[04:53]DOTA2英雄基础教程 祈求者
2014/01/03 DOTA
[50:50]完美世界DOTA2联赛PWL S3 INK ICE vs DLG 第一场 12.20
2020/12/23 DOTA
python读出当前时间精度到秒的代码
2019/07/05 Python
基于matplotlib中ion()和ioff()的使用详解
2020/06/16 Python
Python爬虫小例子——爬取51job发布的工作职位
2020/07/10 Python
Python爬虫实例——爬取美团美食数据
2020/07/15 Python
HTML5 中新的全局属性(整理)
2013/07/31 HTML / CSS
测绘工程本科生求职信
2013/10/10 职场文书
财务部出纳岗位职责
2013/12/22 职场文书
致1500米运动员广播稿
2014/02/07 职场文书
活动总结格式范文
2014/04/26 职场文书
实验室标语
2014/06/21 职场文书
忠犬八公的故事观后感
2015/06/05 职场文书
幼儿园迎新生欢迎词
2015/09/30 职场文书
用javascript制作qq注册动态页面
2021/04/14 Javascript