tensorflow使用神经网络实现mnist分类


Posted in Python onSeptember 08, 2018

本文实例为大家分享了tensorflow神经网络实现mnist分类的具体代码,供大家参考,具体内容如下

只有两层的神经网络,直接上代码

#引入包
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
#引入input_data文件
from tensorflow.examples.tutorials.mnist import input_data
#读取文件
mnist = input_data.read_data_sets('F:/mnist/data/',one_hot=True)

#定义第一个隐藏层和第二个隐藏层,输入层输出层
n_hidden_1 = 256
n_hidden_2 = 128
n_input = 784
n_classes = 10

#由于不知道输入图片个数,所以用placeholder
x = tf.placeholder("float",[None,n_input])
y = tf.placeholder("float",[None,n_classes])

stddev = 0.1

#定义权重
weights = {
    'w1':tf.Variable(tf.random_normal([n_input,n_hidden_1],stddev = stddev)),
    'w2':tf.Variable(tf.random_normal([n_hidden_1,n_hidden_2],stddev=stddev)),
    'out':tf.Variable(tf.random_normal([n_hidden_2,n_classes],stddev=stddev))    
    }

#定义偏置
biases = {
    'b1':tf.Variable(tf.random_normal([n_hidden_1])),
    'b2':tf.Variable(tf.random_normal([n_hidden_2])),
    'out':tf.Variable(tf.random_normal([n_classes])), 
    }
print("Network is Ready")


#前向传播
def multilayer_perceptrin(_X,_weights,_biases):
  layer1 = tf.nn.sigmoid(tf.add(tf.matmul(_X,_weights['w1']),_biases['b1']))
  layer2 = tf.nn.sigmoid(tf.add(tf.matmul(layer1,_weights['w2']),_biases['b2']))
  return (tf.matmul(layer2,_weights['out'])+_biases['out'])

#定义优化函数,精准度等
pred = multilayer_perceptrin(x,weights,biases)
cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits = pred,labels=y))
optm = tf.train.GradientDescentOptimizer(learning_rate = 0.001).minimize(cost)
corr = tf.equal(tf.argmax(pred,1),tf.argmax(y,1))
accr = tf.reduce_mean(tf.cast(corr,"float"))
print("Functions is ready")

#定义超参数
training_epochs = 80
batch_size = 200
display_step = 4

#会话开始
init = tf.global_variables_initializer()
sess = tf.Session()
sess.run(init)

#优化
for epoch in range(training_epochs):
  avg_cost=0.
  total_batch = int(mnist.train.num_examples/batch_size)

  for i in range(total_batch):
    batch_xs,batch_ys = mnist.train.next_batch(batch_size)
    feeds = {x:batch_xs,y:batch_ys}
    sess.run(optm,feed_dict = feeds)
    avg_cost += sess.run(cost,feed_dict=feeds)
  avg_cost = avg_cost/total_batch

  if (epoch+1) % display_step ==0:
    print("Epoch:%03d/%03d cost:%.9f"%(epoch,training_epochs,avg_cost))
    feeds = {x:batch_xs,y:batch_ys}
    train_acc = sess.run(accr,feed_dict = feeds)
    print("Train accuracy:%.3f"%(train_acc))
    feeds = {x:mnist.test.images,y:mnist.test.labels}
    test_acc = sess.run(accr,feed_dict = feeds)
    print("Test accuracy:%.3f"%(test_acc))
print("Optimization Finished")

程序部分运行结果如下:

Train accuracy:0.605
Test accuracy:0.633
Epoch:071/080 cost:1.810029302
Train accuracy:0.600
Test accuracy:0.645
Epoch:075/080 cost:1.761531130
Train accuracy:0.690
Test accuracy:0.649
Epoch:079/080 cost:1.711757494
Train accuracy:0.640
Test accuracy:0.660
Optimization Finished

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
在Python中使用__slots__方法的详细教程
Apr 28 Python
Python Socket使用实例
Dec 18 Python
一份python入门应该看的学习资料
Apr 11 Python
Python实现删除时保留特定文件夹和文件的示例
Apr 27 Python
使用PyCharm进行远程开发和调试的实现
Nov 04 Python
python-numpy-指数分布实例详解
Dec 07 Python
以SQLite和PySqlite为例来学习Python DB API
Feb 05 Python
pycharm设置python文件模板信息过程图解
Mar 10 Python
Python捕获异常堆栈信息的几种方法(小结)
May 18 Python
Python代码需要缩进吗
Jul 01 Python
python3.8动态人脸识别的实现示例
Sep 21 Python
Python用SSH连接到网络设备
Feb 18 Python
Python unittest单元测试框架总结
Sep 08 #Python
tensorflow实现加载mnist数据集
Sep 08 #Python
使用tensorflow实现线性回归
Sep 08 #Python
Python  unittest单元测试框架的使用
Sep 08 #Python
tensorflow实现逻辑回归模型
Sep 08 #Python
Django实现表单验证
Sep 08 #Python
python实现排序算法解析
Sep 08 #Python
You might like
基于PHP后台的Android新闻浏览客户端
2016/05/23 PHP
使用JavaScript创建新样式表和新样式规则
2016/06/14 PHP
PHP合并数组的2种方法小结
2016/11/24 PHP
PHP常见数组排序方法小结
2018/08/20 PHP
用window.location.href实现刷新另个框架页面
2007/03/07 Javascript
js 效率组装字符串 StringBuffer
2009/12/23 Javascript
jquery 经典动画菜单效果代码
2010/01/26 Javascript
TextArea不支持maxlength的解决办法(jquery)
2011/09/13 Javascript
解析瀑布流布局:JS+绝对定位的实现
2013/05/08 Javascript
详解AngularJS中的表格使用
2015/06/16 Javascript
Angular 根据 service 的状态更新 directive
2016/04/03 Javascript
js微信扫描二维码登录网站技术原理
2016/12/01 Javascript
微信小程序实现导航栏选项卡效果
2020/06/19 Javascript
JS/HTML5游戏常用算法之碰撞检测 地图格子算法实例详解
2018/12/12 Javascript
Vue唯一可以更改vuex实例中state数据状态的属性对象Mutation的讲解
2019/01/18 Javascript
Python修改MP3文件的方法
2015/06/15 Python
使用python批量化音乐文件格式转换的实例
2019/01/09 Python
pyside+pyqt实现鼠标右键菜单功能
2020/12/08 Python
安装多个版本的TensorFlow的方法步骤
2020/04/21 Python
Python如何使用27行代码绘制星星图
2020/07/20 Python
flask开启多线程的具体方法
2020/08/02 Python
Python自动创建Excel并获取内容
2020/09/16 Python
python 获取计算机的网卡信息
2021/02/18 Python
Kickers鞋英国官网:男士、女士和儿童鞋
2021/03/08 全球购物
企划经理的岗位职责
2013/11/17 职场文书
父亲追悼会答谢词
2014/01/17 职场文书
自荐信写法介绍
2014/01/25 职场文书
程序员求职信
2014/04/16 职场文书
会计系毕业生求职信
2014/05/28 职场文书
煤矿安全协议书
2014/08/20 职场文书
老公写给老婆的检讨书
2015/05/06 职场文书
《刷子李》教学反思
2016/02/20 职场文书
Redis数据结构之链表与字典的使用
2021/05/11 Redis
Golang标准库syscall详解(什么是系统调用)
2021/05/25 Golang
Pytorch 如何加速Dataloader提升数据读取速度
2021/05/28 Python
Java各种比较对象的方式的对比总结
2021/06/20 Java/Android