tensorflow学习笔记之简单的神经网络训练和测试


Posted in Python onApril 15, 2018

本文实例为大家分享了用简单的神经网络来训练和测试的具体代码,供大家参考,具体内容如下

刚开始学习tf时,我们从简单的地方开始。卷积神经网络(CNN)是由简单的神经网络(NN)发展而来的,因此,我们的第一个例子,就从神经网络开始。

神经网络没有卷积功能,只有简单的三层:输入层,隐藏层和输出层。

数据从输入层输入,在隐藏层进行加权变换,最后在输出层进行输出。输出的时候,我们可以使用softmax回归,输出属于每个类别的概率值。借用极客学院的图表示如下:

tensorflow学习笔记之简单的神经网络训练和测试

其中,x1,x2,x3为输入数据,经过运算后,得到三个数据属于某个类别的概率值y1,y2,y3. 用简单的公式表示如下:

tensorflow学习笔记之简单的神经网络训练和测试

在训练过程中,我们将真实的结果和预测的结果相比(交叉熵比较法),会得到一个残差。公式如下:

tensorflow学习笔记之简单的神经网络训练和测试

y是我们预测的概率值,y'是实际的值。这个残差越小越好,我们可以使用梯度下降法,不停地改变W和b的值,使得残差逐渐变小,最后收敛到最小值。这样训练就完成了,我们就得到了一个模型(W和b的最优化值)。

完整代码如下:

import tensorflow as tf
import tensorflow.examples.tutorials.mnist.input_data as input_data
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
x = tf.placeholder(tf.float32, [None, 784])
y_actual = tf.placeholder(tf.float32, shape=[None, 10])
W = tf.Variable(tf.zeros([784,10]))    #初始化权值W
b = tf.Variable(tf.zeros([10]))      #初始化偏置项b
y_predict = tf.nn.softmax(tf.matmul(x,W) + b)   #加权变换并进行softmax回归,得到预测概率
cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_actual*tf.log(y_predict),reduction_indies=1))  #求交叉熵
train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy)  #用梯度下降法使得残差最小

correct_prediction = tf.equal(tf.argmax(y_predict,1), tf.argmax(y_actual,1))  #在测试阶段,测试准确度计算
accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))        #多个批次的准确度均值

init = tf.initialize_all_variables()
with tf.Session() as sess:
  sess.run(init)
  for i in range(1000):        #训练阶段,迭代1000次
    batch_xs, batch_ys = mnist.train.next_batch(100)      #按批次训练,每批100行数据
    sess.run(train_step, feed_dict={x: batch_xs, y_actual: batch_ys})  #执行训练
    if(i%100==0):         #每训练100次,测试一次
      print "accuracy:",sess.run(accuracy, feed_dict={x: mnist.test.images, y_actual: mnist.test.labels})

每训练100次,测试一次,随着训练次数的增加,测试精度也在增加。训练结束后,1W行数据测试的平均精度为91%左右,不是太高,肯定没有CNN高。

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python字符串连接方式汇总
Aug 21 Python
在Python下尝试多线程编程
Apr 28 Python
python协程用法实例分析
Jun 04 Python
详解Swift中属性的声明与作用
Jun 30 Python
小小聊天室Python代码实现
Aug 17 Python
基于python神经卷积网络的人脸识别
May 24 Python
Selenium定位元素操作示例
Aug 10 Python
Python异常处理操作实例详解
Aug 28 Python
python实现指定字符串补全空格、前面填充0的方法
Nov 16 Python
win10下python2和python3共存问题解决方法
Dec 23 Python
使用python接受tgam的脑波数据实例
Apr 09 Python
python变量的作用域是什么
May 26 Python
Pytorch入门之mnist分类实例
Apr 14 #Python
pytorch构建网络模型的4种方法
Apr 13 #Python
Python输入二维数组方法
Apr 13 #Python
Python基于递归实现电话号码映射功能示例
Apr 13 #Python
Python的多维空数组赋值方法
Apr 13 #Python
python多维数组切片方法
Apr 13 #Python
Python实现判断并移除列表指定位置元素的方法
Apr 13 #Python
You might like
php基础知识:类与对象(1)
2006/12/13 PHP
自动生成文章摘要的代码[PHP 版本]
2007/03/20 PHP
PHP使用array_multisort对多个数组或多维数组进行排序
2014/12/16 PHP
Javascript 键盘keyCode键码值表
2009/12/24 Javascript
有关于JS构造函数的重载和工厂方法
2013/04/07 Javascript
动态创建script标签实现跨域资源访问的方法介绍
2014/02/28 Javascript
node.js中的fs.symlinkSync方法使用说明
2014/12/15 Javascript
JavaScript中Date对象的常用方法示例
2015/10/24 Javascript
JavaScript之map reduce_动力节点Java学院整理
2017/06/29 Javascript
探索Vue高阶组件的使用
2018/01/08 Javascript
vue中echarts3.0自适应的方法
2018/02/26 Javascript
VUE-cli3使用 svg-sprite-loader
2018/10/20 Javascript
JS/HTML5游戏常用算法之路径搜索算法 A*寻路算法完整实例
2018/12/14 Javascript
小程序实现多列选择器
2019/02/15 Javascript
javascript实现移动端触屏拖拽功能
2020/07/29 Javascript
解决vue+webpack项目接口跨域出现的问题
2020/08/10 Javascript
python赋值操作方法分享
2013/03/23 Python
Python程序中使用SQLAlchemy时出现乱码的解决方案
2015/04/24 Python
Django在win10下的安装并创建工程
2017/11/20 Python
PyCharm代码提示忽略大小写设置方法
2018/10/28 Python
用python脚本24小时刷浏览器的访问量方法
2018/12/07 Python
详解python selenium 爬取网易云音乐歌单名
2019/03/28 Python
Python 获取 datax 执行结果保存到数据库的方法
2019/07/11 Python
python如果快速判断数字奇数偶数
2019/11/13 Python
python__new__内置静态方法使用解析
2020/01/07 Python
Python程序控制语句用法实例分析
2020/01/14 Python
New Balance澳大利亚官网:运动鞋和健身服装
2019/02/23 全球购物
DERMAdoctor官网:美国著名皮肤护理品牌
2019/07/06 全球购物
运动会通讯稿100字
2014/01/31 职场文书
《画家乡》教学反思
2014/04/22 职场文书
2015年试用期自我评价范文
2015/03/10 职场文书
证劵公司反洗钱宣传活动总结
2015/05/08 职场文书
大学生读书笔记范文
2015/07/01 职场文书
小学运动会入场词
2015/07/18 职场文书
小学一年级语文教学反思
2016/03/03 职场文书
如何使用Maxwell实时同步mysql数据
2021/04/08 MySQL