tensorflow学习笔记之简单的神经网络训练和测试


Posted in Python onApril 15, 2018

本文实例为大家分享了用简单的神经网络来训练和测试的具体代码,供大家参考,具体内容如下

刚开始学习tf时,我们从简单的地方开始。卷积神经网络(CNN)是由简单的神经网络(NN)发展而来的,因此,我们的第一个例子,就从神经网络开始。

神经网络没有卷积功能,只有简单的三层:输入层,隐藏层和输出层。

数据从输入层输入,在隐藏层进行加权变换,最后在输出层进行输出。输出的时候,我们可以使用softmax回归,输出属于每个类别的概率值。借用极客学院的图表示如下:

tensorflow学习笔记之简单的神经网络训练和测试

其中,x1,x2,x3为输入数据,经过运算后,得到三个数据属于某个类别的概率值y1,y2,y3. 用简单的公式表示如下:

tensorflow学习笔记之简单的神经网络训练和测试

在训练过程中,我们将真实的结果和预测的结果相比(交叉熵比较法),会得到一个残差。公式如下:

tensorflow学习笔记之简单的神经网络训练和测试

y是我们预测的概率值,y'是实际的值。这个残差越小越好,我们可以使用梯度下降法,不停地改变W和b的值,使得残差逐渐变小,最后收敛到最小值。这样训练就完成了,我们就得到了一个模型(W和b的最优化值)。

完整代码如下:

import tensorflow as tf
import tensorflow.examples.tutorials.mnist.input_data as input_data
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
x = tf.placeholder(tf.float32, [None, 784])
y_actual = tf.placeholder(tf.float32, shape=[None, 10])
W = tf.Variable(tf.zeros([784,10]))    #初始化权值W
b = tf.Variable(tf.zeros([10]))      #初始化偏置项b
y_predict = tf.nn.softmax(tf.matmul(x,W) + b)   #加权变换并进行softmax回归,得到预测概率
cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_actual*tf.log(y_predict),reduction_indies=1))  #求交叉熵
train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy)  #用梯度下降法使得残差最小

correct_prediction = tf.equal(tf.argmax(y_predict,1), tf.argmax(y_actual,1))  #在测试阶段,测试准确度计算
accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))        #多个批次的准确度均值

init = tf.initialize_all_variables()
with tf.Session() as sess:
  sess.run(init)
  for i in range(1000):        #训练阶段,迭代1000次
    batch_xs, batch_ys = mnist.train.next_batch(100)      #按批次训练,每批100行数据
    sess.run(train_step, feed_dict={x: batch_xs, y_actual: batch_ys})  #执行训练
    if(i%100==0):         #每训练100次,测试一次
      print "accuracy:",sess.run(accuracy, feed_dict={x: mnist.test.images, y_actual: mnist.test.labels})

每训练100次,测试一次,随着训练次数的增加,测试精度也在增加。训练结束后,1W行数据测试的平均精度为91%左右,不是太高,肯定没有CNN高。

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python中为什么要用self探讨
Apr 14 Python
Python序列化基础知识(json/pickle)
Oct 19 Python
浅谈tensorflow中几个随机函数的用法
Jul 27 Python
python修改txt文件中的某一项方法
Dec 29 Python
使用Python做定时任务及时了解互联网动态
May 15 Python
解决Pycharm中恢复被exclude的项目问题(pycharm source root)
Feb 14 Python
Python安装与卸载流程详细步骤(图解)
Feb 20 Python
Python代码需要缩进吗
Jul 01 Python
pycharm如何使用anaconda中的各种包(操作步骤)
Jul 31 Python
python 解决Windows平台上路径有空格的问题
Nov 10 Python
python-jwt用户认证食用教学的实现方法
Jan 19 Python
python3 hdf5文件 遍历代码
May 19 Python
Pytorch入门之mnist分类实例
Apr 14 #Python
pytorch构建网络模型的4种方法
Apr 13 #Python
Python输入二维数组方法
Apr 13 #Python
Python基于递归实现电话号码映射功能示例
Apr 13 #Python
Python的多维空数组赋值方法
Apr 13 #Python
python多维数组切片方法
Apr 13 #Python
Python实现判断并移除列表指定位置元素的方法
Apr 13 #Python
You might like
根德YB400的电路分析
2021/03/02 无线电
PHP读取Excel类文件
2017/05/15 PHP
PHP goto语句用法实例
2019/08/06 PHP
PHP 使用位运算实现四则运算的代码
2021/03/09 PHP
初试jQuery EasyUI 使用介绍
2010/04/01 Javascript
纯CSS打造的导航菜单(附jquery版)
2010/08/07 Javascript
点击表单提交时出现jQuery没有权限的解决方法
2014/07/23 Javascript
JS脚本根据手机浏览器类型跳转WAP手机网站(两种方式)
2015/08/04 Javascript
第九章之路径分页标签与徽章组件
2016/04/25 Javascript
基于css3新属性transform及原生js实现鼠标拖动3d立方体旋转
2016/06/12 Javascript
浅谈javascript中执行环境(作用域)与作用域链
2016/12/08 Javascript
各种选择框jQuery的选中方法(实例讲解)
2017/06/27 jQuery
微信小程序实现联动选择器
2019/02/15 Javascript
[01:10]DOTA2次级职业联赛 - EP战队宣传片
2014/12/01 DOTA
解决Python出现_warn_unsafe_extraction问题的方法
2016/03/24 Python
LRUCache的实现原理及利用python实现的方法
2017/11/21 Python
Python时间戳使用和相互转换详解
2017/12/11 Python
python实现将读入的多维list转为一维list的方法
2018/06/28 Python
Python 实现交换矩阵的行示例
2019/06/26 Python
Numpy中对向量、矩阵的使用详解
2019/10/29 Python
用python的turtle模块实现给女票画个小心心
2019/11/23 Python
Python的形参和实参使用方式
2019/12/24 Python
keras tensorflow 实现在python下多进程运行
2020/02/06 Python
如何解决pycharm调试报错的问题
2020/08/06 Python
python matplotlib工具栏源码探析三之添加、删除自定义工具项的案例详解
2021/02/25 Python
如何使用amaze ui的分页样式封装一个通用的JS分页控件
2020/08/21 HTML / CSS
如何写出高质量、高性能的MySQL查询
2014/11/17 面试题
个人考核材料
2014/05/15 职场文书
保护黄河倡议书
2014/05/16 职场文书
优秀班组长事迹
2014/05/31 职场文书
出国签证在职证明范本
2014/11/24 职场文书
护士先进个人总结
2015/02/13 职场文书
2015年教师自我评价范文
2015/03/04 职场文书
写给媳妇的检讨书
2015/05/06 职场文书
预备党员入党感言
2015/08/01 职场文书
MYSQL事务的隔离级别与MVCC
2022/05/25 MySQL