使用tensorflow实现VGG网络,训练mnist数据集方式


Posted in Python onMay 26, 2020

VGG作为流行的几个模型之一,训练图形数据效果不错,在mnist数据集是常用的入门集数据,VGG层数非常多,如果严格按照规范来实现,并用来训练mnist数据集,会出现各种问题,如,经过16层卷积后,28*28*1的图片几乎无法进行。

先介绍下VGG

ILSVRC 2014的第二名是Karen Simonyan和 Andrew Zisserman实现的卷积神经网络,现在称其为VGGNet。它主要的贡献是展示出网络的深度是算法优良性能的关键部分。

他们最好的网络包含了16个卷积/全连接层。网络的结构非常一致,从头到尾全部使用的是3x3的卷积和2x2的汇聚。他们的预训练模型是可以在网络上获得并在Caffe中使用的。

VGGNet不好的一点是它耗费更多计算资源,并且使用了更多的参数,导致更多的内存占用(140M)。其中绝大多数的参数都是来自于第一个全连接层。

模型结构:

使用tensorflow实现VGG网络,训练mnist数据集方式

本文在实现时候,尽量保存VGG原来模型结构,核心代码如下:

weights ={
  'wc1':tf.Variable(tf.random_normal([3,3,1,64])),
  'wc2':tf.Variable(tf.random_normal([3,3,64,64])),
  'wc3':tf.Variable(tf.random_normal([3,3,64,128])),
  'wc4':tf.Variable(tf.random_normal([3,3,128,128])),
  
  'wc5':tf.Variable(tf.random_normal([3,3,128,256])),
  'wc6':tf.Variable(tf.random_normal([3,3,256,256])),
  'wc7':tf.Variable(tf.random_normal([3,3,256,256])),
  'wc8':tf.Variable(tf.random_normal([3,3,256,256])),
  
  'wc9':tf.Variable(tf.random_normal([3,3,256,512])),
  'wc10':tf.Variable(tf.random_normal([3,3,512,512])),
  'wc11':tf.Variable(tf.random_normal([3,3,512,512])),
  'wc12':tf.Variable(tf.random_normal([3,3,512,512])),
  'wc13':tf.Variable(tf.random_normal([3,3,512,512])),
  'wc14':tf.Variable(tf.random_normal([3,3,512,512])),
  'wc15':tf.Variable(tf.random_normal([3,3,512,512])),
  'wc16':tf.Variable(tf.random_normal([3,3,512,256])),
  
  'wd1':tf.Variable(tf.random_normal([4096,4096])),
  'wd2':tf.Variable(tf.random_normal([4096,4096])),
  'out':tf.Variable(tf.random_normal([4096,nn_classes])),
}
 
biases ={
  'bc1':tf.Variable(tf.zeros([64])),
  'bc2':tf.Variable(tf.zeros([64])),
  'bc3':tf.Variable(tf.zeros([128])),
  'bc4':tf.Variable(tf.zeros([128])),
  'bc5':tf.Variable(tf.zeros([256])),
  'bc6':tf.Variable(tf.zeros([256])),
  'bc7':tf.Variable(tf.zeros([256])),
  'bc8':tf.Variable(tf.zeros([256])),
  'bc9':tf.Variable(tf.zeros([512])),
  'bc10':tf.Variable(tf.zeros([512])),
  'bc11':tf.Variable(tf.zeros([512])),
  'bc12':tf.Variable(tf.zeros([512])),
  'bc13':tf.Variable(tf.zeros([512])),
  'bc14':tf.Variable(tf.zeros([512])),
  'bc15':tf.Variable(tf.zeros([512])),
  'bc16':tf.Variable(tf.zeros([256])),
  
  
  'bd1':tf.Variable(tf.zeros([4096])),
  'bd2':tf.Variable(tf.zeros([4096])),
  'out':tf.Variable(tf.zeros([nn_classes])),
}

卷积实现:

def convLevel(i,input,type):
  num = i
  out = conv2D('conv'+str(num),input,weights['wc'+str(num)],biases['bc'+str(num)])
  if type=='p':
    out = maxPool2D('pool'+str(num),out, k=2) 
    out = norm('norm'+str(num),out, lsize=4)
  return out 
 
def VGG(x,weights,biases,dropout):
  x = tf.reshape(x,shape=[-1,28,28,1])
 
  input = x
 
  for i in range(16):
    i += 1
    if(i==2) or (i==4) or (i==12) : # 根据模型定义还需要更多的POOL化,但mnist图片大小不允许。
      input = convLevel(i,input,'p')
    else:
      input = convLevel(i,input,'c')

训练:

pred = VGG(x, weights, biases, keep_prob)
cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=pred,labels=y))
optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(cost)
 
correct_pred = tf.equal(tf.argmax(pred,1), tf.argmax(y,1))
accuracy_ = tf.reduce_mean(tf.cast(correct_pred,tf.float32))
 
init = tf.global_variables_initializer()
with tf.Session() as sess:
  sess.run(init)
  step = 1
  while step*batch_size < train_iters:
    batch_x,batch_y = mnist.train.next_batch(batch_size)
    sess.run(optimizer,feed_dict={x:batch_x,y:batch_y,keep_prob:dropout})
    print(step*batch_size)
    if step % display_step == 0 :
      #loss,acc = sess.run([cost,accuracy],feed_dict={x:batch_x,y:batch_y,keep_prob=1.0})
      acc = sess.run(accuracy_, feed_dict={x: batch_x, y: batch_y, keep_prob: 1.})
      # 计算损失值
      
      loss = sess.run(cost, feed_dict={x: batch_x, y: batch_y, keep_prob: 1.})
      print("iter: "+str(step*batch_size)+"mini batch Loss="+"{:.6f}".format(loss)+",acc="+"{:6f}".format(acc))
 
    step += 1 
   
  print("training end!")

最终效果:

训练10000次后:结果如下:

iter: 12288 mini batch Loss=5088409.500000,acc=0.578125

iter: 12800 mini batch Loss=4514274.000000,acc=0.601562

iter: 13312 mini batch Loss=4483454.500000,acc=0.648438

这种深度的模型可以考虑循环10万次以上。目前效果还不错,本人没有GPU,心痛笔记本的CPU,100%的CPU利用率,听到风扇响就不忍心再训练,本文也借鉴了alex网络实现,当然我也实现了这个网络模型。在MNIST数据上,ALEX由于层数较少,收敛更快,当然MNIST,用CNN足够了。

以上这篇使用tensorflow实现VGG网络,训练mnist数据集方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python实现猜数字游戏(无重复数字)示例分享
Mar 29 Python
Python查找相似单词的方法
Mar 05 Python
Go语言基于Socket编写服务器端与客户端通信的实例
Feb 19 Python
Python基础中所出现的异常报错总结
Nov 19 Python
python中print()函数的“,”与java中System.out.print()函数中的“+”功能详解
Nov 24 Python
对python 矩阵转置transpose的实例讲解
Apr 17 Python
Python3中_(下划线)和__(双下划线)的用途和区别
Apr 26 Python
Python中查看变量的类型内存地址所占字节的大小
Jun 26 Python
python分数表示方式和写法
Jun 26 Python
Pandas_cum累积计算和rolling滚动计算的用法详解
Jul 04 Python
Python使用xlrd实现读取合并单元格
Jul 09 Python
Pycharm github配置实现过程图解
Oct 13 Python
浅谈Tensorflow加载Vgg预训练模型的几个注意事项
May 26 #Python
Tensorflow加载Vgg预训练模型操作
May 26 #Python
PyQt5如何将.ui文件转换为.py文件的实例代码
May 26 #Python
TensorFlow实现模型断点训练,checkpoint模型载入方式
May 26 #Python
python 日志模块 日志等级设置失效的解决方案
May 26 #Python
python3.7+selenium模拟淘宝登录功能的实现
May 26 #Python
TensorFlow固化模型的实现操作
May 26 #Python
You might like
php实现单链表的实例代码
2013/03/22 PHP
php输出1000以内质数(素数)示例
2014/02/16 PHP
自己写的兼容低于PHP 5.5版本的array_column()函数
2014/10/24 PHP
分享一则PHP定义函数代码
2015/02/26 PHP
Track Image Loading效果代码分析
2007/08/13 Javascript
javascript中input中readonly和disabled区别介绍
2012/10/23 Javascript
sogou地图API用法实例教程
2014/09/11 Javascript
jquery文档操作wrap()方法实例简述
2015/01/10 Javascript
jQuery中scrollLeft()方法用法实例
2015/01/16 Javascript
DOM基础教程之事件类型
2015/01/20 Javascript
有关文件上传 非ajax提交 得到后台数据问题
2016/10/12 Javascript
angular实现IM聊天图片发送实例
2017/05/08 Javascript
基于EasyUI的基础之上实现树形功能菜单
2017/06/28 Javascript
JavaScript中的return布尔值的用法和原理解析
2017/08/14 Javascript
JavaScript实现的原生态Tab标签页功能【兼容IE6】
2017/09/18 Javascript
浅谈webpack 自动刷新与解析
2018/04/09 Javascript
jquery判断滚动条距离顶部的距离方法
2018/09/05 jQuery
vue + typescript + video.js实现 流媒体播放 视频监控功能
2019/07/07 Javascript
[02:23]DOTA2英雄基础教程 幻影长矛手
2013/12/09 DOTA
Python使用SQLite和Excel操作进行数据分析
2018/01/20 Python
python调用xlsxwriter创建xlsx的方法
2018/05/03 Python
Django分页查询并返回jsons数据(中文乱码解决方法)
2018/08/02 Python
解决pycharm下os.system执行命令返回有中文乱码的问题
2019/07/07 Python
Python定时发送天气预报邮件代码实例
2019/09/09 Python
印度尼西亚最大和最全面的网络商城:Blibli.com
2017/10/04 全球购物
澳大利亚自然和有机的健康美容产品一站式商店:Ziani Beauty
2017/12/28 全球购物
中学生团员自我评价分享
2013/12/07 职场文书
工程专业求职自荐书范文
2014/02/18 职场文书
中西医专业毕业生职业规划书
2014/02/24 职场文书
安全协议书范本
2014/04/21 职场文书
个人考核材料
2014/05/15 职场文书
小学教师师德师风演讲稿
2014/08/22 职场文书
交通事故调解协议书
2015/05/20 职场文书
2015年国庆节广播稿
2015/08/19 职场文书
sql字段解析器的实现示例
2021/06/23 SQL Server
Elasticsearch 数据类型及管理
2022/04/19 Python