使用TensorFlow搭建一个全连接神经网络教程


Posted in Python onFebruary 06, 2020

说明

本例子利用TensorFlow搭建一个全连接神经网络,实现对MNIST手写数字的识别。

先上代码

from tensorflow.examples.tutorials.mnist import input_data
import tensorflow as tf

# prepare data
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)

xs = tf.placeholder(tf.float32, [None, 784])
ys = tf.placeholder(tf.float32, [None, 10])

# the model of the fully-connected network
weights = tf.Variable(tf.random_normal([784, 10]))
biases = tf.Variable(tf.zeros([1, 10]) + 0.1)
outputs = tf.matmul(xs, weights) + biases
predictions = tf.nn.softmax(outputs)
cross_entropy = tf.reduce_mean(-tf.reduce_sum(ys * tf.log(predictions),
            reduction_indices=[1]))
train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)

# compute the accuracy
correct_predictions = tf.equal(tf.argmax(predictions, 1), tf.argmax(ys, 1))
accuracy = tf.reduce_mean(tf.cast(correct_predictions, tf.float32))

with tf.Session() as sess:
 init = tf.global_variables_initializer()
 sess.run(init)
 for i in range(1000):
  batch_xs, batch_ys = mnist.train.next_batch(100)
  sess.run(train_step, feed_dict={
   xs: batch_xs,
   ys: batch_ys
  })
  if i % 50 == 0:
   print(sess.run(accuracy, feed_dict={
    xs: mnist.test.images,
    ys: mnist.test.labels
   }))

代码解析

1. 读取MNIST数据

mnist = input_data.read_data_sets('MNIST_data', one_hot=True)

2. 建立占位符

xs = tf.placeholder(tf.float32, [None, 784])
ys = tf.placeholder(tf.float32, [None, 10])

xs 代表图片像素数据, 每张图片(28×28)被展开成(1×784), 有多少图片还未定, 所以shape为None×784.

ys 代表图片标签数据, 0-9十个数字被表示成One-hot形式, 即只有对应bit为1, 其余为0.

3. 建立模型

weights = tf.Variable(tf.random_normal([784, 10]))


biases = tf.Variable(tf.zeros([1, 10]) + 0.1)
outputs = tf.matmul(xs, weights) + biases
predictions = tf.nn.softmax(outputs)
cross_entropy = tf.reduce_mean(-tf.reduce_sum(ys * tf.log(predictions),
            reduction_indices=[1]))
train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)

使用Softmax函数作为激活函数:

使用TensorFlow搭建一个全连接神经网络教程

4. 计算正确率

correct_predictions = tf.equal(tf.argmax(predictions, 1), tf.argmax(ys, 1))
accuracy = tf.reduce_mean(tf.cast(correct_predictions, tf.float32))

5. 使用模型

with tf.Session() as sess:
 init = tf.global_variables_initializer()
 sess.run(init)
 for i in range(1000):
  batch_xs, batch_ys = mnist.train.next_batch(100)
  sess.run(train_step, feed_dict={
   xs: batch_xs,
   ys: batch_ys
  })
  if i % 50 == 0:
   print(sess.run(accuracy, feed_dict={
    xs: mnist.test.images,
    ys: mnist.test.labels
   }))

运行结果

训练1000个循环, 准确率在87%左右.

Extracting MNIST_data/train-images-idx3-ubyte.gz
Extracting MNIST_data/train-labels-idx1-ubyte.gz
Extracting MNIST_data/t10k-images-idx3-ubyte.gz
Extracting MNIST_data/t10k-labels-idx1-ubyte.gz
0.1041
0.632
0.7357
0.7837
0.7971
0.8147
0.8283
0.8376
0.8423
0.8501
0.8501
0.8533
0.8567
0.8597
0.8552
0.8647
0.8654
0.8701
0.8712
0.8712

以上这篇使用TensorFlow搭建一个全连接神经网络教程就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python实现倒计时的示例
Feb 14 Python
python使用BeautifulSoup分析网页信息的方法
Apr 04 Python
致Python初学者 Anaconda入门使用指南完整版
Apr 05 Python
python解决pandas处理缺失值为空字符串的问题
Apr 08 Python
Win10下Python3.7.3安装教程图解
Jul 08 Python
python实现输入任意一个大写字母生成金字塔的示例
Oct 27 Python
wxPython窗体拆分布局基础组件
Nov 19 Python
使用Tensorflow将自己的数据分割成batch训练实例
Jan 20 Python
Python内存映射文件读写方式
Apr 24 Python
使用Keras预训练好的模型进行目标类别预测详解
Jun 27 Python
Python学习笔记之装饰器
Aug 06 Python
Python如何创建装饰器时保留函数元信息
Aug 07 Python
详解python 降级到3.6终极解决方案
Feb 06 #Python
PyCharm如何导入python项目的方法
Feb 06 #Python
tensorflow 环境变量设置方式
Feb 06 #Python
快速查找Python安装路径方法
Feb 06 #Python
运行tensorflow python程序,限制对GPU和CPU的占用操作
Feb 06 #Python
如何在django中添加日志功能
Feb 06 #Python
keras tensorflow 实现在python下多进程运行
Feb 06 #Python
You might like
Windows下的PHP5.0详解
2006/11/18 PHP
PHP捕获Fatal error错误的方法
2014/06/11 PHP
PHP预定义变量9大超全局数组用法详解
2016/04/23 PHP
关于laravel 日志写入失败问题汇总
2019/10/17 PHP
JavaScript 入门基础知识 想学习js的朋友可以参考下
2009/12/26 Javascript
javascript中的绑定与解绑函数应用示例
2013/06/24 Javascript
javascript利用apply和arguments复用方法
2013/11/25 Javascript
jquery图片切换实例分析
2015/04/15 Javascript
jQuery实现动画效果circle实例
2015/08/06 Javascript
JavaScript实现点击按钮字体放大、缩小
2016/02/29 Javascript
jQuery中的一些常见方法小结(推荐)
2016/06/13 Javascript
layer实现关闭弹出层刷新父界面功能详解
2017/11/15 Javascript
Vuex提升学习篇
2018/01/11 Javascript
js数组常用最重要的方法
2018/02/04 Javascript
JS通过位运算实现权限加解密
2018/08/14 Javascript
Vue项目引发的「过滤器」使用教程
2019/03/12 Javascript
利用soaplib搭建webservice详细步骤和实例代码
2013/11/20 Python
Linux下将Python的Django项目部署到Apache服务器
2015/12/24 Python
使用PM2+nginx部署python项目的方法示例
2018/11/07 Python
python-itchat 获取微信群用户信息的实例
2019/02/21 Python
Django文件存储 自己定制存储系统解析
2019/08/02 Python
Python装饰器使用你可能不知道的几种姿势
2019/10/25 Python
使用python接受tgam的脑波数据实例
2020/04/09 Python
Python基于xlrd模块处理合并单元格
2020/07/28 Python
使用canvas绘制超炫时钟
2014/12/17 HTML / CSS
乌克兰珠宝大卖场:Zlato.ua
2020/09/27 全球购物
一份软件工程师的面试试题
2016/02/01 面试题
医学生自我鉴定范文
2013/11/08 职场文书
法律专业推荐信范文
2013/11/29 职场文书
四风问题自查报告剖析材料
2014/02/08 职场文书
电焊工岗位职责
2014/03/06 职场文书
会计求职自荐信范文
2015/03/04 职场文书
2015年圣诞节活动总结
2015/03/24 职场文书
项目备案申请报告
2015/05/15 职场文书
导游词之北京明十三陵
2019/10/28 职场文书
pytorch损失反向传播后梯度为none的问题
2021/05/12 Python