tensorflow使用神经网络实现mnist分类


Posted in Python onSeptember 08, 2018

本文实例为大家分享了tensorflow神经网络实现mnist分类的具体代码,供大家参考,具体内容如下

只有两层的神经网络,直接上代码

#引入包
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
#引入input_data文件
from tensorflow.examples.tutorials.mnist import input_data
#读取文件
mnist = input_data.read_data_sets('F:/mnist/data/',one_hot=True)

#定义第一个隐藏层和第二个隐藏层,输入层输出层
n_hidden_1 = 256
n_hidden_2 = 128
n_input = 784
n_classes = 10

#由于不知道输入图片个数,所以用placeholder
x = tf.placeholder("float",[None,n_input])
y = tf.placeholder("float",[None,n_classes])

stddev = 0.1

#定义权重
weights = {
    'w1':tf.Variable(tf.random_normal([n_input,n_hidden_1],stddev = stddev)),
    'w2':tf.Variable(tf.random_normal([n_hidden_1,n_hidden_2],stddev=stddev)),
    'out':tf.Variable(tf.random_normal([n_hidden_2,n_classes],stddev=stddev))    
    }

#定义偏置
biases = {
    'b1':tf.Variable(tf.random_normal([n_hidden_1])),
    'b2':tf.Variable(tf.random_normal([n_hidden_2])),
    'out':tf.Variable(tf.random_normal([n_classes])), 
    }
print("Network is Ready")


#前向传播
def multilayer_perceptrin(_X,_weights,_biases):
  layer1 = tf.nn.sigmoid(tf.add(tf.matmul(_X,_weights['w1']),_biases['b1']))
  layer2 = tf.nn.sigmoid(tf.add(tf.matmul(layer1,_weights['w2']),_biases['b2']))
  return (tf.matmul(layer2,_weights['out'])+_biases['out'])

#定义优化函数,精准度等
pred = multilayer_perceptrin(x,weights,biases)
cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits = pred,labels=y))
optm = tf.train.GradientDescentOptimizer(learning_rate = 0.001).minimize(cost)
corr = tf.equal(tf.argmax(pred,1),tf.argmax(y,1))
accr = tf.reduce_mean(tf.cast(corr,"float"))
print("Functions is ready")

#定义超参数
training_epochs = 80
batch_size = 200
display_step = 4

#会话开始
init = tf.global_variables_initializer()
sess = tf.Session()
sess.run(init)

#优化
for epoch in range(training_epochs):
  avg_cost=0.
  total_batch = int(mnist.train.num_examples/batch_size)

  for i in range(total_batch):
    batch_xs,batch_ys = mnist.train.next_batch(batch_size)
    feeds = {x:batch_xs,y:batch_ys}
    sess.run(optm,feed_dict = feeds)
    avg_cost += sess.run(cost,feed_dict=feeds)
  avg_cost = avg_cost/total_batch

  if (epoch+1) % display_step ==0:
    print("Epoch:%03d/%03d cost:%.9f"%(epoch,training_epochs,avg_cost))
    feeds = {x:batch_xs,y:batch_ys}
    train_acc = sess.run(accr,feed_dict = feeds)
    print("Train accuracy:%.3f"%(train_acc))
    feeds = {x:mnist.test.images,y:mnist.test.labels}
    test_acc = sess.run(accr,feed_dict = feeds)
    print("Test accuracy:%.3f"%(test_acc))
print("Optimization Finished")

程序部分运行结果如下:

Train accuracy:0.605
Test accuracy:0.633
Epoch:071/080 cost:1.810029302
Train accuracy:0.600
Test accuracy:0.645
Epoch:075/080 cost:1.761531130
Train accuracy:0.690
Test accuracy:0.649
Epoch:079/080 cost:1.711757494
Train accuracy:0.640
Test accuracy:0.660
Optimization Finished

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
巧用Python装饰器 免去调用父类构造函数的麻烦
May 18 Python
使用grappelli为django admin后台添加模板
Nov 18 Python
Python操作MongoDB数据库PyMongo库使用方法
Apr 27 Python
Windows下python3.7安装教程
Jul 31 Python
python绘制简单彩虹图
Nov 19 Python
python 获取页面表格数据存放到csv中的方法
Dec 26 Python
python 采用paramiko 远程执行命令及报错解决
Oct 21 Python
wxpython多线程防假死与线程间传递消息实例详解
Dec 13 Python
在pycharm中为项目导入anacodna环境的操作方法
Feb 12 Python
QML用PathView实现轮播图
Jun 03 Python
Python实现哲学家就餐问题实例代码
Nov 09 Python
教你怎么用PyCharm为同一服务器配置多个python解释器
May 31 Python
Python unittest单元测试框架总结
Sep 08 #Python
tensorflow实现加载mnist数据集
Sep 08 #Python
使用tensorflow实现线性回归
Sep 08 #Python
Python  unittest单元测试框架的使用
Sep 08 #Python
tensorflow实现逻辑回归模型
Sep 08 #Python
Django实现表单验证
Sep 08 #Python
python实现排序算法解析
Sep 08 #Python
You might like
解析PHP留言本模块主要功能的函数说明(代码可实现)
2013/06/25 PHP
PHP实现将textarea的值根据回车换行拆分至数组
2015/06/10 PHP
Zend Framework教程之连接数据库并执行增删查的方法(附demo源码下载)
2016/03/21 PHP
php7函数,声明,返回值等新特性介绍
2018/05/25 PHP
PHP中实现中文字串截取无乱码的解决方法
2018/05/29 PHP
简单实用的js调试logger组件实现代码
2010/11/20 Javascript
JavaScript 垃圾回收机制分析
2013/10/10 Javascript
生成二维码方法汇总
2014/12/26 Javascript
JavaScript学习笔记之Cookie对象
2015/01/22 Javascript
jQuery版本升级踩坑大全
2016/01/12 Javascript
仅30行代码实现Javascript中的MVC
2016/02/15 Javascript
让你一句话理解闭包(简单易懂)
2016/06/03 Javascript
解析浏览器端的AJAX缓存机制
2016/06/21 Javascript
深入分析javascript中console命令
2016/08/14 Javascript
JavaScript中localStorage对象存储方式实例分析
2017/01/12 Javascript
使用vue.js2.0 + ElementUI开发后台管理系统详细教程(一)
2017/01/21 Javascript
详解nodejs微信公众号开发——4.自动回复各种消息
2017/04/11 NodeJs
Node.js如何实现注册邮箱激活功能 (常见)
2017/07/23 Javascript
JS手机端touch事件计算滑动距离的方法示例
2017/10/26 Javascript
vue异步加载高德地图的实现
2018/06/19 Javascript
Echart折线图手柄触发事件示例详解
2018/12/16 Javascript
原来JS还可以这样拆箱转换详解
2019/02/01 Javascript
详解vue的双向绑定原理及实现
2019/05/05 Javascript
vue配置接口域名方法总结
2019/05/12 Javascript
基于vue.js实现购物车
2020/01/15 Javascript
koa-passport实现本地验证的方法示例
2020/02/20 Javascript
python通过Windows下远程控制Linux系统
2018/06/20 Python
python基于C/S模式实现聊天室功能
2019/01/09 Python
简单了解python中的与或非运算
2019/09/18 Python
Django 实现将图片转为Base64,然后使用json传输
2020/03/27 Python
拥有超过850家商店的美国在线派对商店:Party City
2018/10/21 全球购物
SmartBuyGlasses德国:购买太阳镜和眼镜
2019/08/20 全球购物
物流管理专业求职信
2014/05/29 职场文书
签约仪式策划方案
2014/06/02 职场文书
高二学年自我鉴定范文(2篇)
2014/09/26 职场文书
离职告别感言
2015/08/04 职场文书