使用tensorflow实现VGG网络,训练mnist数据集方式


Posted in Python onMay 26, 2020

VGG作为流行的几个模型之一,训练图形数据效果不错,在mnist数据集是常用的入门集数据,VGG层数非常多,如果严格按照规范来实现,并用来训练mnist数据集,会出现各种问题,如,经过16层卷积后,28*28*1的图片几乎无法进行。

先介绍下VGG

ILSVRC 2014的第二名是Karen Simonyan和 Andrew Zisserman实现的卷积神经网络,现在称其为VGGNet。它主要的贡献是展示出网络的深度是算法优良性能的关键部分。

他们最好的网络包含了16个卷积/全连接层。网络的结构非常一致,从头到尾全部使用的是3x3的卷积和2x2的汇聚。他们的预训练模型是可以在网络上获得并在Caffe中使用的。

VGGNet不好的一点是它耗费更多计算资源,并且使用了更多的参数,导致更多的内存占用(140M)。其中绝大多数的参数都是来自于第一个全连接层。

模型结构:

使用tensorflow实现VGG网络,训练mnist数据集方式

本文在实现时候,尽量保存VGG原来模型结构,核心代码如下:

weights ={
  'wc1':tf.Variable(tf.random_normal([3,3,1,64])),
  'wc2':tf.Variable(tf.random_normal([3,3,64,64])),
  'wc3':tf.Variable(tf.random_normal([3,3,64,128])),
  'wc4':tf.Variable(tf.random_normal([3,3,128,128])),
  
  'wc5':tf.Variable(tf.random_normal([3,3,128,256])),
  'wc6':tf.Variable(tf.random_normal([3,3,256,256])),
  'wc7':tf.Variable(tf.random_normal([3,3,256,256])),
  'wc8':tf.Variable(tf.random_normal([3,3,256,256])),
  
  'wc9':tf.Variable(tf.random_normal([3,3,256,512])),
  'wc10':tf.Variable(tf.random_normal([3,3,512,512])),
  'wc11':tf.Variable(tf.random_normal([3,3,512,512])),
  'wc12':tf.Variable(tf.random_normal([3,3,512,512])),
  'wc13':tf.Variable(tf.random_normal([3,3,512,512])),
  'wc14':tf.Variable(tf.random_normal([3,3,512,512])),
  'wc15':tf.Variable(tf.random_normal([3,3,512,512])),
  'wc16':tf.Variable(tf.random_normal([3,3,512,256])),
  
  'wd1':tf.Variable(tf.random_normal([4096,4096])),
  'wd2':tf.Variable(tf.random_normal([4096,4096])),
  'out':tf.Variable(tf.random_normal([4096,nn_classes])),
}
 
biases ={
  'bc1':tf.Variable(tf.zeros([64])),
  'bc2':tf.Variable(tf.zeros([64])),
  'bc3':tf.Variable(tf.zeros([128])),
  'bc4':tf.Variable(tf.zeros([128])),
  'bc5':tf.Variable(tf.zeros([256])),
  'bc6':tf.Variable(tf.zeros([256])),
  'bc7':tf.Variable(tf.zeros([256])),
  'bc8':tf.Variable(tf.zeros([256])),
  'bc9':tf.Variable(tf.zeros([512])),
  'bc10':tf.Variable(tf.zeros([512])),
  'bc11':tf.Variable(tf.zeros([512])),
  'bc12':tf.Variable(tf.zeros([512])),
  'bc13':tf.Variable(tf.zeros([512])),
  'bc14':tf.Variable(tf.zeros([512])),
  'bc15':tf.Variable(tf.zeros([512])),
  'bc16':tf.Variable(tf.zeros([256])),
  
  
  'bd1':tf.Variable(tf.zeros([4096])),
  'bd2':tf.Variable(tf.zeros([4096])),
  'out':tf.Variable(tf.zeros([nn_classes])),
}

卷积实现:

def convLevel(i,input,type):
  num = i
  out = conv2D('conv'+str(num),input,weights['wc'+str(num)],biases['bc'+str(num)])
  if type=='p':
    out = maxPool2D('pool'+str(num),out, k=2) 
    out = norm('norm'+str(num),out, lsize=4)
  return out 
 
def VGG(x,weights,biases,dropout):
  x = tf.reshape(x,shape=[-1,28,28,1])
 
  input = x
 
  for i in range(16):
    i += 1
    if(i==2) or (i==4) or (i==12) : # 根据模型定义还需要更多的POOL化,但mnist图片大小不允许。
      input = convLevel(i,input,'p')
    else:
      input = convLevel(i,input,'c')

训练:

pred = VGG(x, weights, biases, keep_prob)
cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=pred,labels=y))
optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(cost)
 
correct_pred = tf.equal(tf.argmax(pred,1), tf.argmax(y,1))
accuracy_ = tf.reduce_mean(tf.cast(correct_pred,tf.float32))
 
init = tf.global_variables_initializer()
with tf.Session() as sess:
  sess.run(init)
  step = 1
  while step*batch_size < train_iters:
    batch_x,batch_y = mnist.train.next_batch(batch_size)
    sess.run(optimizer,feed_dict={x:batch_x,y:batch_y,keep_prob:dropout})
    print(step*batch_size)
    if step % display_step == 0 :
      #loss,acc = sess.run([cost,accuracy],feed_dict={x:batch_x,y:batch_y,keep_prob=1.0})
      acc = sess.run(accuracy_, feed_dict={x: batch_x, y: batch_y, keep_prob: 1.})
      # 计算损失值
      
      loss = sess.run(cost, feed_dict={x: batch_x, y: batch_y, keep_prob: 1.})
      print("iter: "+str(step*batch_size)+"mini batch Loss="+"{:.6f}".format(loss)+",acc="+"{:6f}".format(acc))
 
    step += 1 
   
  print("training end!")

最终效果:

训练10000次后:结果如下:

iter: 12288 mini batch Loss=5088409.500000,acc=0.578125

iter: 12800 mini batch Loss=4514274.000000,acc=0.601562

iter: 13312 mini batch Loss=4483454.500000,acc=0.648438

这种深度的模型可以考虑循环10万次以上。目前效果还不错,本人没有GPU,心痛笔记本的CPU,100%的CPU利用率,听到风扇响就不忍心再训练,本文也借鉴了alex网络实现,当然我也实现了这个网络模型。在MNIST数据上,ALEX由于层数较少,收敛更快,当然MNIST,用CNN足够了。

以上这篇使用tensorflow实现VGG网络,训练mnist数据集方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python实现FTP服务器服务的方法
Apr 11 Python
Python中音频处理库pydub的使用教程
Jun 07 Python
使用python 爬虫抓站的一些技巧总结
Jan 10 Python
理论讲解python多进程并发编程
Feb 09 Python
python 解决动态的定义变量名,并给其赋值的方法(大数据处理)
Nov 10 Python
pandas去除重复列的实现方法
Jan 29 Python
pandas.DataFrame的pivot()和unstack()实现行转列
Jul 06 Python
Python 项目转化为so文件实例
Dec 23 Python
Tensorflow 实现分批量读取数据
Jan 04 Python
Pytorch 实现冻结指定卷积层的参数
Jan 06 Python
如何通过Python3和ssl实现加密通信功能
May 09 Python
Python接口自动化测试框架运行原理及流程
Nov 30 Python
浅谈Tensorflow加载Vgg预训练模型的几个注意事项
May 26 #Python
Tensorflow加载Vgg预训练模型操作
May 26 #Python
PyQt5如何将.ui文件转换为.py文件的实例代码
May 26 #Python
TensorFlow实现模型断点训练,checkpoint模型载入方式
May 26 #Python
python 日志模块 日志等级设置失效的解决方案
May 26 #Python
python3.7+selenium模拟淘宝登录功能的实现
May 26 #Python
TensorFlow固化模型的实现操作
May 26 #Python
You might like
分享8个最佳的代码片段在线测试网站
2013/06/29 PHP
浅析php中三个等号(===)和两个等号(==)的区别
2013/08/06 PHP
浅析php与数据库代码开发规范
2013/08/08 PHP
PHP中source #N问题的解决方法
2014/01/27 PHP
PHP两个n位的二进制整数相加问题的解决
2018/08/26 PHP
laravel 判断查询数据库返回值的例子
2019/10/11 PHP
如何在指定的地方插入html内容和文本内容
2013/12/23 Javascript
javascript实现复选框超过限制即弹出警告框的方法
2015/02/25 Javascript
JavaScript代码性能优化总结篇
2016/05/15 Javascript
js addDqmForPP给标签内属性值加上双引号的函数
2016/12/24 Javascript
Bootstrap jquery.twbsPagination.js动态页码分页实例代码
2017/02/20 Javascript
jQuery实现全选、反选和不选功能
2017/08/16 jQuery
详述 Sublime Text 打开 GBK 格式中文乱码的解决方法
2017/10/26 Javascript
微信小程序实现选项卡功能
2020/06/19 Javascript
浅谈Vue-cli 命令行工具分析
2017/11/22 Javascript
Javascript实现运算符重载详解
2018/04/07 Javascript
原生JS实现无缝轮播图片
2020/06/24 Javascript
vue实现公告栏文字上下滚动效果的示例代码
2020/06/16 Javascript
vue 接口请求地址前缀本地开发和线上开发设置方式
2020/08/13 Javascript
js实现简单的点名器随机色实例代码
2020/09/20 Javascript
python网络编程学习笔记(四):域名系统
2014/06/09 Python
Python threading多线程编程实例
2014/09/18 Python
详解Django中的权限和组以及消息
2015/07/23 Python
Python如何获取系统iops示例代码
2016/09/06 Python
Python分治法定义与应用实例详解
2017/07/28 Python
python异常处理、自定义异常、断言原理与用法分析
2020/03/23 Python
Django中FilePathField字段的用法
2020/05/21 Python
Pandas DataFrame求差集的示例代码
2020/12/13 Python
termux中matplotlib无法显示中文问题的解决方法
2021/01/11 Python
Python3+SQLAlchemy+Sqlite3实现ORM教程
2021/02/16 Python
Hotels.com台湾:饭店订房网
2017/09/06 全球购物
印尼综合在线预订网站:Tiket.com(机票、酒店、火车、租车和娱乐)
2018/10/11 全球购物
《雨霖铃》教学反思
2014/02/22 职场文书
师德承诺书
2015/01/20 职场文书
Django实现在线无水印抖音视频下载(附源码及地址)
2021/05/06 Python
Redis 的查询很快的原因解析及Redis 如何保证查询的高效
2022/03/16 Redis