使用TensorFlow搭建一个全连接神经网络教程


Posted in Python onFebruary 06, 2020

说明

本例子利用TensorFlow搭建一个全连接神经网络,实现对MNIST手写数字的识别。

先上代码

from tensorflow.examples.tutorials.mnist import input_data
import tensorflow as tf

# prepare data
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)

xs = tf.placeholder(tf.float32, [None, 784])
ys = tf.placeholder(tf.float32, [None, 10])

# the model of the fully-connected network
weights = tf.Variable(tf.random_normal([784, 10]))
biases = tf.Variable(tf.zeros([1, 10]) + 0.1)
outputs = tf.matmul(xs, weights) + biases
predictions = tf.nn.softmax(outputs)
cross_entropy = tf.reduce_mean(-tf.reduce_sum(ys * tf.log(predictions),
            reduction_indices=[1]))
train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)

# compute the accuracy
correct_predictions = tf.equal(tf.argmax(predictions, 1), tf.argmax(ys, 1))
accuracy = tf.reduce_mean(tf.cast(correct_predictions, tf.float32))

with tf.Session() as sess:
 init = tf.global_variables_initializer()
 sess.run(init)
 for i in range(1000):
  batch_xs, batch_ys = mnist.train.next_batch(100)
  sess.run(train_step, feed_dict={
   xs: batch_xs,
   ys: batch_ys
  })
  if i % 50 == 0:
   print(sess.run(accuracy, feed_dict={
    xs: mnist.test.images,
    ys: mnist.test.labels
   }))

代码解析

1. 读取MNIST数据

mnist = input_data.read_data_sets('MNIST_data', one_hot=True)

2. 建立占位符

xs = tf.placeholder(tf.float32, [None, 784])
ys = tf.placeholder(tf.float32, [None, 10])

xs 代表图片像素数据, 每张图片(28×28)被展开成(1×784), 有多少图片还未定, 所以shape为None×784.

ys 代表图片标签数据, 0-9十个数字被表示成One-hot形式, 即只有对应bit为1, 其余为0.

3. 建立模型

weights = tf.Variable(tf.random_normal([784, 10]))


biases = tf.Variable(tf.zeros([1, 10]) + 0.1)
outputs = tf.matmul(xs, weights) + biases
predictions = tf.nn.softmax(outputs)
cross_entropy = tf.reduce_mean(-tf.reduce_sum(ys * tf.log(predictions),
            reduction_indices=[1]))
train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)

使用Softmax函数作为激活函数:

使用TensorFlow搭建一个全连接神经网络教程

4. 计算正确率

correct_predictions = tf.equal(tf.argmax(predictions, 1), tf.argmax(ys, 1))
accuracy = tf.reduce_mean(tf.cast(correct_predictions, tf.float32))

5. 使用模型

with tf.Session() as sess:
 init = tf.global_variables_initializer()
 sess.run(init)
 for i in range(1000):
  batch_xs, batch_ys = mnist.train.next_batch(100)
  sess.run(train_step, feed_dict={
   xs: batch_xs,
   ys: batch_ys
  })
  if i % 50 == 0:
   print(sess.run(accuracy, feed_dict={
    xs: mnist.test.images,
    ys: mnist.test.labels
   }))

运行结果

训练1000个循环, 准确率在87%左右.

Extracting MNIST_data/train-images-idx3-ubyte.gz
Extracting MNIST_data/train-labels-idx1-ubyte.gz
Extracting MNIST_data/t10k-images-idx3-ubyte.gz
Extracting MNIST_data/t10k-labels-idx1-ubyte.gz
0.1041
0.632
0.7357
0.7837
0.7971
0.8147
0.8283
0.8376
0.8423
0.8501
0.8501
0.8533
0.8567
0.8597
0.8552
0.8647
0.8654
0.8701
0.8712
0.8712

以上这篇使用TensorFlow搭建一个全连接神经网络教程就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
详解Python2.x中对Unicode编码的使用
Apr 03 Python
python实现数据导出到excel的示例--普通格式
May 03 Python
Python使用jsonpath-rw模块处理Json对象操作示例
Jul 31 Python
Python在图片中插入大量文字并且自动换行
Jan 02 Python
pybind11和numpy进行交互的方法
Jul 04 Python
Python 获取 datax 执行结果保存到数据库的方法
Jul 11 Python
Python imread、newaxis用法详解
Nov 04 Python
django3.02模板中的超链接配置实例代码
Feb 04 Python
pytorch VGG11识别cifar10数据集(训练+预测单张输入图片操作)
Jun 24 Python
python求解汉诺塔游戏
Jul 09 Python
python urllib和urllib3知识点总结
Feb 08 Python
教你用python实现一个无界面的小型图书管理系统
May 21 Python
详解python 降级到3.6终极解决方案
Feb 06 #Python
PyCharm如何导入python项目的方法
Feb 06 #Python
tensorflow 环境变量设置方式
Feb 06 #Python
快速查找Python安装路径方法
Feb 06 #Python
运行tensorflow python程序,限制对GPU和CPU的占用操作
Feb 06 #Python
如何在django中添加日志功能
Feb 06 #Python
keras tensorflow 实现在python下多进程运行
Feb 06 #Python
You might like
访问编码后的中文URL返回404错误的解决方法
2014/08/20 PHP
php实现求相对时间函数
2015/06/15 PHP
JS刷新框架外页面七种实现代码
2013/02/18 Javascript
javascript scrollTop正解使用方法
2013/11/14 Javascript
JavaScript实现N皇后问题算法谜题解答
2014/12/29 Javascript
jquery模拟alert的弹窗插件
2015/07/31 Javascript
Jquery 1.9.1源码分析系列(十二)之筛选操作
2015/12/02 Javascript
基于JavaScript实现瀑布流效果(循环渐近)
2016/01/27 Javascript
全面解析Bootstrap中nav、collapse的使用方法
2016/05/22 Javascript
JS/jquery实现一个网页内同时调用多个倒计时的方法
2017/04/27 jQuery
基于JavaScript实现数码时钟效果
2020/03/30 Javascript
ES6中let 和 const 的新特性
2018/09/03 Javascript
微信小程序使用setData修改数组中单个对象的方法分析
2018/12/30 Javascript
原生js实现密码强度验证功能
2020/03/18 Javascript
Openlayers3实现车辆轨迹回放功能
2020/09/29 Javascript
Python实现获取操作系统版本信息方法
2015/04/08 Python
Python实现的朴素贝叶斯分类器示例
2018/01/06 Python
python中不能连接超时的问题及解决方法
2018/06/10 Python
Python打包方法Pyinstaller的使用
2018/10/09 Python
django小技巧之html模板中调用对象属性或对象的方法
2018/11/30 Python
Python使用pymongo库操作MongoDB数据库的方法实例
2019/02/22 Python
Python实用库 PrettyTable 学习笔记
2019/08/06 Python
tensorflow实现对张量数据的切片操作方式
2020/01/19 Python
Python使用Chrome插件实现爬虫过程图解
2020/06/09 Python
Django:使用filter的pk进行多值查询操作
2020/07/15 Python
PyQt5多线程防卡死和多窗口用法的实现
2020/09/15 Python
DHC中国官方购物网站:日本通信销售No.1化妆品
2016/08/20 全球购物
新西兰领先的鞋类和靴子网上商城:Merchant 1948
2017/09/08 全球购物
数控技术应届生求职信
2013/11/13 职场文书
法学毕业生自荐信
2013/11/13 职场文书
幼儿园教师教学反思
2014/02/06 职场文书
高中军训感言800字
2014/03/05 职场文书
物流专业专科生职业生涯规划书
2014/09/14 职场文书
运动会广播稿100字
2015/08/19 职场文书
新手入门Mysql--概念
2021/06/18 MySQL
根德5570型九灯四波段立体声收音机是电子管收音机的楷模 ? 再论5570
2022/04/05 无线电