TensorFlow实现简单的CNN的方法


Posted in Python onJuly 18, 2019

这里,我们将采用Tensor Flow内建函数实现简单的CNN,并用MNIST数据集进行测试

第1步:加载相应的库并创建计算图会话

import numpy as np
import tensorflow as tf
from tensorflow.contrib.learn.python.learn.datasets.mnist import read_data_sets
import matplotlib.pyplot as plt
 
#创建计算图会话
sess = tf.Session()

第2步:加载MNIST数据集,这里采用TensorFlow自带数据集,MNIST数据为28×28的图像,因此将其转化为相应二维矩阵

#数据集
data_dir = 'MNIST_data'
mnist = read_data_sets(data_dir)
 
train_xdata = np.array([np.reshape(x,[28,28]) for x in mnist.train.images] )
test_xdata = np.array([np.reshape(x,[28,28]) for x in mnist.test.images] )
 
train_labels = mnist.train.labels
test_labels = mnist.test.labels

第3步:设置模型参数

这里采用随机批量训练的方法,每训练10次对测试集进行测试,共迭代1500次,学习率采用指数下降的方式,初始学习率为0.1,每训练10次,学习率乘0.9,为了进行对比,后面会给出固定学习率为0.01的损失曲线图和准确率图

#设置模型参数
 
batch_size = 100 #批量训练图像张数
initial_learning_rate = 0.1 #学习率
global_step = tf.Variable(0, trainable=False) ;
learning_rate = tf.train.exponential_decay(initial_learning_rate,
                      global_step=global_step,
                      decay_steps=10,decay_rate=0.9)
 
evaluation_size = 500 #测试图像张数
 
image_width = 28 #图像的宽和高
image_height = 28
 
target_size = 10  #图像的目标为0~9共10个目标
num_channels = 1    #灰度图,颜色通道为1
generations = 1500  #迭代500次
evaluation_step = 10 #每训练十次进行一次测试
 
conv1_features = 25  #卷积层的特征个数
conv2_features = 50
 
max_pool_size1 = 2  #池化层大小
max_pool_size2 = 2
 
fully_connected_size = 100 #全连接层的神经元个数

第4步:声明占位符,注意这里的目标y_target类型为int32整型

#声明占位符
 
x_input_shape = [batch_size,image_width,image_height,num_channels]
x_input = tf.placeholder(tf.float32,shape=x_input_shape)
y_target = tf.placeholder(tf.int32,shape=[batch_size])
 
evaluation_input_shape = [evaluation_size,image_width,image_height,num_channels]
evaluation_input = tf.placeholder(tf.float32,shape=evaluation_input_shape)
evaluation_target = tf.placeholder(tf.int32,shape=[evaluation_size])

第5步:声明卷积层和全连接层的权重和偏置,这里采用2层卷积层和1层隐含全连接层

#声明卷积层的权重和偏置
#卷积层1
#采用滤波器为4X4滤波器,输入通道为1,输出通道为25
conv1_weight = tf.Variable(tf.truncated_normal([4,4,num_channels,conv1_features],stddev=0.1,dtype=tf.float32))
conv1_bias = tf.Variable(tf.truncated_normal([conv1_features],stddev=0.1,dtype=tf.float32))
 
#卷积层2
#采用滤波器为4X4滤波器,输入通道为25,输出通道为50
conv2_weight = tf.Variable(tf.truncated_normal([4,4,conv1_features,conv2_features],stddev=0.1,dtype=tf.float32))
conv2_bias = tf.Variable(tf.truncated_normal([conv2_features],stddev=0.1,dtype=tf.float32))
 
#声明全连接层权重和偏置
 
#卷积层过后图像的宽和高
conv_output_width = image_width // (max_pool_size1 * max_pool_size2) #//表示整除
conv_output_height = image_height // (max_pool_size1 * max_pool_size2)
 
#全连接层的输入大小
full1_input_size = conv_output_width * conv_output_height *conv2_features
 
full1_weight = tf.Variable(tf.truncated_normal([full1_input_size,fully_connected_size],stddev=0.1,dtype=tf.float32))
full1_bias = tf.Variable(tf.truncated_normal([fully_connected_size],stddev=0.1,dtype=tf.float32))
 
full2_weight = tf.Variable(tf.truncated_normal([fully_connected_size,target_size],stddev=0.1,dtype=tf.float32))
full2_bias = tf.Variable(tf.truncated_normal([target_size],stddev=0.1,dtype=tf.float32))

第6步:声明CNN模型,这里的两层卷积层均采用Conv-ReLU-MaxPool的结构,步长为[1,1,1,1],padding为SAME

全连接层隐层神经元为100个,输出层为目标个数10

def my_conv_net(input_data):
 
  #第一层:Conv-ReLU-MaxPool
  conv1 = tf.nn.conv2d(input_data,conv1_weight,strides=[1,1,1,1],padding='SAME')
  relu1 = tf.nn.relu(tf.nn.bias_add(conv1,conv1_bias))
  max_pool1 = tf.nn.max_pool(relu1,ksize=[1,max_pool_size1,max_pool_size1,1],strides=[1,max_pool_size1,max_pool_size1,1],padding='SAME')
 
  #第二层:Conv-ReLU-MaxPool
  conv2 = tf.nn.conv2d(max_pool1, conv2_weight, strides=[1, 1, 1, 1], padding='SAME')
  relu2 = tf.nn.relu(tf.nn.bias_add(conv2, conv2_bias))
  max_pool2 = tf.nn.max_pool(relu2, ksize=[1, max_pool_size2, max_pool_size2, 1],
                strides=[1, max_pool_size2, max_pool_size2, 1], padding='SAME')
 
  #全连接层
  #先将数据转化为1*N的形式
  #获取数据大小
  conv_output_shape = max_pool2.get_shape().as_list()
  #全连接层输入数据大小
  fully_input_size = conv_output_shape[1]*conv_output_shape[2]*conv_output_shape[3] #这三个shape就是图像的宽高和通道数
  full1_input_data = tf.reshape(max_pool2,[conv_output_shape[0],fully_input_size])  #转化为batch_size*fully_input_size二维矩阵
  #第一层全连接
  fully_connected1 = tf.nn.relu(tf.add(tf.matmul(full1_input_data,full1_weight),full1_bias))
  #第二层全连接输出
  model_output = tf.nn.relu(tf.add(tf.matmul(fully_connected1,full2_weight),full2_bias))#shape = [batch_size,target_size]
 
  return model_output
 
model_output = my_conv_net(x_input)
test_model_output = my_conv_net(evaluation_input)

第7步:定义损失函数,这里采用softmax函数作为损失函数

#损失函数
 
loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(logits=model_output,labels=y_target))

第8步:建立测评与评估函数,这里对输出层进行softmax,再通过np.argmax找出每行最大的数所在位置,再与目标值进行比对,统计准确率

#预测与评估
prediction = tf.nn.softmax(model_output)
test_prediction = tf.nn.softmax(test_model_output)
 
def get_accuracy(logits,targets):
  batch_predictions = np.argmax(logits,axis=1)#返回每行最大的数所在位置
  num_correct = np.sum(np.equal(batch_predictions,targets))
  return 100*num_correct/batch_predictions.shape[0]

第9步:初始化模型变量并创建优化器

#创建优化器
opt = tf.train.GradientDescentOptimizer(learning_rate=learning_rate)
train_step = opt.minimize(loss)
 
#初始化变量
init = tf.initialize_all_variables()
sess.run(init)

第10步:随机批量训练并进行绘图

#开始训练
 
train_loss = []
train_acc = []
test_acc = []
Learning_rate_vec = []
for i in range(generations):
  rand_index = np.random.choice(len(train_xdata),size=batch_size)
  rand_x = train_xdata[rand_index]
  rand_x = np.expand_dims(rand_x,3)
  rand_y = train_labels[rand_index]
  Learning_rate_vec.append(sess.run(learning_rate, feed_dict={global_step: i}))
  train_dict = {x_input:rand_x,y_target:rand_y}
 
  sess.run(train_step,feed_dict={x_input:rand_x,y_target:rand_y,global_step:i})
  temp_train_loss = sess.run(loss,feed_dict=train_dict)
  temp_train_prediction = sess.run(prediction,feed_dict=train_dict)
  temp_train_acc = get_accuracy(temp_train_prediction,rand_y)
 
  #测试集
  if (i+1)%evaluation_step ==0:
    eval_index = np.random.choice(len(test_xdata),size=evaluation_size)
    eval_x = test_xdata[eval_index]
    eval_x = np.expand_dims(eval_x,3)
    eval_y = test_labels[eval_index]
 
 
    test_dict = {evaluation_input:eval_x,evaluation_target:eval_y}
    temp_test_preds = sess.run(test_prediction,feed_dict=test_dict)
    temp_test_acc = get_accuracy(temp_test_preds,eval_y)
 
    test_acc.append(temp_test_acc)
  train_acc.append(temp_train_acc)
  train_loss.append(temp_train_loss)
 
 
 
 
#画损失曲线
fig = plt.figure()
ax = fig.add_subplot(111)
ax.plot(train_loss,'k-')
ax.set_xlabel('Generation')
ax.set_ylabel('Softmax Loss')
fig.suptitle('Softmax Loss per Generation')
 
#画准确度曲线
index = np.arange(start=1,stop=generations+1,step=evaluation_step)
fig2 = plt.figure()
ax2 = fig2.add_subplot(111)
ax2.plot(train_acc,'k-',label='Train Set Accuracy')
ax2.plot(index,test_acc,'r--',label='Test Set Accuracy')
ax2.set_xlabel('Generation')
ax2.set_ylabel('Accuracy')
fig2.suptitle('Train and Test Set Accuracy')
 
 
#画图
fig3 = plt.figure()
actuals = rand_y[0:6]
train_predictions = np.argmax(temp_train_prediction,axis=1)[0:6]
images = np.squeeze(rand_x[0:6])
Nrows = 2
Ncols =3
 
for i in range(6):
  ax3 = fig3.add_subplot(Nrows,Ncols,i+1)
  ax3.imshow(np.reshape(images[i],[28,28]),cmap='Greys_r')
  ax3.set_title('Actual: '+str(actuals[i]) +' pred: '+str(train_predictions[i]))
 
 
#画学习率
fig4 = plt.figure()
ax4 = fig4.add_subplot(111)
ax4.plot(Learning_rate_vec,'k-')
ax4.set_xlabel('step')
ax4.set_ylabel('Learning_rate')
fig4.suptitle('Learning_rate')
 
 
 
plt.show()

下面给出固定学习率图像和学习率随迭代次数下降的图像:

首先给出固定学习率图像:

下面是损失曲线

TensorFlow实现简单的CNN的方法

下面是准确率

TensorFlow实现简单的CNN的方法

我们可以看出,固定学习率损失函数下降速度较缓,同时其最终准确率为80%~90%之间就不再提高了

下面给出学习率随迭代次数降低的曲线:

首先给出学习率随迭代次数降低的损失曲线

TensorFlow实现简单的CNN的方法

然后给出相应的准确率曲线

TensorFlow实现简单的CNN的方法

我们可以看出其损失函数下降很快,同时准确率也可以达到90%以上

下面给出随机抓取的图像相应的识别情况:

TensorFlow实现简单的CNN的方法

至此我们实现了简单的CNN来实现MNIST手写图数据集的识别,如果想进一步提高其准确率,可以通过改变CNN网络参数,如通道数、全连接层神经元个数,过滤器大小,学习率,训练次数,加入dropout层等等,也可以通过增加CNN网络深度来进一步提高其准确率

下面给出一组参数:

初始学习率:initial_learning_rate=0.05

迭代步长:decay_steps=50,每50步改变一次学习率

下面是仿真结果:

TensorFlow实现简单的CNN的方法

TensorFlow实现简单的CNN的方法

TensorFlow实现简单的CNN的方法

TensorFlow实现简单的CNN的方法

我们可以看出,通过调整超参数,其既保证了损失函数能够快速下降,又进一步提高了其模型准确率,我们在训练次数为1500次的基础上,准确率已经达到97%以上。

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
详解Python中的文本处理
Apr 11 Python
Python黑魔法@property装饰器的使用技巧解析
Jun 16 Python
Python实现复杂对象转JSON的方法示例
Jun 22 Python
Python实现字符串逆序输出功能示例
Jun 24 Python
基于python3 类的属性、方法、封装、继承实例讲解
Sep 19 Python
python 3.7.0 下pillow安装方法
Aug 27 Python
关于python中plt.hist参数的使用详解
Nov 28 Python
python基于TCP实现的文件下载器功能案例
Dec 10 Python
python 读写文件包含多种编码格式的解决方式
Dec 20 Python
Python实现病毒仿真器的方法示例(附demo)
Feb 19 Python
如何用python 操作zookeeper
Dec 28 Python
浅析Python模块之间的相互引用问题
Feb 26 Python
windows上安装python3教程以及环境变量配置详解
Jul 18 #Python
Django 开发环境配置过程详解
Jul 18 #Python
解决Django中多条件查询的问题
Jul 18 #Python
python openpyxl使用方法详解
Jul 18 #Python
Python Django基础二之URL路由系统
Jul 18 #Python
使用django的objects.filter()方法匹配多个关键字的方法
Jul 18 #Python
Django基础三之视图函数的使用方法
Jul 18 #Python
You might like
php 遍历数据表数据并列表横向排列的代码
2009/09/05 PHP
PHP字符过滤函数去除字符串最后一个逗号(rtrim)
2013/03/26 PHP
php中json_encode处理gbk与gb2312中文乱码问题的解决方法
2014/07/10 PHP
php实现的支持imagemagick及gd库两种处理的缩略图生成类
2014/09/23 PHP
php强大的时间转换函数strtotime
2016/02/18 PHP
[原创]php常用字符串输出方法分析(echo,print,printf及sprintf)
2016/07/09 PHP
PHP curl 或 file_get_contents 获取需要授权页面的方法
2017/05/05 PHP
PHP7扩展开发教程之Hello World实现方法示例
2017/08/03 PHP
extjs 学习笔记 四 带分页的grid
2009/10/20 Javascript
IE6不能修改NAME问题的解决方法
2010/09/03 Javascript
利用javascript的面向对象的特性实现限制试用期
2011/08/04 Javascript
js修改input的type属性及浏览器兼容问题探讨与解决
2013/01/23 Javascript
jQuery如何取id有.的值一般的方法是取不到的
2014/04/18 Javascript
js根据手机客户端浏览器类型,判断跳转官网/手机网站多个实例代码
2016/04/30 Javascript
JS组件系列之使用HTML标签的data属性初始化JS组件
2016/09/14 Javascript
webpack配置之后端渲染详解
2017/10/26 Javascript
JavaScript设计模式之建造者模式实例教程
2018/07/02 Javascript
vue防止花括号{{}}闪烁v-text和v-html、v-cloak用法示例
2019/03/13 Javascript
微信小程序自定义tabbar custom-tab-bar 6s出不来解决方案(cover-view不兼容)
2019/11/01 Javascript
基于python进行桶排序与基数排序的总结
2018/05/29 Python
如何在python字符串中输入纯粹的{}
2018/08/22 Python
python如何实现视频转代码视频
2019/06/17 Python
python交易记录整合交易类详解
2019/07/03 Python
利用pandas向一个csv文件追加写入数据的实现示例
2020/04/23 Python
详解CSS3中@media的实际使用
2015/08/04 HTML / CSS
Zavvi荷兰:英国大型音像制品和图书游戏零售商
2018/03/22 全球购物
Strawberrynet草莓网新加坡站:护肤、彩妆、香水及美发产品
2018/08/31 全球购物
SCHIESSER荷兰官方网站:德国内衣专家
2020/10/09 全球购物
英语专业推荐信
2013/11/16 职场文书
中专毕业生的自荐书
2014/07/01 职场文书
工会趣味活动方案
2014/08/18 职场文书
师德师风个人整改措施
2014/10/27 职场文书
2014年初级职称工作总结
2014/12/08 职场文书
护士节慰问信
2015/02/15 职场文书
商务英语求职信范文
2015/03/19 职场文书
详解CSS中的特指度和层叠问题
2021/07/15 HTML / CSS