使用TensorFlow搭建一个全连接神经网络教程


Posted in Python onFebruary 06, 2020

说明

本例子利用TensorFlow搭建一个全连接神经网络,实现对MNIST手写数字的识别。

先上代码

from tensorflow.examples.tutorials.mnist import input_data
import tensorflow as tf

# prepare data
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)

xs = tf.placeholder(tf.float32, [None, 784])
ys = tf.placeholder(tf.float32, [None, 10])

# the model of the fully-connected network
weights = tf.Variable(tf.random_normal([784, 10]))
biases = tf.Variable(tf.zeros([1, 10]) + 0.1)
outputs = tf.matmul(xs, weights) + biases
predictions = tf.nn.softmax(outputs)
cross_entropy = tf.reduce_mean(-tf.reduce_sum(ys * tf.log(predictions),
            reduction_indices=[1]))
train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)

# compute the accuracy
correct_predictions = tf.equal(tf.argmax(predictions, 1), tf.argmax(ys, 1))
accuracy = tf.reduce_mean(tf.cast(correct_predictions, tf.float32))

with tf.Session() as sess:
 init = tf.global_variables_initializer()
 sess.run(init)
 for i in range(1000):
  batch_xs, batch_ys = mnist.train.next_batch(100)
  sess.run(train_step, feed_dict={
   xs: batch_xs,
   ys: batch_ys
  })
  if i % 50 == 0:
   print(sess.run(accuracy, feed_dict={
    xs: mnist.test.images,
    ys: mnist.test.labels
   }))

代码解析

1. 读取MNIST数据

mnist = input_data.read_data_sets('MNIST_data', one_hot=True)

2. 建立占位符

xs = tf.placeholder(tf.float32, [None, 784])
ys = tf.placeholder(tf.float32, [None, 10])

xs 代表图片像素数据, 每张图片(28×28)被展开成(1×784), 有多少图片还未定, 所以shape为None×784.

ys 代表图片标签数据, 0-9十个数字被表示成One-hot形式, 即只有对应bit为1, 其余为0.

3. 建立模型

weights = tf.Variable(tf.random_normal([784, 10]))


biases = tf.Variable(tf.zeros([1, 10]) + 0.1)
outputs = tf.matmul(xs, weights) + biases
predictions = tf.nn.softmax(outputs)
cross_entropy = tf.reduce_mean(-tf.reduce_sum(ys * tf.log(predictions),
            reduction_indices=[1]))
train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)

使用Softmax函数作为激活函数:

使用TensorFlow搭建一个全连接神经网络教程

4. 计算正确率

correct_predictions = tf.equal(tf.argmax(predictions, 1), tf.argmax(ys, 1))
accuracy = tf.reduce_mean(tf.cast(correct_predictions, tf.float32))

5. 使用模型

with tf.Session() as sess:
 init = tf.global_variables_initializer()
 sess.run(init)
 for i in range(1000):
  batch_xs, batch_ys = mnist.train.next_batch(100)
  sess.run(train_step, feed_dict={
   xs: batch_xs,
   ys: batch_ys
  })
  if i % 50 == 0:
   print(sess.run(accuracy, feed_dict={
    xs: mnist.test.images,
    ys: mnist.test.labels
   }))

运行结果

训练1000个循环, 准确率在87%左右.

Extracting MNIST_data/train-images-idx3-ubyte.gz
Extracting MNIST_data/train-labels-idx1-ubyte.gz
Extracting MNIST_data/t10k-images-idx3-ubyte.gz
Extracting MNIST_data/t10k-labels-idx1-ubyte.gz
0.1041
0.632
0.7357
0.7837
0.7971
0.8147
0.8283
0.8376
0.8423
0.8501
0.8501
0.8533
0.8567
0.8597
0.8552
0.8647
0.8654
0.8701
0.8712
0.8712

以上这篇使用TensorFlow搭建一个全连接神经网络教程就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
基于python实现学生管理系统
Oct 17 Python
使用Python向C语言的链接库传递数组、结构体、指针类型的数据
Jan 29 Python
详解python中的time和datetime的常用方法
Jul 08 Python
50行Python代码获取高考志愿信息的实现方法
Jul 23 Python
Python使用微信itchat接口实现查看自己微信的信息功能详解
Aug 22 Python
pyftplib中文乱码问题解决方案
Jan 11 Python
PyCharm GUI界面开发和exe文件生成的实现
Mar 04 Python
Ubuntu18.04安装 PyCharm并使用 Anaconda 管理的Python环境
Apr 08 Python
python 如何将office文件转换为PDF
Sep 22 Python
python3排序的实例方法
Oct 20 Python
Django2.1.7 查询数据返回json格式的实现
Dec 29 Python
Python爬取网站图片并保存的实现示例
Feb 26 Python
详解python 降级到3.6终极解决方案
Feb 06 #Python
PyCharm如何导入python项目的方法
Feb 06 #Python
tensorflow 环境变量设置方式
Feb 06 #Python
快速查找Python安装路径方法
Feb 06 #Python
运行tensorflow python程序,限制对GPU和CPU的占用操作
Feb 06 #Python
如何在django中添加日志功能
Feb 06 #Python
keras tensorflow 实现在python下多进程运行
Feb 06 #Python
You might like
adodb与adodb_lite之比较
2006/12/31 PHP
php冒泡排序、快速排序、快速查找、二维数组去重实例分享
2014/04/24 PHP
php基于curl扩展制作跨平台的restfule 接口
2015/05/11 PHP
laravel实现按时间日期进行分组统计方法示例
2019/03/23 PHP
13个绚丽的Jquery 界面设计网站推荐
2010/09/28 Javascript
浅谈 jQuery 事件源码定位问题
2014/06/18 Javascript
javascript伸缩型菜单实现代码
2015/11/16 Javascript
jQuery实现的简单拖拽功能示例
2016/09/13 Javascript
JavaScript继承与多继承实例分析
2018/05/26 Javascript
使用javascript函数编写简单银行取钱存钱流程
2018/05/26 Javascript
深入理解JavaScript 中的匿名函数((function() {})();)与变量的作用域
2018/08/28 Javascript
Vue通过ref父子组件拿值方法
2018/09/12 Javascript
Vue 页面状态保持页面间数据传输的一种方法(推荐)
2018/11/01 Javascript
原生JS实现逼真的图片3D旋转效果详解
2019/02/16 Javascript
如何通过JS实现转码与解码
2020/02/21 Javascript
vue项目中使用vue-layer弹框插件的方法
2020/03/11 Javascript
探究一道价值25k的蚂蚁金服异步串行面试题
2020/08/21 Javascript
vue中实现点击空白区域关闭弹窗的两种方法
2020/12/30 Vue.js
[01:47]2018年度DOTA2最具人气解说-完美盛典
2018/12/16 DOTA
详解python3中的真值测试
2018/08/13 Python
Python3 SSH远程连接服务器的方法示例
2018/12/29 Python
Python高级编程之继承问题详解(super与mro)
2019/11/19 Python
tensorflow指定GPU与动态分配GPU memory设置
2020/02/03 Python
使用ITK-SNAP进行抠图操作并保存mask的实例
2020/07/01 Python
基于selenium及python实现下拉选项定位select
2020/07/22 Python
css3 中实现炫酷的loading效果
2019/04/26 HTML / CSS
CSS3——齿轮转动关键代码
2013/05/02 HTML / CSS
一款利用纯css3实现的win8加载动画的实例分析
2014/12/11 HTML / CSS
德国在线订购鲜花:Fleurop
2018/08/25 全球购物
信息服务专业毕业生求职信
2014/03/02 职场文书
外语系毕业生求职自荐信
2014/04/12 职场文书
2014年大学班长工作总结
2014/11/14 职场文书
幼儿教师年度个人总结
2015/02/05 职场文书
2015年母亲节寄语
2015/03/23 职场文书
初三毕业感言
2015/07/31 职场文书
Java基础-封装和继承
2021/07/02 Java/Android