tensorflow学习笔记之简单的神经网络训练和测试


Posted in Python onApril 15, 2018

本文实例为大家分享了用简单的神经网络来训练和测试的具体代码,供大家参考,具体内容如下

刚开始学习tf时,我们从简单的地方开始。卷积神经网络(CNN)是由简单的神经网络(NN)发展而来的,因此,我们的第一个例子,就从神经网络开始。

神经网络没有卷积功能,只有简单的三层:输入层,隐藏层和输出层。

数据从输入层输入,在隐藏层进行加权变换,最后在输出层进行输出。输出的时候,我们可以使用softmax回归,输出属于每个类别的概率值。借用极客学院的图表示如下:

tensorflow学习笔记之简单的神经网络训练和测试

其中,x1,x2,x3为输入数据,经过运算后,得到三个数据属于某个类别的概率值y1,y2,y3. 用简单的公式表示如下:

tensorflow学习笔记之简单的神经网络训练和测试

在训练过程中,我们将真实的结果和预测的结果相比(交叉熵比较法),会得到一个残差。公式如下:

tensorflow学习笔记之简单的神经网络训练和测试

y是我们预测的概率值,y'是实际的值。这个残差越小越好,我们可以使用梯度下降法,不停地改变W和b的值,使得残差逐渐变小,最后收敛到最小值。这样训练就完成了,我们就得到了一个模型(W和b的最优化值)。

完整代码如下:

import tensorflow as tf
import tensorflow.examples.tutorials.mnist.input_data as input_data
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
x = tf.placeholder(tf.float32, [None, 784])
y_actual = tf.placeholder(tf.float32, shape=[None, 10])
W = tf.Variable(tf.zeros([784,10]))    #初始化权值W
b = tf.Variable(tf.zeros([10]))      #初始化偏置项b
y_predict = tf.nn.softmax(tf.matmul(x,W) + b)   #加权变换并进行softmax回归,得到预测概率
cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_actual*tf.log(y_predict),reduction_indies=1))  #求交叉熵
train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy)  #用梯度下降法使得残差最小

correct_prediction = tf.equal(tf.argmax(y_predict,1), tf.argmax(y_actual,1))  #在测试阶段,测试准确度计算
accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))        #多个批次的准确度均值

init = tf.initialize_all_variables()
with tf.Session() as sess:
  sess.run(init)
  for i in range(1000):        #训练阶段,迭代1000次
    batch_xs, batch_ys = mnist.train.next_batch(100)      #按批次训练,每批100行数据
    sess.run(train_step, feed_dict={x: batch_xs, y_actual: batch_ys})  #执行训练
    if(i%100==0):         #每训练100次,测试一次
      print "accuracy:",sess.run(accuracy, feed_dict={x: mnist.test.images, y_actual: mnist.test.labels})

每训练100次,测试一次,随着训练次数的增加,测试精度也在增加。训练结束后,1W行数据测试的平均精度为91%左右,不是太高,肯定没有CNN高。

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python基础教程之popen函数操作其它程序的输入和输出示例
Feb 10 Python
Python实现扫描局域网活动ip(扫描在线电脑)
Apr 28 Python
Python实现判断字符串中包含某个字符的判断函数示例
Jan 08 Python
python实现超简单的视频对象提取功能
Jun 04 Python
Python3.6简单反射操作示例
Jun 14 Python
python实现自动网页截图并裁剪图片
Jul 30 Python
python获取url的返回信息方法
Dec 17 Python
Python 实现中值滤波、均值滤波的方法
Jan 09 Python
在django中实现页面倒数几秒后自动跳转的例子
Aug 16 Python
Python通过递归函数输出嵌套列表元素
Oct 15 Python
python 实现音频叠加的示例
Oct 29 Python
基于python模拟bfs和dfs代码实例
Nov 19 Python
Pytorch入门之mnist分类实例
Apr 14 #Python
pytorch构建网络模型的4种方法
Apr 13 #Python
Python输入二维数组方法
Apr 13 #Python
Python基于递归实现电话号码映射功能示例
Apr 13 #Python
Python的多维空数组赋值方法
Apr 13 #Python
python多维数组切片方法
Apr 13 #Python
Python实现判断并移除列表指定位置元素的方法
Apr 13 #Python
You might like
php auth_http类库进行身份效验
2009/03/19 PHP
解析php中const与define的应用区别
2013/06/18 PHP
PHP 验证码不显示只有一个小红叉的解决方法
2013/09/30 PHP
Windows下的PHP安装pear教程
2014/10/24 PHP
Yii列表定义与使用分页方法小结(3种方法)
2016/07/15 PHP
img onload事件绑定各浏览器均可执行
2012/12/19 Javascript
自制的文件上传JS控件可支持IE、chrome、firefox etc
2014/04/18 Javascript
javascript实现base64 md5 sha1 密码加密
2015/09/09 Javascript
JS+CSS实现的蓝色table选项卡效果
2015/10/08 Javascript
jQuery中使用animate自定义动画的方法
2016/05/29 Javascript
jQuery Select下拉框操作小结(推荐)
2016/07/22 Javascript
jQuery实现拖动剪裁图片作为头像
2016/12/28 Javascript
微信小程序 Canvas增强组件实例详解及源码分享
2017/01/04 Javascript
JavaScript使用原型和原型链实现对象继承的方法详解
2017/04/05 Javascript
详解Webpack DLL用法以及功能
2017/07/11 Javascript
微信小程序如何获取openid及用户信息
2018/01/26 Javascript
修改Nodejs内置的npm默认配置路径方法
2018/05/13 NodeJs
你可能不知道的CORS跨域资源共享
2019/03/13 Javascript
[03:06]V社市场总监Dota2项目负责人Erik专访:希望更多中国玩家加入DOTA2
2014/07/11 DOTA
Python 命令行非阻塞输入的小例子
2013/09/27 Python
在python中的socket模块使用代理实例
2014/05/29 Python
Python实现的径向基(RBF)神经网络示例
2018/02/06 Python
Python键盘输入转换为列表的实例
2018/06/23 Python
django中间键重定向实例方法
2019/11/10 Python
python-web根据元素属性进行定位的方法
2019/12/13 Python
使用keras实现densenet和Xception的模型融合
2020/05/23 Python
六种酷炫Python运行进度条效果的实现代码
2020/07/17 Python
Python 如何测试文件是否存在
2020/07/31 Python
HTML5-WebSocket实现聊天室示例
2016/12/15 HTML / CSS
元旦晚会邀请函
2014/01/27 职场文书
环保倡议书50字
2014/05/15 职场文书
岗位职责说明书模板
2014/07/30 职场文书
大学生职业生涯十年规划书范文
2014/09/17 职场文书
党员干部公开承诺书范文
2015/04/27 职场文书
2015年安全生产月工作总结
2015/07/27 职场文书
Python基本的内置数据类型及使用方法
2022/04/13 Python