使用tensorflow实现VGG网络,训练mnist数据集方式


Posted in Python onMay 26, 2020

VGG作为流行的几个模型之一,训练图形数据效果不错,在mnist数据集是常用的入门集数据,VGG层数非常多,如果严格按照规范来实现,并用来训练mnist数据集,会出现各种问题,如,经过16层卷积后,28*28*1的图片几乎无法进行。

先介绍下VGG

ILSVRC 2014的第二名是Karen Simonyan和 Andrew Zisserman实现的卷积神经网络,现在称其为VGGNet。它主要的贡献是展示出网络的深度是算法优良性能的关键部分。

他们最好的网络包含了16个卷积/全连接层。网络的结构非常一致,从头到尾全部使用的是3x3的卷积和2x2的汇聚。他们的预训练模型是可以在网络上获得并在Caffe中使用的。

VGGNet不好的一点是它耗费更多计算资源,并且使用了更多的参数,导致更多的内存占用(140M)。其中绝大多数的参数都是来自于第一个全连接层。

模型结构:

使用tensorflow实现VGG网络,训练mnist数据集方式

本文在实现时候,尽量保存VGG原来模型结构,核心代码如下:

weights ={
  'wc1':tf.Variable(tf.random_normal([3,3,1,64])),
  'wc2':tf.Variable(tf.random_normal([3,3,64,64])),
  'wc3':tf.Variable(tf.random_normal([3,3,64,128])),
  'wc4':tf.Variable(tf.random_normal([3,3,128,128])),
  
  'wc5':tf.Variable(tf.random_normal([3,3,128,256])),
  'wc6':tf.Variable(tf.random_normal([3,3,256,256])),
  'wc7':tf.Variable(tf.random_normal([3,3,256,256])),
  'wc8':tf.Variable(tf.random_normal([3,3,256,256])),
  
  'wc9':tf.Variable(tf.random_normal([3,3,256,512])),
  'wc10':tf.Variable(tf.random_normal([3,3,512,512])),
  'wc11':tf.Variable(tf.random_normal([3,3,512,512])),
  'wc12':tf.Variable(tf.random_normal([3,3,512,512])),
  'wc13':tf.Variable(tf.random_normal([3,3,512,512])),
  'wc14':tf.Variable(tf.random_normal([3,3,512,512])),
  'wc15':tf.Variable(tf.random_normal([3,3,512,512])),
  'wc16':tf.Variable(tf.random_normal([3,3,512,256])),
  
  'wd1':tf.Variable(tf.random_normal([4096,4096])),
  'wd2':tf.Variable(tf.random_normal([4096,4096])),
  'out':tf.Variable(tf.random_normal([4096,nn_classes])),
}
 
biases ={
  'bc1':tf.Variable(tf.zeros([64])),
  'bc2':tf.Variable(tf.zeros([64])),
  'bc3':tf.Variable(tf.zeros([128])),
  'bc4':tf.Variable(tf.zeros([128])),
  'bc5':tf.Variable(tf.zeros([256])),
  'bc6':tf.Variable(tf.zeros([256])),
  'bc7':tf.Variable(tf.zeros([256])),
  'bc8':tf.Variable(tf.zeros([256])),
  'bc9':tf.Variable(tf.zeros([512])),
  'bc10':tf.Variable(tf.zeros([512])),
  'bc11':tf.Variable(tf.zeros([512])),
  'bc12':tf.Variable(tf.zeros([512])),
  'bc13':tf.Variable(tf.zeros([512])),
  'bc14':tf.Variable(tf.zeros([512])),
  'bc15':tf.Variable(tf.zeros([512])),
  'bc16':tf.Variable(tf.zeros([256])),
  
  
  'bd1':tf.Variable(tf.zeros([4096])),
  'bd2':tf.Variable(tf.zeros([4096])),
  'out':tf.Variable(tf.zeros([nn_classes])),
}

卷积实现:

def convLevel(i,input,type):
  num = i
  out = conv2D('conv'+str(num),input,weights['wc'+str(num)],biases['bc'+str(num)])
  if type=='p':
    out = maxPool2D('pool'+str(num),out, k=2) 
    out = norm('norm'+str(num),out, lsize=4)
  return out 
 
def VGG(x,weights,biases,dropout):
  x = tf.reshape(x,shape=[-1,28,28,1])
 
  input = x
 
  for i in range(16):
    i += 1
    if(i==2) or (i==4) or (i==12) : # 根据模型定义还需要更多的POOL化,但mnist图片大小不允许。
      input = convLevel(i,input,'p')
    else:
      input = convLevel(i,input,'c')

训练:

pred = VGG(x, weights, biases, keep_prob)
cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=pred,labels=y))
optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(cost)
 
correct_pred = tf.equal(tf.argmax(pred,1), tf.argmax(y,1))
accuracy_ = tf.reduce_mean(tf.cast(correct_pred,tf.float32))
 
init = tf.global_variables_initializer()
with tf.Session() as sess:
  sess.run(init)
  step = 1
  while step*batch_size < train_iters:
    batch_x,batch_y = mnist.train.next_batch(batch_size)
    sess.run(optimizer,feed_dict={x:batch_x,y:batch_y,keep_prob:dropout})
    print(step*batch_size)
    if step % display_step == 0 :
      #loss,acc = sess.run([cost,accuracy],feed_dict={x:batch_x,y:batch_y,keep_prob=1.0})
      acc = sess.run(accuracy_, feed_dict={x: batch_x, y: batch_y, keep_prob: 1.})
      # 计算损失值
      
      loss = sess.run(cost, feed_dict={x: batch_x, y: batch_y, keep_prob: 1.})
      print("iter: "+str(step*batch_size)+"mini batch Loss="+"{:.6f}".format(loss)+",acc="+"{:6f}".format(acc))
 
    step += 1 
   
  print("training end!")

最终效果:

训练10000次后:结果如下:

iter: 12288 mini batch Loss=5088409.500000,acc=0.578125

iter: 12800 mini batch Loss=4514274.000000,acc=0.601562

iter: 13312 mini batch Loss=4483454.500000,acc=0.648438

这种深度的模型可以考虑循环10万次以上。目前效果还不错,本人没有GPU,心痛笔记本的CPU,100%的CPU利用率,听到风扇响就不忍心再训练,本文也借鉴了alex网络实现,当然我也实现了这个网络模型。在MNIST数据上,ALEX由于层数较少,收敛更快,当然MNIST,用CNN足够了。

以上这篇使用tensorflow实现VGG网络,训练mnist数据集方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python 实现数组list 添加、修改、删除的方法
Apr 04 Python
python将txt等文件中的数据读为numpy数组的方法
Dec 22 Python
python实现多层感知器MLP(基于双月数据集)
Jan 18 Python
python挖矿算力测试程序详解
Jul 03 Python
python3 deque 双向队列创建与使用方法分析
Mar 24 Python
Pytorch转onnx、torchscript方式
May 25 Python
Matplotlib自定义坐标轴刻度的实现示例
Jun 18 Python
基于python实现ROC曲线绘制广场解析
Jun 28 Python
python中plt.imshow与cv2.imshow显示颜色问题
Jul 16 Python
Python json格式化打印实现过程解析
Jul 21 Python
基于Python爬取51cto博客页面信息过程解析
Aug 25 Python
python opencv常用图形绘制方法(线段、矩形、圆形、椭圆、文本)
Apr 12 Python
浅谈Tensorflow加载Vgg预训练模型的几个注意事项
May 26 #Python
Tensorflow加载Vgg预训练模型操作
May 26 #Python
PyQt5如何将.ui文件转换为.py文件的实例代码
May 26 #Python
TensorFlow实现模型断点训练,checkpoint模型载入方式
May 26 #Python
python 日志模块 日志等级设置失效的解决方案
May 26 #Python
python3.7+selenium模拟淘宝登录功能的实现
May 26 #Python
TensorFlow固化模型的实现操作
May 26 #Python
You might like
B2K与车机的中波PK
2021/03/02 无线电
捕获关闭窗口的脚本
2009/01/10 Javascript
jquery 事件执行检测代码
2009/12/09 Javascript
javascipt基础内容--需要注意的细节
2013/04/10 Javascript
location对象的属性和方法应用(解析URL)
2013/04/12 Javascript
解析Jquery中如何把一段html代码动态写入到DIV中(实例说明)
2013/07/09 Javascript
js清空form表单中的内容示例
2014/05/20 Javascript
jquery实现submit提交表单
2015/02/03 Javascript
jquery使用经验小结
2015/05/20 Javascript
JS基于Ajax实现的网页Loading效果代码
2015/10/27 Javascript
jQuery基于ajax方式实现用户名存在性检查功能示例
2017/02/10 Javascript
jQuery用户头像裁剪插件cropbox.js使用详解
2017/06/07 jQuery
node跨域请求方法小结
2017/08/25 Javascript
解决v-for中使用v-if或者v-bind:class失效的问题
2018/09/25 Javascript
修改Vue打包后的默认文件名操作
2020/08/12 Javascript
Vue axios 跨域请求无法带上cookie的解决
2020/09/08 Javascript
Django Admin实现上传图片校验功能
2016/03/06 Python
使用Numpy读取CSV文件,并进行行列删除的操作方法
2018/07/04 Python
python排序函数sort()与sorted()的区别
2018/09/18 Python
python解析json串与正则匹配对比方法
2018/12/20 Python
python如何统计代码运行的时长
2019/07/24 Python
python快速排序的实现及运行时间比较
2019/11/22 Python
Python自定义聚合函数merge与transform区别详解
2020/05/26 Python
python Matplotlib数据可视化(2):详解三大容器对象与常用设置
2020/09/30 Python
jupyter notebook更换皮肤主题的实现
2021/01/07 Python
M1芯片安装python3.9.1的实现
2021/02/02 Python
Guess荷兰官网:美国服饰品牌
2020/01/22 全球购物
为什么会有内存对齐
2016/10/10 面试题
个人生活学习自我评价范文
2013/11/26 职场文书
《争吵》教学反思
2014/02/15 职场文书
某某同志考察材料
2014/05/28 职场文书
村党支部书记四风问题个人对照检查材料思想汇报
2014/10/06 职场文书
2016幼儿园教师年度考核评语
2015/12/01 职场文书
PyQt5 显示超清高分辨率图片的方法
2021/04/11 Python
python实现进度条的多种实现
2021/04/29 Python
MySQL生成千万测试数据以及遇到的问题
2022/08/05 MySQL