使用tensorflow实现VGG网络,训练mnist数据集方式


Posted in Python onMay 26, 2020

VGG作为流行的几个模型之一,训练图形数据效果不错,在mnist数据集是常用的入门集数据,VGG层数非常多,如果严格按照规范来实现,并用来训练mnist数据集,会出现各种问题,如,经过16层卷积后,28*28*1的图片几乎无法进行。

先介绍下VGG

ILSVRC 2014的第二名是Karen Simonyan和 Andrew Zisserman实现的卷积神经网络,现在称其为VGGNet。它主要的贡献是展示出网络的深度是算法优良性能的关键部分。

他们最好的网络包含了16个卷积/全连接层。网络的结构非常一致,从头到尾全部使用的是3x3的卷积和2x2的汇聚。他们的预训练模型是可以在网络上获得并在Caffe中使用的。

VGGNet不好的一点是它耗费更多计算资源,并且使用了更多的参数,导致更多的内存占用(140M)。其中绝大多数的参数都是来自于第一个全连接层。

模型结构:

使用tensorflow实现VGG网络,训练mnist数据集方式

本文在实现时候,尽量保存VGG原来模型结构,核心代码如下:

weights ={
  'wc1':tf.Variable(tf.random_normal([3,3,1,64])),
  'wc2':tf.Variable(tf.random_normal([3,3,64,64])),
  'wc3':tf.Variable(tf.random_normal([3,3,64,128])),
  'wc4':tf.Variable(tf.random_normal([3,3,128,128])),
  
  'wc5':tf.Variable(tf.random_normal([3,3,128,256])),
  'wc6':tf.Variable(tf.random_normal([3,3,256,256])),
  'wc7':tf.Variable(tf.random_normal([3,3,256,256])),
  'wc8':tf.Variable(tf.random_normal([3,3,256,256])),
  
  'wc9':tf.Variable(tf.random_normal([3,3,256,512])),
  'wc10':tf.Variable(tf.random_normal([3,3,512,512])),
  'wc11':tf.Variable(tf.random_normal([3,3,512,512])),
  'wc12':tf.Variable(tf.random_normal([3,3,512,512])),
  'wc13':tf.Variable(tf.random_normal([3,3,512,512])),
  'wc14':tf.Variable(tf.random_normal([3,3,512,512])),
  'wc15':tf.Variable(tf.random_normal([3,3,512,512])),
  'wc16':tf.Variable(tf.random_normal([3,3,512,256])),
  
  'wd1':tf.Variable(tf.random_normal([4096,4096])),
  'wd2':tf.Variable(tf.random_normal([4096,4096])),
  'out':tf.Variable(tf.random_normal([4096,nn_classes])),
}
 
biases ={
  'bc1':tf.Variable(tf.zeros([64])),
  'bc2':tf.Variable(tf.zeros([64])),
  'bc3':tf.Variable(tf.zeros([128])),
  'bc4':tf.Variable(tf.zeros([128])),
  'bc5':tf.Variable(tf.zeros([256])),
  'bc6':tf.Variable(tf.zeros([256])),
  'bc7':tf.Variable(tf.zeros([256])),
  'bc8':tf.Variable(tf.zeros([256])),
  'bc9':tf.Variable(tf.zeros([512])),
  'bc10':tf.Variable(tf.zeros([512])),
  'bc11':tf.Variable(tf.zeros([512])),
  'bc12':tf.Variable(tf.zeros([512])),
  'bc13':tf.Variable(tf.zeros([512])),
  'bc14':tf.Variable(tf.zeros([512])),
  'bc15':tf.Variable(tf.zeros([512])),
  'bc16':tf.Variable(tf.zeros([256])),
  
  
  'bd1':tf.Variable(tf.zeros([4096])),
  'bd2':tf.Variable(tf.zeros([4096])),
  'out':tf.Variable(tf.zeros([nn_classes])),
}

卷积实现:

def convLevel(i,input,type):
  num = i
  out = conv2D('conv'+str(num),input,weights['wc'+str(num)],biases['bc'+str(num)])
  if type=='p':
    out = maxPool2D('pool'+str(num),out, k=2) 
    out = norm('norm'+str(num),out, lsize=4)
  return out 
 
def VGG(x,weights,biases,dropout):
  x = tf.reshape(x,shape=[-1,28,28,1])
 
  input = x
 
  for i in range(16):
    i += 1
    if(i==2) or (i==4) or (i==12) : # 根据模型定义还需要更多的POOL化,但mnist图片大小不允许。
      input = convLevel(i,input,'p')
    else:
      input = convLevel(i,input,'c')

训练:

pred = VGG(x, weights, biases, keep_prob)
cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=pred,labels=y))
optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(cost)
 
correct_pred = tf.equal(tf.argmax(pred,1), tf.argmax(y,1))
accuracy_ = tf.reduce_mean(tf.cast(correct_pred,tf.float32))
 
init = tf.global_variables_initializer()
with tf.Session() as sess:
  sess.run(init)
  step = 1
  while step*batch_size < train_iters:
    batch_x,batch_y = mnist.train.next_batch(batch_size)
    sess.run(optimizer,feed_dict={x:batch_x,y:batch_y,keep_prob:dropout})
    print(step*batch_size)
    if step % display_step == 0 :
      #loss,acc = sess.run([cost,accuracy],feed_dict={x:batch_x,y:batch_y,keep_prob=1.0})
      acc = sess.run(accuracy_, feed_dict={x: batch_x, y: batch_y, keep_prob: 1.})
      # 计算损失值
      
      loss = sess.run(cost, feed_dict={x: batch_x, y: batch_y, keep_prob: 1.})
      print("iter: "+str(step*batch_size)+"mini batch Loss="+"{:.6f}".format(loss)+",acc="+"{:6f}".format(acc))
 
    step += 1 
   
  print("training end!")

最终效果:

训练10000次后:结果如下:

iter: 12288 mini batch Loss=5088409.500000,acc=0.578125

iter: 12800 mini batch Loss=4514274.000000,acc=0.601562

iter: 13312 mini batch Loss=4483454.500000,acc=0.648438

这种深度的模型可以考虑循环10万次以上。目前效果还不错,本人没有GPU,心痛笔记本的CPU,100%的CPU利用率,听到风扇响就不忍心再训练,本文也借鉴了alex网络实现,当然我也实现了这个网络模型。在MNIST数据上,ALEX由于层数较少,收敛更快,当然MNIST,用CNN足够了。

以上这篇使用tensorflow实现VGG网络,训练mnist数据集方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python中数字以及算数运算符的相关使用
Oct 12 Python
Python排序算法实例代码
Aug 10 Python
Python排序搜索基本算法之希尔排序实例分析
Dec 09 Python
Python tkinter实现的图片移动碰撞动画效果【附源码下载】
Jan 04 Python
python dataframe 输出结果整行显示的方法
Jun 14 Python
使用python实现滑动验证码功能
Aug 05 Python
python3 selenium自动化测试 强大的CSS定位方法
Aug 23 Python
python socket通信编程实现文件上传代码实例
Dec 14 Python
pytorch中 gpu与gpu、gpu与cpu 在load时相互转化操作
May 25 Python
python pillow库的基础使用教程
Jan 13 Python
selenium3.0+python之环境搭建的方法步骤
Feb 01 Python
Python .py生成.pyd文件并打包.exe 的注意事项说明
Mar 04 Python
浅谈Tensorflow加载Vgg预训练模型的几个注意事项
May 26 #Python
Tensorflow加载Vgg预训练模型操作
May 26 #Python
PyQt5如何将.ui文件转换为.py文件的实例代码
May 26 #Python
TensorFlow实现模型断点训练,checkpoint模型载入方式
May 26 #Python
python 日志模块 日志等级设置失效的解决方案
May 26 #Python
python3.7+selenium模拟淘宝登录功能的实现
May 26 #Python
TensorFlow固化模型的实现操作
May 26 #Python
You might like
php读取txt文件组成SQL并插入数据库的代码(原创自Zjmainstay)
2012/07/31 PHP
PHP中auto_prepend_file与auto_append_file用法实例分析
2014/09/22 PHP
PHP基于cookie与session统计网站访问量并输出显示的方法
2016/01/15 PHP
PHP中Laravel 关联查询返回错误id的解决方法
2017/04/01 PHP
浅谈laravel-admin form中的数据,在提交后,保存前,获取并进行编辑
2019/10/21 PHP
显示js对象所有属性和方法的函数
2009/10/16 Javascript
来自国外的14个图片放大编辑的jQuery插件整理
2010/10/20 Javascript
DOM节点深度克隆函数cloneNode()用法实例
2015/01/12 Javascript
jQuery表单域选择器用法分析
2015/02/10 Javascript
浅谈jQuery双事件多重加载的问题
2016/10/05 Javascript
简单实现JavaScript图片切换效果
2016/11/28 Javascript
ionic2自定义cordova插件开发以及使用(Android)
2017/06/19 Javascript
vue2利用Bus.js如何实现非父子组件通信详解
2017/08/25 Javascript
微信小程序class封装http代码实例
2019/08/24 Javascript
VUE页面中通过双击实现复制表格中内容的示例代码
2020/06/11 Javascript
Vue 实现创建全局组件,并且使用Vue.use() 载入方式
2020/08/11 Javascript
[00:34]TI7不朽珍藏III——地穴编织者不朽展示
2017/07/15 DOTA
Python实现发送email的几种常用方法
2014/08/18 Python
Python数据分析matplotlib设置多个子图的间距方法
2018/08/03 Python
Python搭建代理IP池实现获取IP的方法
2019/10/27 Python
python 遍历pd.Series的index和value
2019/11/26 Python
python+selenium+PhantomJS抓取网页动态加载内容
2020/02/25 Python
Pandas —— resample()重采样和asfreq()频度转换方式
2020/02/26 Python
Python判断字符串是否为空和null方法实例
2020/04/26 Python
HTML5标签嵌套规则详解【必看】
2016/04/26 HTML / CSS
香港卓悦化妆品官网:BONJOUR
2017/09/21 全球购物
长青弘远的面试题
2012/06/09 面试题
金融专业个人求职信
2013/09/22 职场文书
高级销售员求职信
2013/10/25 职场文书
个人合作协议书范本
2014/04/18 职场文书
教导主任竞聘演讲稿
2014/05/16 职场文书
工会换届选举方案
2014/05/21 职场文书
乡镇党员干部群众路线对照检查材料思想汇报
2014/09/28 职场文书
2015年药店工作总结
2015/04/20 职场文书
2015年度个人工作总结报告
2015/10/24 职场文书
Python字符串对齐方法使用(ljust()、rjust()和center())
2021/04/26 Python