Python利用全连接神经网络求解MNIST问题详解


Posted in Python onJanuary 14, 2020

本文实例讲述了Python利用全连接神经网络求解MNIST问题。分享给大家供大家参考,具体如下:

1、单隐藏层神经网络

人类的神经元在树突接受刺激信息后,经过细胞体处理,判断如果达到阈值,则将信息传递给下一个神经元或输出。类似地,神经元模型在输入层输入特征值x之后,与权重w相乘求和再加上b,经过激活函数判断后传递给下一层隐藏层或输出层。

单神经元的模型只有一个求和节点(如左下图所示)。全连接神经网络(Full Connected Networks)如右下图所示,中间层有多个神经元,并且每层的每个神经元都是与上一层和下一层的节点都对应连接。中间隐藏层只有一层的神经元网络称为单隐藏层神经网络。如果有多个中间隐藏层则称为多隐藏层神经网络。

Python利用全连接神经网络求解MNIST问题详解           Python利用全连接神经网络求解MNIST问题详解

常见的激活函数如下所示:

Python利用全连接神经网络求解MNIST问题详解

下面是在单个神经元逻辑回归求解MNIST手写数字识别问题的基础上,采用单隐藏层神经网络进行求解的过程。

首先载入数据,从Tensor FLow提供的数据库中导入MNIST数据

import tensorflow as tf
import tensorflow.examples.tutorials.mnist.input_data as input_data
mnist=input_data.read_data_sets('MNIST_data/',one_hot=True)

构建输入层,其中x是图像的特征值,由于是28×28=784个像素点,所有输入为未知行数、每行784的二维数组。y是图像的标签值,共有0~9十种可能,所有为[None,10]的二维数组

x=tf.placeholder(tf.float32,[None,784],name='x')
y=tf.placeholder(tf.float32,[None,10],name='y')

构建隐藏层,设置隐藏层神经元个数为256,由于输入层输入为784,而隐藏层神经元为h1_num,所以W1为[784,h1_num]形式的二维数组,b为[h1_num]的一维向量。此外采用ReLU作为激活函数处理输出。

h1_num=256                        #设置隐藏层神经元数量
W1=tf.Variable(tf.random_normal([784,h1_num]),name='W1')
b1=tf.Variable(tf.zeros([h1_num]),name='b1')
Y1=tf.nn.relu(tf.matmul(x,W1)+b1)             #激活函数

构建输出层,由于隐藏层有h1_num个神经元输出,输出层输出10种输出结果,所以W2为[h1_num,10]的二维数组,b2为[10]的一维向量。最后结果通过softmax将线性输出Y2转化为独热编码方式。

W2=tf.Variable(tf.random_normal([h1_num,10]),name='W2')
b2=tf.Variable(tf.zeros([10]),name='b2')
Y2=tf.matmul(Y1,W2)+b2
pred=tf.nn.softmax(Y2)

设置训练的超参数、损失函数、优化器,这里采用Adam Optimizer进行优化。准确率是通过比较预测值和标签值是否一致来定义。在定义损失函数时,如果直接使用交叉熵的方式定义,会出现log0值为NaN的情况,导致数据不稳定,无法得出结果。Tensor Flow提供了结合softmax定义交叉熵的方式softmax_cross_entropy_with_logits(),第一个参数为不经softmax处理的前向计算结果Y2,第二个参数为标签值y

train_epochs=20                    #训练轮数
batch_size=50                     #每个批次的样本数
batch_num=int(mnist.train.num_examples/batch_size)  #一轮需要训练多少批
learning_rate=0.01
#定义损失函数、优化器
loss_function=tf.reduce_mean(             #softmax交叉熵损失函数
       tf.nn.softmax_cross_entropy_with_logits(logits=Y2,labels=y)) 
optimizer=tf.train.AdamOptimizer(learning_rate).minimize(loss_function)
#定义准确率
correct_prediction=tf.equal(tf.argmax(pred,1),tf.argmax(y,1))
accuracy=tf.reduce_mean(tf.cast(correct_prediction,tf.float32))

进行训练并输出损失值与准确率,训练进行多轮,每轮一开始分批次读入数据进行训练,每结束一轮输出一次损失和准确率。

ss=tf.Session()
ss.run(tf.global_variables_initializer())           #进行全部变量的初始化
 
for epoch in range(train_epochs):
  for batch in range(batch_num):              #分批次读取数据进行训练
    xs,ys=mnist.train.next_batch(batch_size)
    ss.run(optimizer,feed_dict={x:xs,y:ys})
  loss,acc=ss.run([loss_function,accuracy],\
          feed_dict={x:mnist.validation.images,y:mnist.validation.labels})
  print('第%2d轮训练:损失为:%9f,准确率:%.4f'%(epoch+1,loss,acc))
 
ss.close()

运行结果如下图,与单个神经元相比,可以较快得到较高的准确率

Python利用全连接神经网络求解MNIST问题详解

评估模型,将测试集数据填充入占位符x,y去求准确率,

test_res=ss.run(accuracy,feed_dict={x:mnist.test.images,y:mnist.test.labels})
print('测试集的准确率为:%.4f'%(test_res))

2、多层神经网络

多层是指中间的隐藏层有多个,例如使用两层隐藏层,第一个隐藏层在计算后将结果输出到第二个隐藏层,再由第二个隐藏层计算后交给输出层,而第二个隐藏层的设置与第一个基本相同,例如:

#构建输入层
x=tf.placeholder(tf.float32,[None,784],name='x')
y=tf.placeholder(tf.float32,[None,10],name='y')
#构建第一个隐藏层
h1_num=256                            #第一隐藏层神经元数量256
W1=tf.Variable(tf.truncated_normal([784,h1_num],stddev=0.1),name='W1')
b1=tf.Variable(tf.zeros([h1_num]),name='b1')
Y1=tf.nn.relu(tf.matmul(x,W1)+b1)
#构建第二个隐藏层
h2_num=64                             #第二隐藏层神经元数量64
W2=tf.Variable(tf.random_normal([h1_num,h2_num],stddev=0.1),name='W2')
b2=tf.Variable(tf.zeros([h2_num]),name='b2')
Y2=tf.nn.relu(tf.matmul(Y1,W2)+b2)
#构建输出层
W3=tf.Variable(tf.random_normal([h2_num,10],stddev=0.1),name='W3')
b3=tf.Variable(tf.zeros([10]),name='b3')
Y3=tf.matmul(Y2,W3)+b3
pred=tf.nn.softmax(Y3)

在第一隐藏层产生参数W1时采用的是截断正态分布的随机函数tf.truncated_normal(),与普通正太分布相比,截断正态分布生成的值之间的差距不会太大。

设置的第一隐藏层的神经元256个,第二层64个,因此第二层的每个输入有256个特征值,并产生64个输出,相应的W2的shape为[h1_num,h2_num],b2的shape为[h2_num]。输出层W3的shape为[h2_num,10]。函数的其他部分与单层神经网络相同。

经过运算多层的神经网络训练的准确率不一定比单层的高,因为还涉及到训练的超参数的设置等多种因素。但是多层神经网络的运行速度比单层慢,越多层的神经网络意味着更加复杂的计算量。

全连接层函数

通过以上多层神经网络的定义可以看出两个隐藏层与输出层的构建方法基本类似,都是定义对应的变量W、b,在定义W时其shape为[输出维度,输出维度],因此可以将隐藏层与输出层统一定义为一个全连接层函数:

#定义一个通用的全连接层函数模型
def fcn_layer(inputs,in_dim,out_dim,activation=None):
  W=tf.Variable(tf.truncated_normal([in_dim,out_dim],stddev=0.1))
  b=tf.Variable(tf.zeros([out_dim]))
  Y=tf.matmul(inputs,W)+b
  if activation==None:
    output=Y
  else:
    output=activation(Y)
  return output
#构建第一个隐藏层
Y1=fcn_layer(x,784,256,tf.nn.relu)
#构建第二个隐藏层
Y2=fcn_layer(Y1,256,64,tf.nn.relu)
#构建输出层
Y3=fcn_layer(Y2,64,10)
pred=tf.nn.softmax(Y3)

其中inputs为本层的输入,in_dim为本层的输入维度,也就是上一层的输出维度,out_dim为本层的输出维度,activation为激活函数,默认为None。将输入与权重W叉乘再加上偏置值b得到Y,如果定义了激活函数,用激活函数处理Y,否则直接将Y赋给output输出。

3、模型的保存与读取

在模型训练结束后,如果希望下次继续使用或训练模型则需要将储存起来。

模型的储存

首先需要定义模型数据的保存路径:

import os
save_dir='D:/Temp/MachineLearning/ModelSaving/'    #定义模型的保存路径
if not os.path.exists(save_dir):            #如果不存在该路径则创建
  os.makedirs(save_dir)

定义储存粒度与saver,所谓储存粒度即每个几轮数据进行一次储存

save_step=5            #定义存储粒度
 
saver=tf.train.Saver()      #定义saver

在每轮训练结束后进行判断,每隔5轮储存一次,储存路径中拼接轮数信息,

if epoch%save_step==0:
    saver.save(ss,os.path.join(save_dir,'mnist_fcn_{:02d}.ckpt'.format(epoch+1)))

在所有迭代训练执行结束后,再整体储存一次

saver.save(ss,os.path.join(save_dir,'mnist_fcn.ckpt'))

这样就会在指定目录下生成模型的保存文件:Python利用全连接神经网络求解MNIST问题详解

模型的读取

从定义的模型目录中读取存盘点数据,并将其中的参数赋值给当前的session,然后便可以直接利用session进行测试,其准确率与保存时一致。

save_dir='D:/Temp/MachineLearning/ModelSaving/'    #定义模型的保存路径
saver=tf.train.Saver()                 #定义saver
 
ss=tf.Session()
ss.run(tf.global_variables_initializer())
 
ckpt=tf.train.get_checkpoint_state(save_dir)      #读取存盘点
if ckpt and ckpt.model_checkpoint_path:
  saver.restore(ss,ckpt.model_checkpoint_path)    #从存盘中恢复参数到当前的session
  print('数据恢复从',ckpt.model_checkpoint_path)
 
test_res=accuracy.eval(session=ss,feed_dict={x:mnist.test.images,y:mnist.test.labels})
print('测试集的准确率为:%.4f'%(test_res))

在读取模型时有时候会遇到报错:

NotFoundError (see above for traceback): Restoring from checkpoint failed. This is most likely due to a Variable name or other graph key that is missing from the checkpoint. Please ensure that you have not altered the graph expected based on the checkpoint.

这时只需重启kernel即可。

通过图来保存模型

也可以将训练好的模型以图的形式保存为.pb文件,下次直接可以使用,但不可以继续训练。

通过tf.train.write_graph函数来保存模型如下:

import tensorflow as tf
 
v=tf.Variable(1.0,'new_var')
with tf.Session() as ss:
  tf.train.write_graph(ss.graph_def,'D:\Temp\MachineLearning\ModelSaving\Graph',
            'test_graph.pb',as_text=False)

读取图文件并还原:

with tf.Session() as ss:
  with tf.gfile.GFile('D:/Temp\MachineLearning/ModelSaving/Graph/test_graph.pb','rb') as pb_file:
    graph_def=tf.GraphDef()
    graph_def.ParseFromString(pb_file.read())
    ss.graph.as_default()
    tf.import_graph_def(graph_def)
    print(graph_def)

希望本文所述对大家Python程序设计有所帮助。

Python 相关文章推荐
python中异常捕获方法详解
Mar 03 Python
基于python的Tkinter编写登陆注册界面
Jun 30 Python
详解python里使用正则表达式的分组命名方式
Oct 24 Python
Python3.4实现远程控制电脑开关机
Feb 22 Python
Python中文件的写入读取以及附加文字方法
Jan 23 Python
pandas修改DataFrame列名的实现方法
Feb 22 Python
python自制包并用pip免提交到pypi仅安装到本机【推荐】
Jun 03 Python
基于Python 中函数的 收集参数 机制
Dec 21 Python
From CSV to SQLite3 by python 导入csv到sqlite实例
Feb 14 Python
python脚本监控logstash进程并邮件告警实例
Apr 28 Python
python工具——Mimesis的简单使用教程
Jan 16 Python
详解Python调用系统命令的六种方法
Jan 28 Python
基于pytorch的lstm参数使用详解
Jan 14 #Python
Python利用逻辑回归模型解决MNIST手写数字识别问题详解
Jan 14 #Python
np.random.seed() 的使用详解
Jan 14 #Python
下载与当前Chrome对应的chromedriver.exe(用于python+selenium)
Jan 14 #Python
Python selenium 自动化脚本打包成一个exe文件(推荐)
Jan 14 #Python
pytorch+lstm实现的pos示例
Jan 14 #Python
Python中sorted()排序与字母大小写的问题
Jan 14 #Python
You might like
php 删除记录实现代码
2009/03/12 PHP
php一次性删除前台checkbox多选内容的方法
2013/09/22 PHP
PHP实现自动对图片进行滚动显示的方法
2015/03/12 PHP
Yii2框架dropDownList下拉菜单用法实例分析
2016/07/18 PHP
PHP排序算法之快速排序(Quick Sort)及其优化算法详解
2018/04/21 PHP
php 中self,this的区别和操作方法实例分析
2019/11/04 PHP
js 点击页面其他地方关闭弹出层(示例代码)
2013/12/24 Javascript
JavaScript中的单引号和双引号报错的解决方法
2014/09/01 Javascript
8个超实用的jQuery功能代码分享
2015/01/08 Javascript
Javascript基础教程之数组 array
2015/01/18 Javascript
JQuery CheckBox(复选框)操作方法汇总
2015/04/15 Javascript
jQuery获取与设置iframe高度的方法
2016/08/01 Javascript
JavaScript实现点击按钮复制指定区域文本(推荐)
2016/11/25 Javascript
jQuery中select与datalist制作下拉菜单时的区别浅析
2016/12/30 Javascript
bootstrap制作jsp页面(根据值让table显示选中)
2017/01/05 Javascript
Angular6笔记之封装http的示例代码
2018/07/27 Javascript
vue: WebStorm设置快速编译运行的方法
2018/10/18 Javascript
如何为vuex实现带参数的 getter和state.commit
2019/01/04 Javascript
CKeditor4 字体颜色功能配置方法教程
2019/06/26 Javascript
一篇文章带你浅入webpack的DLL优化打包
2020/02/20 Javascript
node运行js获得输出的三种方式示例详解
2020/07/02 Javascript
[15:28]DOTA2 HEROS教学视频教你分分钟做大人-剧毒术士
2014/06/13 DOTA
[50:27]Secret vs VG 2018国际邀请赛小组赛BO2 第二场 8.17
2018/08/20 DOTA
Python打包可执行文件的方法详解
2016/09/19 Python
linux环境下python中MySQLdb模块的安装方法
2017/06/16 Python
Python多继承顺序实例分析
2018/05/26 Python
numpy:找到指定元素的索引示例
2019/11/26 Python
python实现数据清洗(缺失值与异常值处理)
2019/12/02 Python
施华洛世奇澳大利亚官网:SWAROVSKI澳大利亚
2017/01/06 全球购物
珠宝店促销方案
2014/03/21 职场文书
巾帼建功标兵事迹材料
2014/05/11 职场文书
教师党的群众路线教育实践活动个人对照检查材料
2014/09/23 职场文书
2015年留守儿童工作总结
2015/05/22 职场文书
导游词之镜泊湖
2019/12/09 职场文书
利用Python读取微信朋友圈的多种方法总结
2021/08/23 Python
Spring-cloud Config Server的3种配置方式
2021/09/25 Java/Android