使用tensorflow实现VGG网络,训练mnist数据集方式


Posted in Python onMay 26, 2020

VGG作为流行的几个模型之一,训练图形数据效果不错,在mnist数据集是常用的入门集数据,VGG层数非常多,如果严格按照规范来实现,并用来训练mnist数据集,会出现各种问题,如,经过16层卷积后,28*28*1的图片几乎无法进行。

先介绍下VGG

ILSVRC 2014的第二名是Karen Simonyan和 Andrew Zisserman实现的卷积神经网络,现在称其为VGGNet。它主要的贡献是展示出网络的深度是算法优良性能的关键部分。

他们最好的网络包含了16个卷积/全连接层。网络的结构非常一致,从头到尾全部使用的是3x3的卷积和2x2的汇聚。他们的预训练模型是可以在网络上获得并在Caffe中使用的。

VGGNet不好的一点是它耗费更多计算资源,并且使用了更多的参数,导致更多的内存占用(140M)。其中绝大多数的参数都是来自于第一个全连接层。

模型结构:

使用tensorflow实现VGG网络,训练mnist数据集方式

本文在实现时候,尽量保存VGG原来模型结构,核心代码如下:

weights ={
  'wc1':tf.Variable(tf.random_normal([3,3,1,64])),
  'wc2':tf.Variable(tf.random_normal([3,3,64,64])),
  'wc3':tf.Variable(tf.random_normal([3,3,64,128])),
  'wc4':tf.Variable(tf.random_normal([3,3,128,128])),
  
  'wc5':tf.Variable(tf.random_normal([3,3,128,256])),
  'wc6':tf.Variable(tf.random_normal([3,3,256,256])),
  'wc7':tf.Variable(tf.random_normal([3,3,256,256])),
  'wc8':tf.Variable(tf.random_normal([3,3,256,256])),
  
  'wc9':tf.Variable(tf.random_normal([3,3,256,512])),
  'wc10':tf.Variable(tf.random_normal([3,3,512,512])),
  'wc11':tf.Variable(tf.random_normal([3,3,512,512])),
  'wc12':tf.Variable(tf.random_normal([3,3,512,512])),
  'wc13':tf.Variable(tf.random_normal([3,3,512,512])),
  'wc14':tf.Variable(tf.random_normal([3,3,512,512])),
  'wc15':tf.Variable(tf.random_normal([3,3,512,512])),
  'wc16':tf.Variable(tf.random_normal([3,3,512,256])),
  
  'wd1':tf.Variable(tf.random_normal([4096,4096])),
  'wd2':tf.Variable(tf.random_normal([4096,4096])),
  'out':tf.Variable(tf.random_normal([4096,nn_classes])),
}
 
biases ={
  'bc1':tf.Variable(tf.zeros([64])),
  'bc2':tf.Variable(tf.zeros([64])),
  'bc3':tf.Variable(tf.zeros([128])),
  'bc4':tf.Variable(tf.zeros([128])),
  'bc5':tf.Variable(tf.zeros([256])),
  'bc6':tf.Variable(tf.zeros([256])),
  'bc7':tf.Variable(tf.zeros([256])),
  'bc8':tf.Variable(tf.zeros([256])),
  'bc9':tf.Variable(tf.zeros([512])),
  'bc10':tf.Variable(tf.zeros([512])),
  'bc11':tf.Variable(tf.zeros([512])),
  'bc12':tf.Variable(tf.zeros([512])),
  'bc13':tf.Variable(tf.zeros([512])),
  'bc14':tf.Variable(tf.zeros([512])),
  'bc15':tf.Variable(tf.zeros([512])),
  'bc16':tf.Variable(tf.zeros([256])),
  
  
  'bd1':tf.Variable(tf.zeros([4096])),
  'bd2':tf.Variable(tf.zeros([4096])),
  'out':tf.Variable(tf.zeros([nn_classes])),
}

卷积实现:

def convLevel(i,input,type):
  num = i
  out = conv2D('conv'+str(num),input,weights['wc'+str(num)],biases['bc'+str(num)])
  if type=='p':
    out = maxPool2D('pool'+str(num),out, k=2) 
    out = norm('norm'+str(num),out, lsize=4)
  return out 
 
def VGG(x,weights,biases,dropout):
  x = tf.reshape(x,shape=[-1,28,28,1])
 
  input = x
 
  for i in range(16):
    i += 1
    if(i==2) or (i==4) or (i==12) : # 根据模型定义还需要更多的POOL化,但mnist图片大小不允许。
      input = convLevel(i,input,'p')
    else:
      input = convLevel(i,input,'c')

训练:

pred = VGG(x, weights, biases, keep_prob)
cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=pred,labels=y))
optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(cost)
 
correct_pred = tf.equal(tf.argmax(pred,1), tf.argmax(y,1))
accuracy_ = tf.reduce_mean(tf.cast(correct_pred,tf.float32))
 
init = tf.global_variables_initializer()
with tf.Session() as sess:
  sess.run(init)
  step = 1
  while step*batch_size < train_iters:
    batch_x,batch_y = mnist.train.next_batch(batch_size)
    sess.run(optimizer,feed_dict={x:batch_x,y:batch_y,keep_prob:dropout})
    print(step*batch_size)
    if step % display_step == 0 :
      #loss,acc = sess.run([cost,accuracy],feed_dict={x:batch_x,y:batch_y,keep_prob=1.0})
      acc = sess.run(accuracy_, feed_dict={x: batch_x, y: batch_y, keep_prob: 1.})
      # 计算损失值
      
      loss = sess.run(cost, feed_dict={x: batch_x, y: batch_y, keep_prob: 1.})
      print("iter: "+str(step*batch_size)+"mini batch Loss="+"{:.6f}".format(loss)+",acc="+"{:6f}".format(acc))
 
    step += 1 
   
  print("training end!")

最终效果:

训练10000次后:结果如下:

iter: 12288 mini batch Loss=5088409.500000,acc=0.578125

iter: 12800 mini batch Loss=4514274.000000,acc=0.601562

iter: 13312 mini batch Loss=4483454.500000,acc=0.648438

这种深度的模型可以考虑循环10万次以上。目前效果还不错,本人没有GPU,心痛笔记本的CPU,100%的CPU利用率,听到风扇响就不忍心再训练,本文也借鉴了alex网络实现,当然我也实现了这个网络模型。在MNIST数据上,ALEX由于层数较少,收敛更快,当然MNIST,用CNN足够了。

以上这篇使用tensorflow实现VGG网络,训练mnist数据集方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
遍历python字典几种方法总结(推荐)
Sep 11 Python
Python 逐行分割大txt文件的方法
Oct 10 Python
Python使用matplotlib绘图无法显示中文问题的解决方法
Mar 14 Python
python Opencv将图片转为字符画
Feb 19 Python
Python字符串的一些操作方法总结
Jun 10 Python
Django 重写用户模型的实现
Jul 29 Python
使用coverage统计python web项目代码覆盖率的方法详解
Aug 05 Python
Python 读取有公式cell的结果内容实例方法
Feb 17 Python
python 最简单的实现适配器设计模式的示例
Jun 30 Python
python制作一个简单的gui 数据库查询界面
Nov 19 Python
Django如何重置migration的几种情景
Feb 24 Python
Python几种酷炫的进度条的方式
Apr 11 Python
浅谈Tensorflow加载Vgg预训练模型的几个注意事项
May 26 #Python
Tensorflow加载Vgg预训练模型操作
May 26 #Python
PyQt5如何将.ui文件转换为.py文件的实例代码
May 26 #Python
TensorFlow实现模型断点训练,checkpoint模型载入方式
May 26 #Python
python 日志模块 日志等级设置失效的解决方案
May 26 #Python
python3.7+selenium模拟淘宝登录功能的实现
May 26 #Python
TensorFlow固化模型的实现操作
May 26 #Python
You might like
PHP 设计模式之观察者模式介绍
2012/02/22 PHP
php版微信自定义回复功能示例
2016/12/05 PHP
php实现用户登陆简单实例
2017/04/04 PHP
PHP实现的获取文件mimes类型工具类示例
2018/04/08 PHP
PHP大文件分割分片上传实现代码
2020/12/09 PHP
关于Javascript 的 prototype问题。
2007/01/03 Javascript
也说JavaScript中String类的replace函数
2011/09/22 Javascript
javascript实现简单的Map示例介绍
2013/12/23 Javascript
借助JavaScript脚本判断浏览器Flash Player信息的方法
2014/07/09 Javascript
jQuery插件expander实现图片翻转特效
2015/05/21 Javascript
jQuery 3.0 的 setter和getter 模式详解
2016/07/11 Javascript
node.js路径处理方法以及绝对路径详解
2021/03/04 Javascript
js实现音乐播放控制条
2017/09/09 Javascript
javascript实现循环广告条效果
2017/12/12 Javascript
vue源码学习之Object.defineProperty 对数组监听
2018/05/30 Javascript
vue之a-table中实现清空选中的数据
2019/11/07 Javascript
[05:08]第一届“网鱼杯”DOTA2比赛精彩集锦
2014/09/05 DOTA
用Python制作检测Linux运行信息的工具的教程
2015/04/01 Python
Python设置Socket代理及实现远程摄像头控制的例子
2015/11/13 Python
Python工程师面试题 与Python Web相关
2016/01/14 Python
TensorFlow在MAC环境下的安装及环境搭建
2017/11/14 Python
解决python3爬虫无法显示中文的问题
2018/04/12 Python
Python使用pyodbc访问数据库操作方法详解
2018/07/05 Python
python自定义线程池控制线程数量的示例
2019/02/22 Python
Python定义函数功能与用法实例详解
2019/04/08 Python
python中upper是做什么用的
2020/07/20 Python
python爬虫 requests-html的使用
2020/11/30 Python
python爬虫搭配起Bilibili唧唧的流程分析
2020/12/01 Python
美国一家主营日韩美妆护肤品的在线商店:iMomoko
2016/09/11 全球购物
英国巧克力贸易公司:Chocolate Trading Company
2017/03/21 全球购物
最新大学生创业计划书写作攻略
2014/04/02 职场文书
《学会合作》教学反思
2014/04/12 职场文书
售后前台接待岗位职责
2015/04/03 职场文书
入党介绍人考察意见
2015/06/01 职场文书
导游词之西安骊山
2019/12/20 职场文书
Python中requests库的用法详解
2022/06/05 Python