使用tensorflow实现VGG网络,训练mnist数据集方式


Posted in Python onMay 26, 2020

VGG作为流行的几个模型之一,训练图形数据效果不错,在mnist数据集是常用的入门集数据,VGG层数非常多,如果严格按照规范来实现,并用来训练mnist数据集,会出现各种问题,如,经过16层卷积后,28*28*1的图片几乎无法进行。

先介绍下VGG

ILSVRC 2014的第二名是Karen Simonyan和 Andrew Zisserman实现的卷积神经网络,现在称其为VGGNet。它主要的贡献是展示出网络的深度是算法优良性能的关键部分。

他们最好的网络包含了16个卷积/全连接层。网络的结构非常一致,从头到尾全部使用的是3x3的卷积和2x2的汇聚。他们的预训练模型是可以在网络上获得并在Caffe中使用的。

VGGNet不好的一点是它耗费更多计算资源,并且使用了更多的参数,导致更多的内存占用(140M)。其中绝大多数的参数都是来自于第一个全连接层。

模型结构:

使用tensorflow实现VGG网络,训练mnist数据集方式

本文在实现时候,尽量保存VGG原来模型结构,核心代码如下:

weights ={
  'wc1':tf.Variable(tf.random_normal([3,3,1,64])),
  'wc2':tf.Variable(tf.random_normal([3,3,64,64])),
  'wc3':tf.Variable(tf.random_normal([3,3,64,128])),
  'wc4':tf.Variable(tf.random_normal([3,3,128,128])),
  
  'wc5':tf.Variable(tf.random_normal([3,3,128,256])),
  'wc6':tf.Variable(tf.random_normal([3,3,256,256])),
  'wc7':tf.Variable(tf.random_normal([3,3,256,256])),
  'wc8':tf.Variable(tf.random_normal([3,3,256,256])),
  
  'wc9':tf.Variable(tf.random_normal([3,3,256,512])),
  'wc10':tf.Variable(tf.random_normal([3,3,512,512])),
  'wc11':tf.Variable(tf.random_normal([3,3,512,512])),
  'wc12':tf.Variable(tf.random_normal([3,3,512,512])),
  'wc13':tf.Variable(tf.random_normal([3,3,512,512])),
  'wc14':tf.Variable(tf.random_normal([3,3,512,512])),
  'wc15':tf.Variable(tf.random_normal([3,3,512,512])),
  'wc16':tf.Variable(tf.random_normal([3,3,512,256])),
  
  'wd1':tf.Variable(tf.random_normal([4096,4096])),
  'wd2':tf.Variable(tf.random_normal([4096,4096])),
  'out':tf.Variable(tf.random_normal([4096,nn_classes])),
}
 
biases ={
  'bc1':tf.Variable(tf.zeros([64])),
  'bc2':tf.Variable(tf.zeros([64])),
  'bc3':tf.Variable(tf.zeros([128])),
  'bc4':tf.Variable(tf.zeros([128])),
  'bc5':tf.Variable(tf.zeros([256])),
  'bc6':tf.Variable(tf.zeros([256])),
  'bc7':tf.Variable(tf.zeros([256])),
  'bc8':tf.Variable(tf.zeros([256])),
  'bc9':tf.Variable(tf.zeros([512])),
  'bc10':tf.Variable(tf.zeros([512])),
  'bc11':tf.Variable(tf.zeros([512])),
  'bc12':tf.Variable(tf.zeros([512])),
  'bc13':tf.Variable(tf.zeros([512])),
  'bc14':tf.Variable(tf.zeros([512])),
  'bc15':tf.Variable(tf.zeros([512])),
  'bc16':tf.Variable(tf.zeros([256])),
  
  
  'bd1':tf.Variable(tf.zeros([4096])),
  'bd2':tf.Variable(tf.zeros([4096])),
  'out':tf.Variable(tf.zeros([nn_classes])),
}

卷积实现:

def convLevel(i,input,type):
  num = i
  out = conv2D('conv'+str(num),input,weights['wc'+str(num)],biases['bc'+str(num)])
  if type=='p':
    out = maxPool2D('pool'+str(num),out, k=2) 
    out = norm('norm'+str(num),out, lsize=4)
  return out 
 
def VGG(x,weights,biases,dropout):
  x = tf.reshape(x,shape=[-1,28,28,1])
 
  input = x
 
  for i in range(16):
    i += 1
    if(i==2) or (i==4) or (i==12) : # 根据模型定义还需要更多的POOL化,但mnist图片大小不允许。
      input = convLevel(i,input,'p')
    else:
      input = convLevel(i,input,'c')

训练:

pred = VGG(x, weights, biases, keep_prob)
cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=pred,labels=y))
optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(cost)
 
correct_pred = tf.equal(tf.argmax(pred,1), tf.argmax(y,1))
accuracy_ = tf.reduce_mean(tf.cast(correct_pred,tf.float32))
 
init = tf.global_variables_initializer()
with tf.Session() as sess:
  sess.run(init)
  step = 1
  while step*batch_size < train_iters:
    batch_x,batch_y = mnist.train.next_batch(batch_size)
    sess.run(optimizer,feed_dict={x:batch_x,y:batch_y,keep_prob:dropout})
    print(step*batch_size)
    if step % display_step == 0 :
      #loss,acc = sess.run([cost,accuracy],feed_dict={x:batch_x,y:batch_y,keep_prob=1.0})
      acc = sess.run(accuracy_, feed_dict={x: batch_x, y: batch_y, keep_prob: 1.})
      # 计算损失值
      
      loss = sess.run(cost, feed_dict={x: batch_x, y: batch_y, keep_prob: 1.})
      print("iter: "+str(step*batch_size)+"mini batch Loss="+"{:.6f}".format(loss)+",acc="+"{:6f}".format(acc))
 
    step += 1 
   
  print("training end!")

最终效果:

训练10000次后:结果如下:

iter: 12288 mini batch Loss=5088409.500000,acc=0.578125

iter: 12800 mini batch Loss=4514274.000000,acc=0.601562

iter: 13312 mini batch Loss=4483454.500000,acc=0.648438

这种深度的模型可以考虑循环10万次以上。目前效果还不错,本人没有GPU,心痛笔记本的CPU,100%的CPU利用率,听到风扇响就不忍心再训练,本文也借鉴了alex网络实现,当然我也实现了这个网络模型。在MNIST数据上,ALEX由于层数较少,收敛更快,当然MNIST,用CNN足够了。

以上这篇使用tensorflow实现VGG网络,训练mnist数据集方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python多层嵌套list的递归处理方法(推荐)
Jun 08 Python
Python制作词云的方法
Jan 03 Python
Python语言描述连续子数组的最大和
Jan 04 Python
mac下pycharm设置python版本的图文教程
Jun 13 Python
浅谈Tensorflow由于版本问题出现的几种错误及解决方法
Jun 13 Python
Python构建图像分类识别器的方法
Jan 12 Python
Win10系统下安装labelme及json文件批量转化方法
Jul 30 Python
使用 Django Highcharts 实现数据可视化过程解析
Jul 31 Python
MNIST数据集转化为二维图片的实现示例
Jan 10 Python
使用 Python 遍历目录树的方法
Feb 29 Python
python 在threading中如何处理主进程和子线程的关系
Apr 25 Python
Python FuzzyWuzzy实现模糊匹配
Apr 28 Python
浅谈Tensorflow加载Vgg预训练模型的几个注意事项
May 26 #Python
Tensorflow加载Vgg预训练模型操作
May 26 #Python
PyQt5如何将.ui文件转换为.py文件的实例代码
May 26 #Python
TensorFlow实现模型断点训练,checkpoint模型载入方式
May 26 #Python
python 日志模块 日志等级设置失效的解决方案
May 26 #Python
python3.7+selenium模拟淘宝登录功能的实现
May 26 #Python
TensorFlow固化模型的实现操作
May 26 #Python
You might like
第十节--抽象方法和抽象类
2006/11/16 PHP
PHP Session_Regenerate_ID函数双释放内存破坏漏洞
2011/01/27 PHP
解析yii数据库的增删查改
2013/06/20 PHP
根据中文裁减字符串函数的php代码
2013/12/03 PHP
ThinkPHP的Widget扩展实例
2014/06/19 PHP
php利用ob_start()清除输出和选择性输出的方法
2018/01/18 PHP
jQuery+CSS 实现的超Sexy下拉菜单
2010/01/17 Javascript
js实时获取系统当前时间实例代码
2013/06/28 Javascript
jquery map方法使用示例
2014/04/23 Javascript
浅析JavaScript中的事件机制
2015/06/04 Javascript
javascript常用经典算法实例详解
2015/11/25 Javascript
jQuery的实例及必知重要的jQuery选择器详解
2016/05/20 Javascript
angularjs中使用ng-bind-html和ng-include的实例
2017/04/28 Javascript
使用Browserify来实现CommonJS的浏览器加载方法
2017/05/14 Javascript
js下拉菜单生成器dropMenu使用方法详解
2017/08/01 Javascript
在vue项目中使用md5加密的方法
2018/09/14 Javascript
详解vue开发中调用微信jssdk的问题
2019/04/16 Javascript
Vue实现搜索结果高亮显示关键字
2019/05/28 Javascript
微信小程序 this.triggerEvent()的具体使用
2019/12/10 Javascript
vue 解决在微信内置浏览器中调用支付宝支付的情况
2020/11/09 Javascript
python实现的生成随机迷宫算法核心代码分享(含游戏完整代码)
2014/07/11 Python
深入理解Python 代码优化详解
2014/10/27 Python
详解python3中socket套接字的编码问题解决
2017/07/01 Python
python利用正则表达式排除集合中字符的功能示例
2017/10/10 Python
numpy 声明空数组详解
2019/12/05 Python
python装饰器代码深入讲解
2021/03/01 Python
美国折扣宠物药房:Total Pet Supply
2018/05/27 全球购物
西班牙著名的珠宝首饰品牌:P D PAOLA
2018/09/15 全球购物
会计师事务所审计实习自我鉴定
2013/09/20 职场文书
学历公证委托书
2014/04/09 职场文书
护士年终个人总结
2015/02/13 职场文书
行政诉讼答辩状
2015/05/21 职场文书
毕业证明模板
2015/06/19 职场文书
JavaScript实现简单拖拽效果
2021/09/15 Javascript
JS实现数组去重的11种方法总结
2022/04/04 Javascript
Nginx如何配置多个服务域名解析共用80端口详解
2022/09/23 Servers