使用TensorFlow搭建一个全连接神经网络教程


Posted in Python onFebruary 06, 2020

说明

本例子利用TensorFlow搭建一个全连接神经网络,实现对MNIST手写数字的识别。

先上代码

from tensorflow.examples.tutorials.mnist import input_data
import tensorflow as tf

# prepare data
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)

xs = tf.placeholder(tf.float32, [None, 784])
ys = tf.placeholder(tf.float32, [None, 10])

# the model of the fully-connected network
weights = tf.Variable(tf.random_normal([784, 10]))
biases = tf.Variable(tf.zeros([1, 10]) + 0.1)
outputs = tf.matmul(xs, weights) + biases
predictions = tf.nn.softmax(outputs)
cross_entropy = tf.reduce_mean(-tf.reduce_sum(ys * tf.log(predictions),
            reduction_indices=[1]))
train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)

# compute the accuracy
correct_predictions = tf.equal(tf.argmax(predictions, 1), tf.argmax(ys, 1))
accuracy = tf.reduce_mean(tf.cast(correct_predictions, tf.float32))

with tf.Session() as sess:
 init = tf.global_variables_initializer()
 sess.run(init)
 for i in range(1000):
  batch_xs, batch_ys = mnist.train.next_batch(100)
  sess.run(train_step, feed_dict={
   xs: batch_xs,
   ys: batch_ys
  })
  if i % 50 == 0:
   print(sess.run(accuracy, feed_dict={
    xs: mnist.test.images,
    ys: mnist.test.labels
   }))

代码解析

1. 读取MNIST数据

mnist = input_data.read_data_sets('MNIST_data', one_hot=True)

2. 建立占位符

xs = tf.placeholder(tf.float32, [None, 784])
ys = tf.placeholder(tf.float32, [None, 10])

xs 代表图片像素数据, 每张图片(28×28)被展开成(1×784), 有多少图片还未定, 所以shape为None×784.

ys 代表图片标签数据, 0-9十个数字被表示成One-hot形式, 即只有对应bit为1, 其余为0.

3. 建立模型

weights = tf.Variable(tf.random_normal([784, 10]))


biases = tf.Variable(tf.zeros([1, 10]) + 0.1)
outputs = tf.matmul(xs, weights) + biases
predictions = tf.nn.softmax(outputs)
cross_entropy = tf.reduce_mean(-tf.reduce_sum(ys * tf.log(predictions),
            reduction_indices=[1]))
train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)

使用Softmax函数作为激活函数:

使用TensorFlow搭建一个全连接神经网络教程

4. 计算正确率

correct_predictions = tf.equal(tf.argmax(predictions, 1), tf.argmax(ys, 1))
accuracy = tf.reduce_mean(tf.cast(correct_predictions, tf.float32))

5. 使用模型

with tf.Session() as sess:
 init = tf.global_variables_initializer()
 sess.run(init)
 for i in range(1000):
  batch_xs, batch_ys = mnist.train.next_batch(100)
  sess.run(train_step, feed_dict={
   xs: batch_xs,
   ys: batch_ys
  })
  if i % 50 == 0:
   print(sess.run(accuracy, feed_dict={
    xs: mnist.test.images,
    ys: mnist.test.labels
   }))

运行结果

训练1000个循环, 准确率在87%左右.

Extracting MNIST_data/train-images-idx3-ubyte.gz
Extracting MNIST_data/train-labels-idx1-ubyte.gz
Extracting MNIST_data/t10k-images-idx3-ubyte.gz
Extracting MNIST_data/t10k-labels-idx1-ubyte.gz
0.1041
0.632
0.7357
0.7837
0.7971
0.8147
0.8283
0.8376
0.8423
0.8501
0.8501
0.8533
0.8567
0.8597
0.8552
0.8647
0.8654
0.8701
0.8712
0.8712

以上这篇使用TensorFlow搭建一个全连接神经网络教程就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python解析中国天气网的天气数据
Mar 21 Python
Python中基础的socket编程实战攻略
Jun 01 Python
使用Python的Flask框架构建大型Web应用程序的结构示例
Jun 04 Python
centos 安装python3.6环境并配置虚拟环境的详细教程
Feb 22 Python
详解PyTorch批训练及优化器比较
Apr 28 Python
python版本五子棋的实现代码
Dec 11 Python
Python和Java的语法对比分析语法简洁上python的确完美胜出
May 10 Python
win10系统下python3安装及pip换源和使用教程
Jan 06 Python
在django admin详情表单显示中添加自定义控件的实现
Mar 11 Python
python数据库编程 ODBC方式实现通讯录
Mar 27 Python
用python制作个视频下载器
Feb 01 Python
Python机器学习算法之决策树算法的实现与优缺点
May 13 Python
详解python 降级到3.6终极解决方案
Feb 06 #Python
PyCharm如何导入python项目的方法
Feb 06 #Python
tensorflow 环境变量设置方式
Feb 06 #Python
快速查找Python安装路径方法
Feb 06 #Python
运行tensorflow python程序,限制对GPU和CPU的占用操作
Feb 06 #Python
如何在django中添加日志功能
Feb 06 #Python
keras tensorflow 实现在python下多进程运行
Feb 06 #Python
You might like
PHP比你想象的好得多
2014/11/27 PHP
php文件上传简单实现方法
2015/01/24 PHP
php获取目录下所有文件及目录(多种方法)(推荐)
2019/05/14 PHP
js checkbox(复选框) 使用集锦
2009/04/28 Javascript
一组JS创建和操作表格的函数集合
2009/05/07 Javascript
九种js弹出对话框的方法总结
2013/03/12 Javascript
innerText 使用示例
2014/01/23 Javascript
jquery Ajax 实现加载数据前动画效果的示例代码
2014/02/07 Javascript
JS合并数组的几种方法及优劣比较
2014/09/19 Javascript
jquery实现的点击翻书效果代码
2015/11/04 Javascript
JS阻止事件冒泡行为和闭包的方法
2016/06/16 Javascript
一个仿微博登陆邮箱提示框js开发案例
2016/07/28 Javascript
浅谈js继承的实现及公有、私有、静态方法的书写
2016/10/28 Javascript
javaScript中定义类或对象的五种方式总结
2016/12/04 Javascript
详解node如何让一个端口同时支持https与http
2017/07/04 Javascript
详解利用 Vue.js 实现前后端分离的RBAC角色权限管理
2017/09/15 Javascript
python结合opencv实现人脸检测与跟踪
2015/06/08 Python
python做接口测试的必要性
2019/11/20 Python
通过python连接Linux命令行代码实例
2020/02/18 Python
解决Pycharm双击图标启动不了的问题(JetBrains全家桶通用)
2020/08/07 Python
Numpy实现卷积神经网络(CNN)的示例
2020/10/09 Python
Django如何实现防止XSS攻击
2020/10/13 Python
英国在线照明超市:Castlegate Lights
2019/10/30 全球购物
指针和引用有什么区别
2013/01/13 面试题
学期自我鉴定范文
2013/10/01 职场文书
信息技术课后反思
2014/04/27 职场文书
知识竞赛拉拉队口号
2014/06/16 职场文书
兽医医药专业求职信
2014/07/27 职场文书
七年级上册语文教学计划
2015/01/22 职场文书
文体活动总结
2015/02/04 职场文书
个人总结格式范文
2015/03/09 职场文书
2016年五一促销广告语
2016/01/28 职场文书
MySQL 发生同步延迟时Seconds_Behind_Master还为0的原因
2021/06/21 MySQL
Mysql binlog日志文件过大的解决
2021/10/05 MySQL
Mysql 8.x 创建用户以及授予权限的操作记录
2022/04/18 MySQL
阿里云国际版 使用Nginx作为HTTPS转发代理服务器
2022/05/11 Servers