使用TensorFlow搭建一个全连接神经网络教程


Posted in Python onFebruary 06, 2020

说明

本例子利用TensorFlow搭建一个全连接神经网络,实现对MNIST手写数字的识别。

先上代码

from tensorflow.examples.tutorials.mnist import input_data
import tensorflow as tf

# prepare data
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)

xs = tf.placeholder(tf.float32, [None, 784])
ys = tf.placeholder(tf.float32, [None, 10])

# the model of the fully-connected network
weights = tf.Variable(tf.random_normal([784, 10]))
biases = tf.Variable(tf.zeros([1, 10]) + 0.1)
outputs = tf.matmul(xs, weights) + biases
predictions = tf.nn.softmax(outputs)
cross_entropy = tf.reduce_mean(-tf.reduce_sum(ys * tf.log(predictions),
            reduction_indices=[1]))
train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)

# compute the accuracy
correct_predictions = tf.equal(tf.argmax(predictions, 1), tf.argmax(ys, 1))
accuracy = tf.reduce_mean(tf.cast(correct_predictions, tf.float32))

with tf.Session() as sess:
 init = tf.global_variables_initializer()
 sess.run(init)
 for i in range(1000):
  batch_xs, batch_ys = mnist.train.next_batch(100)
  sess.run(train_step, feed_dict={
   xs: batch_xs,
   ys: batch_ys
  })
  if i % 50 == 0:
   print(sess.run(accuracy, feed_dict={
    xs: mnist.test.images,
    ys: mnist.test.labels
   }))

代码解析

1. 读取MNIST数据

mnist = input_data.read_data_sets('MNIST_data', one_hot=True)

2. 建立占位符

xs = tf.placeholder(tf.float32, [None, 784])
ys = tf.placeholder(tf.float32, [None, 10])

xs 代表图片像素数据, 每张图片(28×28)被展开成(1×784), 有多少图片还未定, 所以shape为None×784.

ys 代表图片标签数据, 0-9十个数字被表示成One-hot形式, 即只有对应bit为1, 其余为0.

3. 建立模型

weights = tf.Variable(tf.random_normal([784, 10]))


biases = tf.Variable(tf.zeros([1, 10]) + 0.1)
outputs = tf.matmul(xs, weights) + biases
predictions = tf.nn.softmax(outputs)
cross_entropy = tf.reduce_mean(-tf.reduce_sum(ys * tf.log(predictions),
            reduction_indices=[1]))
train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)

使用Softmax函数作为激活函数:

使用TensorFlow搭建一个全连接神经网络教程

4. 计算正确率

correct_predictions = tf.equal(tf.argmax(predictions, 1), tf.argmax(ys, 1))
accuracy = tf.reduce_mean(tf.cast(correct_predictions, tf.float32))

5. 使用模型

with tf.Session() as sess:
 init = tf.global_variables_initializer()
 sess.run(init)
 for i in range(1000):
  batch_xs, batch_ys = mnist.train.next_batch(100)
  sess.run(train_step, feed_dict={
   xs: batch_xs,
   ys: batch_ys
  })
  if i % 50 == 0:
   print(sess.run(accuracy, feed_dict={
    xs: mnist.test.images,
    ys: mnist.test.labels
   }))

运行结果

训练1000个循环, 准确率在87%左右.

Extracting MNIST_data/train-images-idx3-ubyte.gz
Extracting MNIST_data/train-labels-idx1-ubyte.gz
Extracting MNIST_data/t10k-images-idx3-ubyte.gz
Extracting MNIST_data/t10k-labels-idx1-ubyte.gz
0.1041
0.632
0.7357
0.7837
0.7971
0.8147
0.8283
0.8376
0.8423
0.8501
0.8501
0.8533
0.8567
0.8597
0.8552
0.8647
0.8654
0.8701
0.8712
0.8712

以上这篇使用TensorFlow搭建一个全连接神经网络教程就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python实现的简单FTP上传下载文件实例
Jun 30 Python
微信小程序跳一跳游戏 python脚本跳一跳刷高分技巧
Jan 04 Python
python 自动去除空行的实例
Jul 24 Python
Pycharm新手教程(只需要看这篇就够了)
Jun 18 Python
pytorch 输出中间层特征的实例
Aug 17 Python
Python如何使用argparse模块处理命令行参数
Dec 11 Python
tensorflow模型保存、加载之变量重命名实例
Jan 21 Python
Python 实现一个计时器
Jul 28 Python
Flask缓存静态文件的具体方法
Aug 02 Python
Django windows使用Apache实现部署流程解析
Oct 12 Python
python两种获取剪贴板内容的方法
Nov 06 Python
Python实现对word文档添加密码去除密码的示例代码
Dec 29 Python
详解python 降级到3.6终极解决方案
Feb 06 #Python
PyCharm如何导入python项目的方法
Feb 06 #Python
tensorflow 环境变量设置方式
Feb 06 #Python
快速查找Python安装路径方法
Feb 06 #Python
运行tensorflow python程序,限制对GPU和CPU的占用操作
Feb 06 #Python
如何在django中添加日志功能
Feb 06 #Python
keras tensorflow 实现在python下多进程运行
Feb 06 #Python
You might like
php代码架构的八点注意事项
2016/01/25 PHP
Laravel多域名下字段验证的方法
2019/04/04 PHP
传递参数的标准方法(jQuery.ajax)
2008/11/19 Javascript
JavaScript 面向对象编程(1) 基础
2010/05/18 Javascript
js遍历td tr等html元素
2012/12/13 Javascript
判断js中各种数据的类型方法之typeof与0bject.prototype.toString讲解
2013/11/07 Javascript
无闪烁更新网页内容JS实现
2013/12/19 Javascript
Jquery如何实现点击时高亮显示代码
2014/01/22 Javascript
仿百度联盟对联广告实现代码
2014/08/30 Javascript
Js为表单动态添加节点内容的方法
2015/02/10 Javascript
JavaScript判断是否为数字的4种方法及效率比较
2015/04/01 Javascript
JS两个数组比较,删除重复值的巧妙方法(推荐)
2016/06/03 Javascript
AngularJS实现与Java Web服务器交互操作示例【附demo源码下载】
2016/11/02 Javascript
Ajax的概述与实现过程
2016/11/18 Javascript
JS实现自定义弹窗功能
2018/08/08 Javascript
Vue自定义属性实例分析
2019/02/23 Javascript
使用express来代理服务的方法
2019/06/21 Javascript
js图数据结构处理 迪杰斯特拉算法代码实例
2019/09/11 Javascript
[01:25]2014DOTA2国际邀请赛 zhou分析LGD比赛情况
2014/07/14 DOTA
python实现ping的方法
2015/07/06 Python
Python实现Youku视频批量下载功能
2017/03/14 Python
机器学习python实战之决策树
2017/11/01 Python
python实现最长公共子序列
2018/05/22 Python
python+opencv打开摄像头,保存视频、拍照功能的实现方法
2019/01/08 Python
pyCharm 设置调试输出窗口中文显示方式(字符码转换)
2020/06/09 Python
CSS 3.0文字悬停跳动特效代码
2020/10/26 HTML / CSS
Snapfish爱尔兰:在线照片打印和个性化照片礼品
2018/09/17 全球购物
法国在线购买汽车轮胎网站:123pneus.fr
2019/02/25 全球购物
征婚广告词
2014/03/17 职场文书
教师党的群众路线学习心得体会
2014/11/04 职场文书
2015新员工试用期工作总结
2014/12/12 职场文书
初中英语教师个人工作总结
2015/02/09 职场文书
2016年高校自主招生自荐信范文
2015/03/24 职场文书
领导离职感言
2015/08/03 职场文书
CSS3 制作精美的定价表
2021/04/06 HTML / CSS
Win11怎么添加用户?Win11添加用户账户的方法
2022/07/15 数码科技