使用TensorFlow搭建一个全连接神经网络教程


Posted in Python onFebruary 06, 2020

说明

本例子利用TensorFlow搭建一个全连接神经网络,实现对MNIST手写数字的识别。

先上代码

from tensorflow.examples.tutorials.mnist import input_data
import tensorflow as tf

# prepare data
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)

xs = tf.placeholder(tf.float32, [None, 784])
ys = tf.placeholder(tf.float32, [None, 10])

# the model of the fully-connected network
weights = tf.Variable(tf.random_normal([784, 10]))
biases = tf.Variable(tf.zeros([1, 10]) + 0.1)
outputs = tf.matmul(xs, weights) + biases
predictions = tf.nn.softmax(outputs)
cross_entropy = tf.reduce_mean(-tf.reduce_sum(ys * tf.log(predictions),
            reduction_indices=[1]))
train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)

# compute the accuracy
correct_predictions = tf.equal(tf.argmax(predictions, 1), tf.argmax(ys, 1))
accuracy = tf.reduce_mean(tf.cast(correct_predictions, tf.float32))

with tf.Session() as sess:
 init = tf.global_variables_initializer()
 sess.run(init)
 for i in range(1000):
  batch_xs, batch_ys = mnist.train.next_batch(100)
  sess.run(train_step, feed_dict={
   xs: batch_xs,
   ys: batch_ys
  })
  if i % 50 == 0:
   print(sess.run(accuracy, feed_dict={
    xs: mnist.test.images,
    ys: mnist.test.labels
   }))

代码解析

1. 读取MNIST数据

mnist = input_data.read_data_sets('MNIST_data', one_hot=True)

2. 建立占位符

xs = tf.placeholder(tf.float32, [None, 784])
ys = tf.placeholder(tf.float32, [None, 10])

xs 代表图片像素数据, 每张图片(28×28)被展开成(1×784), 有多少图片还未定, 所以shape为None×784.

ys 代表图片标签数据, 0-9十个数字被表示成One-hot形式, 即只有对应bit为1, 其余为0.

3. 建立模型

weights = tf.Variable(tf.random_normal([784, 10]))


biases = tf.Variable(tf.zeros([1, 10]) + 0.1)
outputs = tf.matmul(xs, weights) + biases
predictions = tf.nn.softmax(outputs)
cross_entropy = tf.reduce_mean(-tf.reduce_sum(ys * tf.log(predictions),
            reduction_indices=[1]))
train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)

使用Softmax函数作为激活函数:

使用TensorFlow搭建一个全连接神经网络教程

4. 计算正确率

correct_predictions = tf.equal(tf.argmax(predictions, 1), tf.argmax(ys, 1))
accuracy = tf.reduce_mean(tf.cast(correct_predictions, tf.float32))

5. 使用模型

with tf.Session() as sess:
 init = tf.global_variables_initializer()
 sess.run(init)
 for i in range(1000):
  batch_xs, batch_ys = mnist.train.next_batch(100)
  sess.run(train_step, feed_dict={
   xs: batch_xs,
   ys: batch_ys
  })
  if i % 50 == 0:
   print(sess.run(accuracy, feed_dict={
    xs: mnist.test.images,
    ys: mnist.test.labels
   }))

运行结果

训练1000个循环, 准确率在87%左右.

Extracting MNIST_data/train-images-idx3-ubyte.gz
Extracting MNIST_data/train-labels-idx1-ubyte.gz
Extracting MNIST_data/t10k-images-idx3-ubyte.gz
Extracting MNIST_data/t10k-labels-idx1-ubyte.gz
0.1041
0.632
0.7357
0.7837
0.7971
0.8147
0.8283
0.8376
0.8423
0.8501
0.8501
0.8533
0.8567
0.8597
0.8552
0.8647
0.8654
0.8701
0.8712
0.8712

以上这篇使用TensorFlow搭建一个全连接神经网络教程就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python如何读取MySQL数据库表数据
Mar 11 Python
python中实现k-means聚类算法详解
Nov 11 Python
Python实现按当前日期(年、月、日)创建多级目录的方法
Apr 26 Python
对python过滤器和lambda函数的用法详解
Jan 21 Python
有关Tensorflow梯度下降常用的优化方法分享
Feb 04 Python
Python3中configparser模块读写ini文件并解析配置的用法详解
Feb 18 Python
在python image 中实现安装中文字体
May 16 Python
python+requests接口自动化框架的实现
Aug 31 Python
Python调用JavaScript代码的方法
Oct 27 Python
python如何利用paramiko执行服务器命令
Nov 07 Python
Python类class参数self原理解析
Nov 19 Python
Python使用Opencv打开笔记本电脑摄像头报错解问题及解决
Jun 21 Python
详解python 降级到3.6终极解决方案
Feb 06 #Python
PyCharm如何导入python项目的方法
Feb 06 #Python
tensorflow 环境变量设置方式
Feb 06 #Python
快速查找Python安装路径方法
Feb 06 #Python
运行tensorflow python程序,限制对GPU和CPU的占用操作
Feb 06 #Python
如何在django中添加日志功能
Feb 06 #Python
keras tensorflow 实现在python下多进程运行
Feb 06 #Python
You might like
Win下如何安装PHP的APC拓展
2013/08/07 PHP
[原创]PHP实现字节数Byte转换为KB、MB、GB、TB的方法
2017/08/31 PHP
PHP的new static和new self的区别与使用
2019/11/27 PHP
js模拟点击事件实现代码
2012/11/06 Javascript
JavaScript实现的日期控件具体代码
2013/11/18 Javascript
jquery取消选择select下拉框示例代码
2014/02/22 Javascript
将数字转换成大写的人民币表达式的js函数
2014/09/21 Javascript
jQuery实现获取绑定自定义事件元素的方法
2015/12/02 Javascript
使用jQuery+EasyUI实现CheckBoxTree的级联选中特效
2015/12/06 Javascript
jQuery延迟执行的实现方法
2016/12/21 Javascript
基于javascript实现数字英文验证码
2017/01/25 Javascript
Angular.Js中ng-include指令的使用与实现
2017/05/07 Javascript
快速掌握jquery分页插件jqPaginator的使用方法
2017/08/09 jQuery
基于jQuery Ajax实现下拉框无刷新联动
2017/12/06 jQuery
jQuery+Datatables实现表格批量删除功能【推荐】
2018/10/24 jQuery
微信小程序服务器日期格式化问题
2020/01/07 Javascript
如何构建一个Vue插件并生成npm包
2020/10/26 Javascript
[10:21]DOTA2-DPC中国联赛 正赛 PSG.LGD vs Aster 选手采访
2021/03/11 DOTA
PYTHON正则表达式 re模块使用说明
2011/05/19 Python
python对html代码进行escape编码的方法
2015/05/04 Python
Python多进程并发(multiprocessing)用法实例详解
2015/06/02 Python
Python获取文件所在目录和文件名的方法
2017/01/12 Python
python 读取excel文件生成sql文件实例详解
2017/05/12 Python
Python callable()函数用法实例分析
2018/03/17 Python
解决Pycharm中import时无法识别自己写的程序方法
2018/05/18 Python
基于anaconda下强大的conda命令介绍
2018/06/11 Python
python画图的函数用法以及技巧
2019/06/28 Python
Python 时间戳之获取整点凌晨时间戳的操作方法
2020/01/28 Python
jupyter 使用Pillow包显示图像时inline显示方式
2020/04/24 Python
记一次Django响应超慢的解决过程
2020/09/17 Python
记录一下scrapy中settings的一些配置小结
2020/09/28 Python
澳大利亚家具和家居用品在线:BROSA
2017/11/02 全球购物
Wedgwood英国官方网站:英式精致骨瓷餐具、礼品与生活精品,源于1759年
2019/09/02 全球购物
2014年社区学雷锋活动总结
2014/03/09 职场文书
竞选生活委员演讲稿
2014/04/28 职场文书
2014班子成员自我剖析材料思想汇报
2014/10/01 职场文书