使用TensorFlow搭建一个全连接神经网络教程


Posted in Python onFebruary 06, 2020

说明

本例子利用TensorFlow搭建一个全连接神经网络,实现对MNIST手写数字的识别。

先上代码

from tensorflow.examples.tutorials.mnist import input_data
import tensorflow as tf

# prepare data
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)

xs = tf.placeholder(tf.float32, [None, 784])
ys = tf.placeholder(tf.float32, [None, 10])

# the model of the fully-connected network
weights = tf.Variable(tf.random_normal([784, 10]))
biases = tf.Variable(tf.zeros([1, 10]) + 0.1)
outputs = tf.matmul(xs, weights) + biases
predictions = tf.nn.softmax(outputs)
cross_entropy = tf.reduce_mean(-tf.reduce_sum(ys * tf.log(predictions),
            reduction_indices=[1]))
train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)

# compute the accuracy
correct_predictions = tf.equal(tf.argmax(predictions, 1), tf.argmax(ys, 1))
accuracy = tf.reduce_mean(tf.cast(correct_predictions, tf.float32))

with tf.Session() as sess:
 init = tf.global_variables_initializer()
 sess.run(init)
 for i in range(1000):
  batch_xs, batch_ys = mnist.train.next_batch(100)
  sess.run(train_step, feed_dict={
   xs: batch_xs,
   ys: batch_ys
  })
  if i % 50 == 0:
   print(sess.run(accuracy, feed_dict={
    xs: mnist.test.images,
    ys: mnist.test.labels
   }))

代码解析

1. 读取MNIST数据

mnist = input_data.read_data_sets('MNIST_data', one_hot=True)

2. 建立占位符

xs = tf.placeholder(tf.float32, [None, 784])
ys = tf.placeholder(tf.float32, [None, 10])

xs 代表图片像素数据, 每张图片(28×28)被展开成(1×784), 有多少图片还未定, 所以shape为None×784.

ys 代表图片标签数据, 0-9十个数字被表示成One-hot形式, 即只有对应bit为1, 其余为0.

3. 建立模型

weights = tf.Variable(tf.random_normal([784, 10]))


biases = tf.Variable(tf.zeros([1, 10]) + 0.1)
outputs = tf.matmul(xs, weights) + biases
predictions = tf.nn.softmax(outputs)
cross_entropy = tf.reduce_mean(-tf.reduce_sum(ys * tf.log(predictions),
            reduction_indices=[1]))
train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)

使用Softmax函数作为激活函数:

使用TensorFlow搭建一个全连接神经网络教程

4. 计算正确率

correct_predictions = tf.equal(tf.argmax(predictions, 1), tf.argmax(ys, 1))
accuracy = tf.reduce_mean(tf.cast(correct_predictions, tf.float32))

5. 使用模型

with tf.Session() as sess:
 init = tf.global_variables_initializer()
 sess.run(init)
 for i in range(1000):
  batch_xs, batch_ys = mnist.train.next_batch(100)
  sess.run(train_step, feed_dict={
   xs: batch_xs,
   ys: batch_ys
  })
  if i % 50 == 0:
   print(sess.run(accuracy, feed_dict={
    xs: mnist.test.images,
    ys: mnist.test.labels
   }))

运行结果

训练1000个循环, 准确率在87%左右.

Extracting MNIST_data/train-images-idx3-ubyte.gz
Extracting MNIST_data/train-labels-idx1-ubyte.gz
Extracting MNIST_data/t10k-images-idx3-ubyte.gz
Extracting MNIST_data/t10k-labels-idx1-ubyte.gz
0.1041
0.632
0.7357
0.7837
0.7971
0.8147
0.8283
0.8376
0.8423
0.8501
0.8501
0.8533
0.8567
0.8597
0.8552
0.8647
0.8654
0.8701
0.8712
0.8712

以上这篇使用TensorFlow搭建一个全连接神经网络教程就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
一则python3的简单爬虫代码
May 26 Python
TensorFlow的权值更新方法
Jun 14 Python
python绘制立方体的方法
Jul 02 Python
Python爬虫实现简单的爬取有道翻译功能示例
Jul 13 Python
python抖音表白程序源代码
Apr 07 Python
python如何实现数据的线性拟合
Jul 19 Python
Python 获取命令行参数内容及参数个数的实例
Dec 20 Python
PyTorch中permute的用法详解
Dec 30 Python
python无序链表删除重复项的方法
Jan 17 Python
使用python执行shell脚本 并动态传参 及subprocess的使用详解
Mar 06 Python
Django-rest-framework中过滤器的定制实例
Apr 01 Python
python要安装在哪个盘
Jun 15 Python
详解python 降级到3.6终极解决方案
Feb 06 #Python
PyCharm如何导入python项目的方法
Feb 06 #Python
tensorflow 环境变量设置方式
Feb 06 #Python
快速查找Python安装路径方法
Feb 06 #Python
运行tensorflow python程序,限制对GPU和CPU的占用操作
Feb 06 #Python
如何在django中添加日志功能
Feb 06 #Python
keras tensorflow 实现在python下多进程运行
Feb 06 #Python
You might like
PHP 简易输出CSV表格文件的方法详解
2013/06/20 PHP
win7 64位系统 配置php最新版开发环境(php+Apache+mysql)
2014/08/15 PHP
laravel如何开启跨域功能示例详解
2017/08/31 PHP
PHP设计模式之数据访问对象模式(DAO)原理与用法实例分析
2019/12/12 PHP
jquery 选取方法都有哪些
2014/05/18 Javascript
jQuery时间插件jquery.clock.js用法实例(5个示例)
2016/01/14 Javascript
jQuery on()方法绑定动态元素的点击事件无响应的解决办法
2016/07/07 Javascript
JavaScript表单焦点自动切换代码
2016/07/24 Javascript
基于touch.js手势库+zepto.js插件开发图片查看器(滑动、缩放、双击缩放)
2016/11/17 Javascript
JScript实现地址选择功能
2017/08/15 Javascript
vue draggable resizable gorkys与v-chart使用与总结
2019/09/05 Javascript
小程序实现长按保存图片的方法
2019/12/31 Javascript
javascript浅层克隆、深度克隆对比及实例解析
2020/02/09 Javascript
如何在vue中使用百度地图添加自定义覆盖物(水波纹)
2020/11/03 Javascript
在ironpython中利用装饰器执行SQL操作的例子
2015/05/02 Python
Python多线程结合队列下载百度音乐的方法
2015/07/27 Python
Python基于正则表达式实现检查文件内容的方法【文件检索】
2017/08/30 Python
浅谈Python Opencv中gamma变换的使用详解
2018/04/02 Python
Python3中的json模块使用详解
2018/05/05 Python
python监测当前联网状态并连接的实例
2018/12/18 Python
python实现贪吃蛇游戏
2020/03/21 Python
很酷的python表白工具 你喜欢我吗
2019/04/11 Python
Django 在iframe里跳转顶层url的例子
2019/08/21 Python
python 字符串常用函数详解
2019/09/11 Python
html5定位并在百度地图上显示的示例
2014/04/27 HTML / CSS
我想声明一个指针并为它分配一些空间, 但却不行。这些代码有什么 问题?char *p; *p = malloc(10);
2016/10/06 面试题
大学军训感言200字
2014/02/26 职场文书
捐助倡议书范文
2014/04/15 职场文书
激励员工的口号
2014/06/16 职场文书
学生打架检讨书
2014/10/20 职场文书
2015学生会文艺部工作总结
2015/04/03 职场文书
寻找成龙观后感
2015/06/12 职场文书
《风筝》教学反思
2016/02/23 职场文书
Python开发五子棋小游戏
2022/04/28 Python
排查Tomcat进程假死的问题
2022/05/06 Servers
python运行脚本文件的三种方法实例
2022/06/25 Python