使用TensorFlow搭建一个全连接神经网络教程


Posted in Python onFebruary 06, 2020

说明

本例子利用TensorFlow搭建一个全连接神经网络,实现对MNIST手写数字的识别。

先上代码

from tensorflow.examples.tutorials.mnist import input_data
import tensorflow as tf

# prepare data
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)

xs = tf.placeholder(tf.float32, [None, 784])
ys = tf.placeholder(tf.float32, [None, 10])

# the model of the fully-connected network
weights = tf.Variable(tf.random_normal([784, 10]))
biases = tf.Variable(tf.zeros([1, 10]) + 0.1)
outputs = tf.matmul(xs, weights) + biases
predictions = tf.nn.softmax(outputs)
cross_entropy = tf.reduce_mean(-tf.reduce_sum(ys * tf.log(predictions),
            reduction_indices=[1]))
train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)

# compute the accuracy
correct_predictions = tf.equal(tf.argmax(predictions, 1), tf.argmax(ys, 1))
accuracy = tf.reduce_mean(tf.cast(correct_predictions, tf.float32))

with tf.Session() as sess:
 init = tf.global_variables_initializer()
 sess.run(init)
 for i in range(1000):
  batch_xs, batch_ys = mnist.train.next_batch(100)
  sess.run(train_step, feed_dict={
   xs: batch_xs,
   ys: batch_ys
  })
  if i % 50 == 0:
   print(sess.run(accuracy, feed_dict={
    xs: mnist.test.images,
    ys: mnist.test.labels
   }))

代码解析

1. 读取MNIST数据

mnist = input_data.read_data_sets('MNIST_data', one_hot=True)

2. 建立占位符

xs = tf.placeholder(tf.float32, [None, 784])
ys = tf.placeholder(tf.float32, [None, 10])

xs 代表图片像素数据, 每张图片(28×28)被展开成(1×784), 有多少图片还未定, 所以shape为None×784.

ys 代表图片标签数据, 0-9十个数字被表示成One-hot形式, 即只有对应bit为1, 其余为0.

3. 建立模型

weights = tf.Variable(tf.random_normal([784, 10]))


biases = tf.Variable(tf.zeros([1, 10]) + 0.1)
outputs = tf.matmul(xs, weights) + biases
predictions = tf.nn.softmax(outputs)
cross_entropy = tf.reduce_mean(-tf.reduce_sum(ys * tf.log(predictions),
            reduction_indices=[1]))
train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)

使用Softmax函数作为激活函数:

使用TensorFlow搭建一个全连接神经网络教程

4. 计算正确率

correct_predictions = tf.equal(tf.argmax(predictions, 1), tf.argmax(ys, 1))
accuracy = tf.reduce_mean(tf.cast(correct_predictions, tf.float32))

5. 使用模型

with tf.Session() as sess:
 init = tf.global_variables_initializer()
 sess.run(init)
 for i in range(1000):
  batch_xs, batch_ys = mnist.train.next_batch(100)
  sess.run(train_step, feed_dict={
   xs: batch_xs,
   ys: batch_ys
  })
  if i % 50 == 0:
   print(sess.run(accuracy, feed_dict={
    xs: mnist.test.images,
    ys: mnist.test.labels
   }))

运行结果

训练1000个循环, 准确率在87%左右.

Extracting MNIST_data/train-images-idx3-ubyte.gz
Extracting MNIST_data/train-labels-idx1-ubyte.gz
Extracting MNIST_data/t10k-images-idx3-ubyte.gz
Extracting MNIST_data/t10k-labels-idx1-ubyte.gz
0.1041
0.632
0.7357
0.7837
0.7971
0.8147
0.8283
0.8376
0.8423
0.8501
0.8501
0.8533
0.8567
0.8597
0.8552
0.8647
0.8654
0.8701
0.8712
0.8712

以上这篇使用TensorFlow搭建一个全连接神经网络教程就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
有趣的python小程序分享
Dec 05 Python
Python机器学习logistic回归代码解析
Jan 17 Python
python中format()函数的简单使用教程
Mar 14 Python
对python中for、if、while的区别与比较方法
Jun 25 Python
Python爬虫爬取新浪微博内容示例【基于代理IP】
Aug 03 Python
python单例模式获取IP代理的方法详解
Sep 13 Python
Python下简易的单例模式详解
Apr 08 Python
Python列表的切片实例讲解
Aug 20 Python
python 实现矩阵按对角线打印
Nov 29 Python
PyQt5+Pycharm安装和配置图文教程详解
Mar 24 Python
django 实现手动存储文件到model的FileField
Mar 30 Python
python之PySide2安装使用及QT Designer UI设计案例教程
Jul 26 Python
详解python 降级到3.6终极解决方案
Feb 06 #Python
PyCharm如何导入python项目的方法
Feb 06 #Python
tensorflow 环境变量设置方式
Feb 06 #Python
快速查找Python安装路径方法
Feb 06 #Python
运行tensorflow python程序,限制对GPU和CPU的占用操作
Feb 06 #Python
如何在django中添加日志功能
Feb 06 #Python
keras tensorflow 实现在python下多进程运行
Feb 06 #Python
You might like
php.ini中文版
2006/10/09 PHP
实用函数7
2007/11/08 PHP
PHP中exec与system用法区别分析
2014/09/22 PHP
php使用for语句输出三角形的方法
2015/06/09 PHP
Zend Framework入门教程之Zend_Config组件用法详解
2016/12/09 PHP
利用php的ob缓存机制实现页面静态化方法
2017/07/09 PHP
Ext.MessageBox工具类简介
2009/12/10 Javascript
鼠标移到div,浮层显示明细,弹出层与div的上边距左边距重合(示例代码)
2013/12/14 Javascript
Javascript 实现图片无缝滚动
2014/12/19 Javascript
Jquery中find与each方法用法实例
2015/02/04 Javascript
JavaScript实现数组在指定位置插入若干元素的方法
2015/04/06 Javascript
Nodejs多站点切换Htpps协议详解及简单实例
2017/02/23 NodeJs
js实现把时间戳转换为yyyy-MM-dd hh:mm 格式(es6语法)
2017/12/28 Javascript
基于js文件加载优化(详解)
2018/01/03 Javascript
从parcel.js打包出错到选择nvm的全部过程
2018/01/23 Javascript
JS字符串与二进制的相互转化实例代码详解
2019/06/28 Javascript
微信小程序开发之获取用户手机号码(php接口解密)
2020/05/17 Javascript
[06:50]DSPL次级职业联赛十强晋级之路
2014/11/18 DOTA
Python爬取商家联系电话以及各种数据的方法
2018/11/10 Python
Html5实现首页动态视频背景的示例代码
2019/09/25 HTML / CSS
Sneaker Studio匈牙利:购买运动鞋
2018/03/26 全球购物
在C++ 程序中调用被C 编译器编译后的函数,为什么要加extern "C"
2014/08/09 面试题
高中的职业生涯规划书
2013/12/28 职场文书
给同学的道歉信
2014/01/16 职场文书
音乐器材管理制度
2014/01/31 职场文书
GMP办公室主任岗位职责
2014/03/14 职场文书
特教教师先进事迹
2014/05/21 职场文书
地理信息科学专业推荐信
2014/09/08 职场文书
超市开店计划书
2014/09/15 职场文书
银行反四风对照检查材料
2014/09/29 职场文书
幼儿园园长工作总结2015
2015/05/25 职场文书
公司借款担保书
2015/09/22 职场文书
如何拟写通知正文?
2019/04/02 职场文书
《孙子兵法》:欲成大事者,需读懂这些致胜策略
2019/08/23 职场文书
2019年消防宣传标语集锦
2019/11/21 职场文书
Nginx 反向代理解决跨域问题多种情况分析
2022/01/18 Servers