TensorFlow实现简单的CNN的方法


Posted in Python onJuly 18, 2019

这里,我们将采用Tensor Flow内建函数实现简单的CNN,并用MNIST数据集进行测试

第1步:加载相应的库并创建计算图会话

import numpy as np
import tensorflow as tf
from tensorflow.contrib.learn.python.learn.datasets.mnist import read_data_sets
import matplotlib.pyplot as plt
 
#创建计算图会话
sess = tf.Session()

第2步:加载MNIST数据集,这里采用TensorFlow自带数据集,MNIST数据为28×28的图像,因此将其转化为相应二维矩阵

#数据集
data_dir = 'MNIST_data'
mnist = read_data_sets(data_dir)
 
train_xdata = np.array([np.reshape(x,[28,28]) for x in mnist.train.images] )
test_xdata = np.array([np.reshape(x,[28,28]) for x in mnist.test.images] )
 
train_labels = mnist.train.labels
test_labels = mnist.test.labels

第3步:设置模型参数

这里采用随机批量训练的方法,每训练10次对测试集进行测试,共迭代1500次,学习率采用指数下降的方式,初始学习率为0.1,每训练10次,学习率乘0.9,为了进行对比,后面会给出固定学习率为0.01的损失曲线图和准确率图

#设置模型参数
 
batch_size = 100 #批量训练图像张数
initial_learning_rate = 0.1 #学习率
global_step = tf.Variable(0, trainable=False) ;
learning_rate = tf.train.exponential_decay(initial_learning_rate,
                      global_step=global_step,
                      decay_steps=10,decay_rate=0.9)
 
evaluation_size = 500 #测试图像张数
 
image_width = 28 #图像的宽和高
image_height = 28
 
target_size = 10  #图像的目标为0~9共10个目标
num_channels = 1    #灰度图,颜色通道为1
generations = 1500  #迭代500次
evaluation_step = 10 #每训练十次进行一次测试
 
conv1_features = 25  #卷积层的特征个数
conv2_features = 50
 
max_pool_size1 = 2  #池化层大小
max_pool_size2 = 2
 
fully_connected_size = 100 #全连接层的神经元个数

第4步:声明占位符,注意这里的目标y_target类型为int32整型

#声明占位符
 
x_input_shape = [batch_size,image_width,image_height,num_channels]
x_input = tf.placeholder(tf.float32,shape=x_input_shape)
y_target = tf.placeholder(tf.int32,shape=[batch_size])
 
evaluation_input_shape = [evaluation_size,image_width,image_height,num_channels]
evaluation_input = tf.placeholder(tf.float32,shape=evaluation_input_shape)
evaluation_target = tf.placeholder(tf.int32,shape=[evaluation_size])

第5步:声明卷积层和全连接层的权重和偏置,这里采用2层卷积层和1层隐含全连接层

#声明卷积层的权重和偏置
#卷积层1
#采用滤波器为4X4滤波器,输入通道为1,输出通道为25
conv1_weight = tf.Variable(tf.truncated_normal([4,4,num_channels,conv1_features],stddev=0.1,dtype=tf.float32))
conv1_bias = tf.Variable(tf.truncated_normal([conv1_features],stddev=0.1,dtype=tf.float32))
 
#卷积层2
#采用滤波器为4X4滤波器,输入通道为25,输出通道为50
conv2_weight = tf.Variable(tf.truncated_normal([4,4,conv1_features,conv2_features],stddev=0.1,dtype=tf.float32))
conv2_bias = tf.Variable(tf.truncated_normal([conv2_features],stddev=0.1,dtype=tf.float32))
 
#声明全连接层权重和偏置
 
#卷积层过后图像的宽和高
conv_output_width = image_width // (max_pool_size1 * max_pool_size2) #//表示整除
conv_output_height = image_height // (max_pool_size1 * max_pool_size2)
 
#全连接层的输入大小
full1_input_size = conv_output_width * conv_output_height *conv2_features
 
full1_weight = tf.Variable(tf.truncated_normal([full1_input_size,fully_connected_size],stddev=0.1,dtype=tf.float32))
full1_bias = tf.Variable(tf.truncated_normal([fully_connected_size],stddev=0.1,dtype=tf.float32))
 
full2_weight = tf.Variable(tf.truncated_normal([fully_connected_size,target_size],stddev=0.1,dtype=tf.float32))
full2_bias = tf.Variable(tf.truncated_normal([target_size],stddev=0.1,dtype=tf.float32))

第6步:声明CNN模型,这里的两层卷积层均采用Conv-ReLU-MaxPool的结构,步长为[1,1,1,1],padding为SAME

全连接层隐层神经元为100个,输出层为目标个数10

def my_conv_net(input_data):
 
  #第一层:Conv-ReLU-MaxPool
  conv1 = tf.nn.conv2d(input_data,conv1_weight,strides=[1,1,1,1],padding='SAME')
  relu1 = tf.nn.relu(tf.nn.bias_add(conv1,conv1_bias))
  max_pool1 = tf.nn.max_pool(relu1,ksize=[1,max_pool_size1,max_pool_size1,1],strides=[1,max_pool_size1,max_pool_size1,1],padding='SAME')
 
  #第二层:Conv-ReLU-MaxPool
  conv2 = tf.nn.conv2d(max_pool1, conv2_weight, strides=[1, 1, 1, 1], padding='SAME')
  relu2 = tf.nn.relu(tf.nn.bias_add(conv2, conv2_bias))
  max_pool2 = tf.nn.max_pool(relu2, ksize=[1, max_pool_size2, max_pool_size2, 1],
                strides=[1, max_pool_size2, max_pool_size2, 1], padding='SAME')
 
  #全连接层
  #先将数据转化为1*N的形式
  #获取数据大小
  conv_output_shape = max_pool2.get_shape().as_list()
  #全连接层输入数据大小
  fully_input_size = conv_output_shape[1]*conv_output_shape[2]*conv_output_shape[3] #这三个shape就是图像的宽高和通道数
  full1_input_data = tf.reshape(max_pool2,[conv_output_shape[0],fully_input_size])  #转化为batch_size*fully_input_size二维矩阵
  #第一层全连接
  fully_connected1 = tf.nn.relu(tf.add(tf.matmul(full1_input_data,full1_weight),full1_bias))
  #第二层全连接输出
  model_output = tf.nn.relu(tf.add(tf.matmul(fully_connected1,full2_weight),full2_bias))#shape = [batch_size,target_size]
 
  return model_output
 
model_output = my_conv_net(x_input)
test_model_output = my_conv_net(evaluation_input)

第7步:定义损失函数,这里采用softmax函数作为损失函数

#损失函数
 
loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(logits=model_output,labels=y_target))

第8步:建立测评与评估函数,这里对输出层进行softmax,再通过np.argmax找出每行最大的数所在位置,再与目标值进行比对,统计准确率

#预测与评估
prediction = tf.nn.softmax(model_output)
test_prediction = tf.nn.softmax(test_model_output)
 
def get_accuracy(logits,targets):
  batch_predictions = np.argmax(logits,axis=1)#返回每行最大的数所在位置
  num_correct = np.sum(np.equal(batch_predictions,targets))
  return 100*num_correct/batch_predictions.shape[0]

第9步:初始化模型变量并创建优化器

#创建优化器
opt = tf.train.GradientDescentOptimizer(learning_rate=learning_rate)
train_step = opt.minimize(loss)
 
#初始化变量
init = tf.initialize_all_variables()
sess.run(init)

第10步:随机批量训练并进行绘图

#开始训练
 
train_loss = []
train_acc = []
test_acc = []
Learning_rate_vec = []
for i in range(generations):
  rand_index = np.random.choice(len(train_xdata),size=batch_size)
  rand_x = train_xdata[rand_index]
  rand_x = np.expand_dims(rand_x,3)
  rand_y = train_labels[rand_index]
  Learning_rate_vec.append(sess.run(learning_rate, feed_dict={global_step: i}))
  train_dict = {x_input:rand_x,y_target:rand_y}
 
  sess.run(train_step,feed_dict={x_input:rand_x,y_target:rand_y,global_step:i})
  temp_train_loss = sess.run(loss,feed_dict=train_dict)
  temp_train_prediction = sess.run(prediction,feed_dict=train_dict)
  temp_train_acc = get_accuracy(temp_train_prediction,rand_y)
 
  #测试集
  if (i+1)%evaluation_step ==0:
    eval_index = np.random.choice(len(test_xdata),size=evaluation_size)
    eval_x = test_xdata[eval_index]
    eval_x = np.expand_dims(eval_x,3)
    eval_y = test_labels[eval_index]
 
 
    test_dict = {evaluation_input:eval_x,evaluation_target:eval_y}
    temp_test_preds = sess.run(test_prediction,feed_dict=test_dict)
    temp_test_acc = get_accuracy(temp_test_preds,eval_y)
 
    test_acc.append(temp_test_acc)
  train_acc.append(temp_train_acc)
  train_loss.append(temp_train_loss)
 
 
 
 
#画损失曲线
fig = plt.figure()
ax = fig.add_subplot(111)
ax.plot(train_loss,'k-')
ax.set_xlabel('Generation')
ax.set_ylabel('Softmax Loss')
fig.suptitle('Softmax Loss per Generation')
 
#画准确度曲线
index = np.arange(start=1,stop=generations+1,step=evaluation_step)
fig2 = plt.figure()
ax2 = fig2.add_subplot(111)
ax2.plot(train_acc,'k-',label='Train Set Accuracy')
ax2.plot(index,test_acc,'r--',label='Test Set Accuracy')
ax2.set_xlabel('Generation')
ax2.set_ylabel('Accuracy')
fig2.suptitle('Train and Test Set Accuracy')
 
 
#画图
fig3 = plt.figure()
actuals = rand_y[0:6]
train_predictions = np.argmax(temp_train_prediction,axis=1)[0:6]
images = np.squeeze(rand_x[0:6])
Nrows = 2
Ncols =3
 
for i in range(6):
  ax3 = fig3.add_subplot(Nrows,Ncols,i+1)
  ax3.imshow(np.reshape(images[i],[28,28]),cmap='Greys_r')
  ax3.set_title('Actual: '+str(actuals[i]) +' pred: '+str(train_predictions[i]))
 
 
#画学习率
fig4 = plt.figure()
ax4 = fig4.add_subplot(111)
ax4.plot(Learning_rate_vec,'k-')
ax4.set_xlabel('step')
ax4.set_ylabel('Learning_rate')
fig4.suptitle('Learning_rate')
 
 
 
plt.show()

下面给出固定学习率图像和学习率随迭代次数下降的图像:

首先给出固定学习率图像:

下面是损失曲线

TensorFlow实现简单的CNN的方法

下面是准确率

TensorFlow实现简单的CNN的方法

我们可以看出,固定学习率损失函数下降速度较缓,同时其最终准确率为80%~90%之间就不再提高了

下面给出学习率随迭代次数降低的曲线:

首先给出学习率随迭代次数降低的损失曲线

TensorFlow实现简单的CNN的方法

然后给出相应的准确率曲线

TensorFlow实现简单的CNN的方法

我们可以看出其损失函数下降很快,同时准确率也可以达到90%以上

下面给出随机抓取的图像相应的识别情况:

TensorFlow实现简单的CNN的方法

至此我们实现了简单的CNN来实现MNIST手写图数据集的识别,如果想进一步提高其准确率,可以通过改变CNN网络参数,如通道数、全连接层神经元个数,过滤器大小,学习率,训练次数,加入dropout层等等,也可以通过增加CNN网络深度来进一步提高其准确率

下面给出一组参数:

初始学习率:initial_learning_rate=0.05

迭代步长:decay_steps=50,每50步改变一次学习率

下面是仿真结果:

TensorFlow实现简单的CNN的方法

TensorFlow实现简单的CNN的方法

TensorFlow实现简单的CNN的方法

TensorFlow实现简单的CNN的方法

我们可以看出,通过调整超参数,其既保证了损失函数能够快速下降,又进一步提高了其模型准确率,我们在训练次数为1500次的基础上,准确率已经达到97%以上。

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python获取网页上图片下载地址的方法
Mar 11 Python
详解Python中的文件操作
Aug 28 Python
Python 编码处理-str与Unicode的区别
Sep 06 Python
利用python打开摄像头及颜色检测方法
Aug 03 Python
django 信号调度机制详解
Jul 19 Python
Python实现Singleton模式的方式详解
Aug 08 Python
Python全面分析系统的时域特性和频率域特性
Feb 26 Python
pyMySQL SQL语句传参问题,单个参数或多个参数说明
Jun 06 Python
查看keras的默认backend实现方式
Jun 19 Python
Python3压缩和解压缩实现代码
Mar 01 Python
Python如何利用正则表达式爬取网页信息及图片
Apr 17 Python
整理Python中常用的conda命令操作
Jun 15 Python
windows上安装python3教程以及环境变量配置详解
Jul 18 #Python
Django 开发环境配置过程详解
Jul 18 #Python
解决Django中多条件查询的问题
Jul 18 #Python
python openpyxl使用方法详解
Jul 18 #Python
Python Django基础二之URL路由系统
Jul 18 #Python
使用django的objects.filter()方法匹配多个关键字的方法
Jul 18 #Python
Django基础三之视图函数的使用方法
Jul 18 #Python
You might like
php数组函数序列之array_pop() - 删除数组中的最后一个元素
2011/11/07 PHP
php引用传值实例详解学习
2013/11/06 PHP
PHP模板引擎Smarty的缓存使用总结
2014/04/24 PHP
PHP实现的线索二叉树及二叉树遍历方法详解
2016/04/25 PHP
基于laravel缓冲cache的用法详解
2019/10/23 PHP
PHP 自动加载类原理与用法实例分析
2020/04/14 PHP
jQuery 扩展对input的一些操作方法
2009/10/30 Javascript
JavaScript设计模式之策略模式实例
2014/10/10 Javascript
jQuery fancybox在ie浏览器下无法显示关闭按钮的解决办法
2016/02/19 Javascript
AngularJS中的过滤器filter用法完全解析
2016/04/22 Javascript
JQuery validate插件验证用户注册信息
2016/05/11 Javascript
JavaScript鼠标特效大全
2016/09/13 Javascript
快速理解 JavaScript 中的 LHS 和 RHS 查询的用法
2017/08/24 Javascript
解决Layui 表单提交数据为空的问题
2018/08/15 Javascript
js事件触发操作实例分析
2019/06/21 Javascript
JavaScript自定义超时API代码实例
2020/04/30 Javascript
jQuery 添加元素和删除元素的方法
2020/07/15 jQuery
用vue设计一个日历表
2020/12/03 Vue.js
Python的迭代器和生成器
2015/07/29 Python
python实现汉诺塔方法汇总
2016/07/25 Python
利用numpy实现一、二维数组的拼接简单代码示例
2017/12/15 Python
Python3实现带附件的定时发送邮件功能
2020/12/22 Python
python 控制Asterisk AMI接口外呼电话的例子
2019/08/08 Python
利用 PyCharm 实现本地代码和远端的实时同步功能
2020/03/23 Python
Mixbook加拿大:照片书,照片卡,剪贴簿,年历和日历
2017/02/21 全球购物
国培远程培训感言
2014/03/08 职场文书
班训口号大全
2014/06/18 职场文书
模具设计与制造专业自荐书
2014/07/01 职场文书
校运动会广播稿(100篇)
2014/09/12 职场文书
县委常委班子专题民主生活会查摆问题及整改措施
2014/09/27 职场文书
2014年企业团支部工作总结
2014/12/10 职场文书
大学生学习十八届五中全会精神心得体会
2016/01/05 职场文书
Mysql基础知识点汇总
2021/05/26 MySQL
JS中如何优雅的使用async await详解
2021/10/05 Javascript
JAVA SpringMVC实现自定义拦截器
2022/03/16 Python
java开发双人五子棋游戏
2022/05/06 Java/Android