tensorflow学习笔记之简单的神经网络训练和测试


Posted in Python onApril 15, 2018

本文实例为大家分享了用简单的神经网络来训练和测试的具体代码,供大家参考,具体内容如下

刚开始学习tf时,我们从简单的地方开始。卷积神经网络(CNN)是由简单的神经网络(NN)发展而来的,因此,我们的第一个例子,就从神经网络开始。

神经网络没有卷积功能,只有简单的三层:输入层,隐藏层和输出层。

数据从输入层输入,在隐藏层进行加权变换,最后在输出层进行输出。输出的时候,我们可以使用softmax回归,输出属于每个类别的概率值。借用极客学院的图表示如下:

tensorflow学习笔记之简单的神经网络训练和测试

其中,x1,x2,x3为输入数据,经过运算后,得到三个数据属于某个类别的概率值y1,y2,y3. 用简单的公式表示如下:

tensorflow学习笔记之简单的神经网络训练和测试

在训练过程中,我们将真实的结果和预测的结果相比(交叉熵比较法),会得到一个残差。公式如下:

tensorflow学习笔记之简单的神经网络训练和测试

y是我们预测的概率值,y'是实际的值。这个残差越小越好,我们可以使用梯度下降法,不停地改变W和b的值,使得残差逐渐变小,最后收敛到最小值。这样训练就完成了,我们就得到了一个模型(W和b的最优化值)。

完整代码如下:

import tensorflow as tf
import tensorflow.examples.tutorials.mnist.input_data as input_data
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
x = tf.placeholder(tf.float32, [None, 784])
y_actual = tf.placeholder(tf.float32, shape=[None, 10])
W = tf.Variable(tf.zeros([784,10]))    #初始化权值W
b = tf.Variable(tf.zeros([10]))      #初始化偏置项b
y_predict = tf.nn.softmax(tf.matmul(x,W) + b)   #加权变换并进行softmax回归,得到预测概率
cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_actual*tf.log(y_predict),reduction_indies=1))  #求交叉熵
train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy)  #用梯度下降法使得残差最小

correct_prediction = tf.equal(tf.argmax(y_predict,1), tf.argmax(y_actual,1))  #在测试阶段,测试准确度计算
accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))        #多个批次的准确度均值

init = tf.initialize_all_variables()
with tf.Session() as sess:
  sess.run(init)
  for i in range(1000):        #训练阶段,迭代1000次
    batch_xs, batch_ys = mnist.train.next_batch(100)      #按批次训练,每批100行数据
    sess.run(train_step, feed_dict={x: batch_xs, y_actual: batch_ys})  #执行训练
    if(i%100==0):         #每训练100次,测试一次
      print "accuracy:",sess.run(accuracy, feed_dict={x: mnist.test.images, y_actual: mnist.test.labels})

每训练100次,测试一次,随着训练次数的增加,测试精度也在增加。训练结束后,1W行数据测试的平均精度为91%左右,不是太高,肯定没有CNN高。

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python自动化测试实例解析
Sep 28 Python
Python实现简单多线程任务队列
Feb 27 Python
Python判断文本中消息重复次数的方法
Apr 27 Python
Python安装第三方库及常见问题处理方法汇总
Sep 13 Python
详解Python给照片换底色(蓝底换红底)
Mar 22 Python
python3模拟实现xshell远程执行liunx命令的方法
Jul 12 Python
python自动化测试之异常及日志操作实例分析
Nov 09 Python
flask 实现上传图片并缩放作为头像的例子
Jan 09 Python
一篇文章搞懂python的转义字符及用法
Sep 03 Python
python中count函数知识点浅析
Dec 17 Python
selenium如何定位span元素的实现
Jan 13 Python
conda安装tensorflow和conda常用命令小结
Feb 20 Python
Pytorch入门之mnist分类实例
Apr 14 #Python
pytorch构建网络模型的4种方法
Apr 13 #Python
Python输入二维数组方法
Apr 13 #Python
Python基于递归实现电话号码映射功能示例
Apr 13 #Python
Python的多维空数组赋值方法
Apr 13 #Python
python多维数组切片方法
Apr 13 #Python
Python实现判断并移除列表指定位置元素的方法
Apr 13 #Python
You might like
一个PHP操作Access类(PHP+ODBC+Access)
2007/01/02 PHP
PHP 文章中的远程图片采集到本地的代码
2009/07/30 PHP
深入PHP empty(),isset(),is_null()的实例测试详解
2013/06/06 PHP
php解析json数据实例
2014/08/19 PHP
php实现随机显示图片方法汇总
2015/05/21 PHP
PHP单例模式定义与使用实例详解
2017/02/06 PHP
PHP基于自定义类随机生成姓名的方法示例
2017/08/05 PHP
PHP unlink与rmdir删除目录及目录下所有文件实例代码
2018/02/07 PHP
在Laravel中使用MongoDB的方法示例
2019/11/11 PHP
关于UTF-8的客户端用AJAX方式获取GB2312的服务器端乱码问题的解决办法
2010/11/30 Javascript
jquery ajax 同步异步的执行 return值不能取得的解决方案
2012/01/08 Javascript
基于JavaScript实现动态创建表格和增加表格行数
2015/12/20 Javascript
纯js代码制作的网页时钟特效【附实例】
2016/03/30 Javascript
jquery基于layui实现二级联动下拉选择(省份城市选择)
2017/06/20 jQuery
webpack下实现动态引入文件方法
2018/02/22 Javascript
用vue快速开发app的脚手架工具
2018/06/11 Javascript
vue生成token并保存到本地存储中
2018/07/17 Javascript
ndm:NPM的桌面GUI应用程序
2018/10/15 Javascript
微信小程序中的上拉、下拉菜单功能
2020/03/13 Javascript
js抽奖转盘实现方法分析
2020/05/16 Javascript
python实现百万答题自动百度搜索答案
2018/01/16 Python
Python随机生成身份证号码及校验功能
2018/12/04 Python
Python3调用百度AI识别图片中的文字功能示例【测试可用】
2019/03/13 Python
python名片管理系统开发
2020/06/18 Python
python的json包位置及用法总结
2020/06/21 Python
python中id函数运行方式
2020/07/03 Python
用gpu训练好的神经网络,用tensorflow-cpu跑出错的原因及解决方案
2021/03/03 Python
HTML5之SVG 2D入门8—文档结构及相关元素总结
2013/01/30 HTML / CSS
全球最受追捧的运动服品牌领先数字目的地:Stylerunner
2020/11/25 全球购物
晚会邀请函范文
2014/01/24 职场文书
餐厅楼面主管岗位职责范本
2014/02/16 职场文书
2014财务年终工作总结
2014/12/08 职场文书
党委工作总结2015
2015/04/27 职场文书
工作证明格式范文
2015/06/15 职场文书
基层工作经历证明
2015/06/19 职场文书
浅谈sql_@SelectProvider及使用注意说明
2021/08/04 Java/Android