使用TensorFlow搭建一个全连接神经网络教程


Posted in Python onFebruary 06, 2020

说明

本例子利用TensorFlow搭建一个全连接神经网络,实现对MNIST手写数字的识别。

先上代码

from tensorflow.examples.tutorials.mnist import input_data
import tensorflow as tf

# prepare data
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)

xs = tf.placeholder(tf.float32, [None, 784])
ys = tf.placeholder(tf.float32, [None, 10])

# the model of the fully-connected network
weights = tf.Variable(tf.random_normal([784, 10]))
biases = tf.Variable(tf.zeros([1, 10]) + 0.1)
outputs = tf.matmul(xs, weights) + biases
predictions = tf.nn.softmax(outputs)
cross_entropy = tf.reduce_mean(-tf.reduce_sum(ys * tf.log(predictions),
            reduction_indices=[1]))
train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)

# compute the accuracy
correct_predictions = tf.equal(tf.argmax(predictions, 1), tf.argmax(ys, 1))
accuracy = tf.reduce_mean(tf.cast(correct_predictions, tf.float32))

with tf.Session() as sess:
 init = tf.global_variables_initializer()
 sess.run(init)
 for i in range(1000):
  batch_xs, batch_ys = mnist.train.next_batch(100)
  sess.run(train_step, feed_dict={
   xs: batch_xs,
   ys: batch_ys
  })
  if i % 50 == 0:
   print(sess.run(accuracy, feed_dict={
    xs: mnist.test.images,
    ys: mnist.test.labels
   }))

代码解析

1. 读取MNIST数据

mnist = input_data.read_data_sets('MNIST_data', one_hot=True)

2. 建立占位符

xs = tf.placeholder(tf.float32, [None, 784])
ys = tf.placeholder(tf.float32, [None, 10])

xs 代表图片像素数据, 每张图片(28×28)被展开成(1×784), 有多少图片还未定, 所以shape为None×784.

ys 代表图片标签数据, 0-9十个数字被表示成One-hot形式, 即只有对应bit为1, 其余为0.

3. 建立模型

weights = tf.Variable(tf.random_normal([784, 10]))


biases = tf.Variable(tf.zeros([1, 10]) + 0.1)
outputs = tf.matmul(xs, weights) + biases
predictions = tf.nn.softmax(outputs)
cross_entropy = tf.reduce_mean(-tf.reduce_sum(ys * tf.log(predictions),
            reduction_indices=[1]))
train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)

使用Softmax函数作为激活函数:

使用TensorFlow搭建一个全连接神经网络教程

4. 计算正确率

correct_predictions = tf.equal(tf.argmax(predictions, 1), tf.argmax(ys, 1))
accuracy = tf.reduce_mean(tf.cast(correct_predictions, tf.float32))

5. 使用模型

with tf.Session() as sess:
 init = tf.global_variables_initializer()
 sess.run(init)
 for i in range(1000):
  batch_xs, batch_ys = mnist.train.next_batch(100)
  sess.run(train_step, feed_dict={
   xs: batch_xs,
   ys: batch_ys
  })
  if i % 50 == 0:
   print(sess.run(accuracy, feed_dict={
    xs: mnist.test.images,
    ys: mnist.test.labels
   }))

运行结果

训练1000个循环, 准确率在87%左右.

Extracting MNIST_data/train-images-idx3-ubyte.gz
Extracting MNIST_data/train-labels-idx1-ubyte.gz
Extracting MNIST_data/t10k-images-idx3-ubyte.gz
Extracting MNIST_data/t10k-labels-idx1-ubyte.gz
0.1041
0.632
0.7357
0.7837
0.7971
0.8147
0.8283
0.8376
0.8423
0.8501
0.8501
0.8533
0.8567
0.8597
0.8552
0.8647
0.8654
0.8701
0.8712
0.8712

以上这篇使用TensorFlow搭建一个全连接神经网络教程就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
跟老齐学Python之折腾一下目录
Oct 24 Python
Python map和reduce函数用法示例
Feb 26 Python
通过数据库对Django进行删除字段和删除模型的操作
Jul 21 Python
Python实现的计数排序算法示例
Nov 29 Python
pyqt5简介及安装方法介绍
Jan 31 Python
python PyTorch预训练示例
Feb 11 Python
12个Python程序员面试必备问题与答案(小结)
Jun 24 Python
Django中多种重定向方法使用详解
Jul 17 Python
PyQt使用QPropertyAnimation开发简单动画
Apr 02 Python
Jupyter notebook设置背景主题,字体大小及自动补全代码的操作
Apr 13 Python
如何更换python默认编辑器的背景色
Aug 10 Python
python飞机大战游戏实例讲解
Dec 04 Python
详解python 降级到3.6终极解决方案
Feb 06 #Python
PyCharm如何导入python项目的方法
Feb 06 #Python
tensorflow 环境变量设置方式
Feb 06 #Python
快速查找Python安装路径方法
Feb 06 #Python
运行tensorflow python程序,限制对GPU和CPU的占用操作
Feb 06 #Python
如何在django中添加日志功能
Feb 06 #Python
keras tensorflow 实现在python下多进程运行
Feb 06 #Python
You might like
PHP5.0正式发布 不完全兼容PHP4 新增多项功能
2006/10/09 PHP
织梦模板标记简介
2007/03/11 PHP
fleaphp crud操作之findByField函数的使用方法
2011/04/23 PHP
浅析PHP中的字符串编码转换(自动识别原编码)
2013/07/02 PHP
PHP读取文件内容的五种方式
2015/12/28 PHP
CI框架数据库查询缓存优化的方法
2016/11/21 PHP
php多文件打包下载的实例代码
2017/07/12 PHP
PHP开发实现微信退款功能示例
2017/11/25 PHP
防止网站内容被拷贝的一些方法与优缺点好处与坏处分析
2007/11/30 Javascript
Javascript 调试利器 Firebug使用详解六
2009/07/05 Javascript
Jquery+JSon 无刷新分页实现代码
2010/04/01 Javascript
js实现鼠标拖动图片并兼容IE/FF火狐/谷歌等主流浏览器
2013/06/06 Javascript
jquery ajax修改全局变量示例代码
2013/11/08 Javascript
jQuery Mobile 触摸事件实例
2016/06/04 Javascript
结合代码图文讲解JavaScript中的作用域与作用域链
2016/07/05 Javascript
jQuery实现判断控件是否显示的方法
2017/01/11 Javascript
js实现无缝滚动图
2017/02/22 Javascript
强大的 Angular 表单验证功能详细介绍
2017/05/23 Javascript
详解Vue文档中几个易忽视部分的剖析
2018/03/24 Javascript
详解ES6中的 Set Map 数据结构学习总结
2018/11/06 Javascript
使用vue for时为什么要key【推荐】
2019/07/11 Javascript
微信小程序官方动态自定义底部tabBar的例子
2019/09/04 Javascript
layer.open组件获取弹出层页面变量、函数的实例
2019/09/25 Javascript
Python批量创建迅雷任务及创建多个文件
2016/02/13 Python
利用ctypes提高Python的执行速度
2016/09/09 Python
python 删除指定时间间隔之前的文件实例
2018/04/24 Python
pytorch神经网络之卷积层与全连接层参数的设置方法
2019/08/18 Python
python实现微信小程序用户登录、模板推送
2019/08/28 Python
Python使用matplotlib 画矩形的三种方式分析
2019/10/31 Python
python selenium循环登陆网站的实现
2019/11/04 Python
python matplotlib如何给图中的点加标签
2019/11/14 Python
Pyecharts 中Geo函数常用参数的用法说明
2021/02/01 Python
农药学硕士毕业生自荐信
2013/09/25 职场文书
秋季运动会稿件
2014/01/30 职场文书
学生上课迟到检讨书
2015/01/01 职场文书
python opencv检测直线 cv2.HoughLinesP的实现
2021/06/18 Python