tensorflow实现残差网络方式(mnist数据集)


Posted in Python onMay 26, 2020

介绍

残差网络是何凯明大神的神作,效果非常好,深度可以达到1000层。但是,其实现起来并没有那末难,在这里以tensorflow作为框架,实现基于mnist数据集上的残差网络,当然只是比较浅层的。

如下图所示:

tensorflow实现残差网络方式(mnist数据集)

实线的Connection部分,表示通道相同,如上图的第一个粉色矩形和第三个粉色矩形,都是3x3x64的特征图,由于通道相同,所以采用计算方式为H(x)=F(x)+x

虚线的的Connection部分,表示通道不同,如上图的第一个绿色矩形和第三个绿色矩形,分别是3x3x64和3x3x128的特征图,通道不同,采用的计算方式为H(x)=F(x)+Wx,其中W是卷积操作,用来调整x维度的。

根据输入和输出尺寸是否相同,又分为identity_block和conv_block,每种block有上图两种模式,三卷积和二卷积,三卷积速度更快些,因此在这里选择该种方式。

具体实现见如下代码:

#tensorflow基于mnist数据集上的VGG11网络,可以直接运行
from tensorflow.examples.tutorials.mnist import input_data
import tensorflow as tf
#tensorflow基于mnist实现VGG11
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)

#x=mnist.train.images
#y=mnist.train.labels
#X=mnist.test.images
#Y=mnist.test.labels
x = tf.placeholder(tf.float32, [None,784])
y = tf.placeholder(tf.float32, [None, 10])
sess = tf.InteractiveSession()

def weight_variable(shape):
#这里是构建初始变量
 initial = tf.truncated_normal(shape, mean=0,stddev=0.1)
#创建变量
 return tf.Variable(initial)

def bias_variable(shape):
 initial = tf.constant(0.1, shape=shape)
 return tf.Variable(initial)

#在这里定义残差网络的id_block块,此时输入和输出维度相同
def identity_block(X_input, kernel_size, in_filter, out_filters, stage, block):
 """
 Implementation of the identity block as defined in Figure 3

 Arguments:
 X -- input tensor of shape (m, n_H_prev, n_W_prev, n_C_prev)
 kernel_size -- integer, specifying the shape of the middle CONV's window for the main path
 filters -- python list of integers, defining the number of filters in the CONV layers of the main path
 stage -- integer, used to name the layers, depending on their position in the network
 block -- string/character, used to name the layers, depending on their position in the network
 training -- train or test

 Returns:
 X -- output of the identity block, tensor of shape (n_H, n_W, n_C)
 """

 # defining name basis
 block_name = 'res' + str(stage) + block
 f1, f2, f3 = out_filters
 with tf.variable_scope(block_name):
  X_shortcut = X_input

  #first
  W_conv1 = weight_variable([1, 1, in_filter, f1])
  X = tf.nn.conv2d(X_input, W_conv1, strides=[1, 1, 1, 1], padding='SAME')
  b_conv1 = bias_variable([f1])
  X = tf.nn.relu(X+ b_conv1)

  #second
  W_conv2 = weight_variable([kernel_size, kernel_size, f1, f2])
  X = tf.nn.conv2d(X, W_conv2, strides=[1, 1, 1, 1], padding='SAME')
  b_conv2 = bias_variable([f2])
  X = tf.nn.relu(X+ b_conv2)

  #third

  W_conv3 = weight_variable([1, 1, f2, f3])
  X = tf.nn.conv2d(X, W_conv3, strides=[1, 1, 1, 1], padding='SAME')
  b_conv3 = bias_variable([f3])
  X = tf.nn.relu(X+ b_conv3)
  #final step
  add = tf.add(X, X_shortcut)
  b_conv_fin = bias_variable([f3])
  add_result = tf.nn.relu(add+b_conv_fin)

 return add_result


#这里定义conv_block模块,由于该模块定义时输入和输出尺度不同,故需要进行卷积操作来改变尺度,从而得以相加
def convolutional_block( X_input, kernel_size, in_filter,
    out_filters, stage, block, stride=2):
 """
 Implementation of the convolutional block as defined in Figure 4

 Arguments:
 X -- input tensor of shape (m, n_H_prev, n_W_prev, n_C_prev)
 kernel_size -- integer, specifying the shape of the middle CONV's window for the main path
 filters -- python list of integers, defining the number of filters in the CONV layers of the main path
 stage -- integer, used to name the layers, depending on their position in the network
 block -- string/character, used to name the layers, depending on their position in the network
 training -- train or test
 stride -- Integer, specifying the stride to be used

 Returns:
 X -- output of the convolutional block, tensor of shape (n_H, n_W, n_C)
 """

 # defining name basis
 block_name = 'res' + str(stage) + block
 with tf.variable_scope(block_name):
  f1, f2, f3 = out_filters

  x_shortcut = X_input
  #first
  W_conv1 = weight_variable([1, 1, in_filter, f1])
  X = tf.nn.conv2d(X_input, W_conv1,strides=[1, stride, stride, 1],padding='SAME')
  b_conv1 = bias_variable([f1])
  X = tf.nn.relu(X + b_conv1)

  #second
  W_conv2 =weight_variable([kernel_size, kernel_size, f1, f2])
  X = tf.nn.conv2d(X, W_conv2, strides=[1,1,1,1], padding='SAME')
  b_conv2 = bias_variable([f2])
  X = tf.nn.relu(X+b_conv2)

  #third
  W_conv3 = weight_variable([1,1, f2,f3])
  X = tf.nn.conv2d(X, W_conv3, strides=[1, 1, 1,1], padding='SAME')
  b_conv3 = bias_variable([f3])
  X = tf.nn.relu(X+b_conv3)
  #shortcut path
  W_shortcut =weight_variable([1, 1, in_filter, f3])
  x_shortcut = tf.nn.conv2d(x_shortcut, W_shortcut, strides=[1, stride, stride, 1], padding='VALID')

  #final
  add = tf.add(x_shortcut, X)
  #建立最后融合的权重
  b_conv_fin = bias_variable([f3])
  add_result = tf.nn.relu(add+ b_conv_fin)


 return add_result



x = tf.reshape(x, [-1,28,28,1])
w_conv1 = weight_variable([2, 2, 1, 64])
x = tf.nn.conv2d(x, w_conv1, strides=[1, 2, 2, 1], padding='SAME')
b_conv1 = bias_variable([64])
x = tf.nn.relu(x+b_conv1)
#这里操作后变成14x14x64
x = tf.nn.max_pool(x, ksize=[1, 3, 3, 1],
    strides=[1, 1, 1, 1], padding='SAME')


#stage 2
x = convolutional_block(X_input=x, kernel_size=3, in_filter=64, out_filters=[64, 64, 256], stage=2, block='a', stride=1)
#上述conv_block操作后,尺寸变为14x14x256
x = identity_block(x, 3, 256, [64, 64, 256], stage=2, block='b' )
x = identity_block(x, 3, 256, [64, 64, 256], stage=2, block='c')
#上述操作后张量尺寸变成14x14x256
x = tf.nn.max_pool(x, [1, 2, 2, 1], strides=[1,2,2,1], padding='SAME')
#变成7x7x256
flat = tf.reshape(x, [-1,7*7*256])

w_fc1 = weight_variable([7 * 7 *256, 1024])
b_fc1 = bias_variable([1024])

h_fc1 = tf.nn.relu(tf.matmul(flat, w_fc1) + b_fc1)
keep_prob = tf.placeholder(tf.float32)
h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob)
w_fc2 = weight_variable([1024, 10])
b_fc2 = bias_variable([10])
y_conv = tf.matmul(h_fc1_drop, w_fc2) + b_fc2


#建立损失函数,在这里采用交叉熵函数
cross_entropy = tf.reduce_mean(
 tf.nn.softmax_cross_entropy_with_logits(labels=y, logits=y_conv))

train_step = tf.train.AdamOptimizer(1e-3).minimize(cross_entropy)
correct_prediction = tf.equal(tf.argmax(y_conv,1), tf.argmax(y,1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
#初始化变量

sess.run(tf.global_variables_initializer())

print("cuiwei")
for i in range(2000):
 batch = mnist.train.next_batch(10)
 if i%100 == 0:
 train_accuracy = accuracy.eval(feed_dict={
 x:batch[0], y: batch[1], keep_prob: 1.0})
 print("step %d, training accuracy %g"%(i, train_accuracy))
 train_step.run(feed_dict={x: batch[0], y: batch[1], keep_prob: 0.5})

以上这篇tensorflow实现残差网络方式(mnist数据集)就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python操作CouchDB数据库简单示例
Mar 10 Python
python安装mysql-python简明笔记(ubuntu环境)
Jun 25 Python
Python实现按特定格式对文件进行读写的方法示例
Nov 30 Python
PyQt5每天必学之切换按钮
Aug 20 Python
Python实现的求解最小公倍数算法示例
May 03 Python
对python 操作solr索引数据的实例详解
Dec 07 Python
opencv与numpy的图像基本操作
Mar 08 Python
Python闭包和装饰器用法实例详解
May 22 Python
Python数据可视化处理库PyEcharts柱状图,饼图,线性图,词云图常用实例详解
Feb 10 Python
python删除文件、清空目录的实现方法
Sep 23 Python
python 动态绘制爱心的示例
Sep 27 Python
python接口自动化框架实战
Dec 23 Python
Python中格式化字符串的四种实现
May 26 #Python
使用tensorflow实现VGG网络,训练mnist数据集方式
May 26 #Python
浅谈Tensorflow加载Vgg预训练模型的几个注意事项
May 26 #Python
Tensorflow加载Vgg预训练模型操作
May 26 #Python
PyQt5如何将.ui文件转换为.py文件的实例代码
May 26 #Python
TensorFlow实现模型断点训练,checkpoint模型载入方式
May 26 #Python
python 日志模块 日志等级设置失效的解决方案
May 26 #Python
You might like
《魔兽争霸3:重制版》翻车了?你想要的我们都没有
2019/11/07 魔兽争霸
PHP 输出简单动态WAP页面
2009/06/09 PHP
有关于PHP中常见数据类型的汇总分享
2014/01/06 PHP
codeigniter数据库操作函数汇总
2014/06/12 PHP
PHP中构造函数和析构函数解析
2014/10/10 PHP
thinkphp实现163、QQ邮箱收发邮件的方法
2015/12/18 PHP
PHP快速排序quicksort实例详解
2016/09/28 PHP
PHP表单验证内容是否为空的实现代码
2016/11/14 PHP
PHP开发之归档格式phar文件概念与用法详解【创建,使用,解包还原提取】
2017/11/17 PHP
PHP 8新特性简介
2020/08/18 PHP
PHP后门隐藏的一些技巧总结
2020/11/04 PHP
Javascript 面向对象 对象(Object)
2010/05/13 Javascript
JS控制输入框内字符串长度
2014/05/21 Javascript
Node.js中使用mongoskin操作mongoDB实例
2014/09/28 Javascript
JavaScript中的值类型转换介绍
2014/12/31 Javascript
js中取得变量绝对值的方法
2015/01/03 Javascript
jQuery中DOM树操作之使用反向插入方法实例分析
2015/01/23 Javascript
js实现iPhone界面风格的单选框和复选框按钮实例
2015/08/18 Javascript
JavaScript小技巧整理篇(非常全)
2016/01/26 Javascript
分享一个精简的vue.js 图片lazyload插件实例
2017/03/13 Javascript
微信小程序 检查接口状态实例详解
2017/06/23 Javascript
Vue实现本地购物车功能
2018/12/05 Javascript
layer.open提交子页面的form和layedit文本编辑内容的方法
2019/09/27 Javascript
vscode 调试 node.js的方法步骤
2020/09/15 Javascript
[56:57]LGD vs VP 2019DOTA2国际邀请赛淘汰赛 胜者组赛BO3 第一场 8.20.mp4
2019/08/22 DOTA
Python读取sqlite数据库文件的方法分析
2017/08/07 Python
CSS3 旋转立方体问题详解
2020/01/09 HTML / CSS
使用phonegap创建联系人的实现方法
2017/03/30 HTML / CSS
AmazeUI 按钮交互的实现示例
2020/08/24 HTML / CSS
美体小铺印度官网:The Body Shop印度
2019/10/17 全球购物
幼儿园春季开学寄语
2014/04/03 职场文书
金融系应届毕业生求职信
2014/05/26 职场文书
竞聘演讲稿怎么写
2014/08/28 职场文书
作文批改评语
2014/12/25 职场文书
公司文体活动总结
2015/05/07 职场文书
2016年校长新年寄语
2015/08/17 职场文书