tensorflow使用神经网络实现mnist分类


Posted in Python onSeptember 08, 2018

本文实例为大家分享了tensorflow神经网络实现mnist分类的具体代码,供大家参考,具体内容如下

只有两层的神经网络,直接上代码

#引入包
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
#引入input_data文件
from tensorflow.examples.tutorials.mnist import input_data
#读取文件
mnist = input_data.read_data_sets('F:/mnist/data/',one_hot=True)

#定义第一个隐藏层和第二个隐藏层,输入层输出层
n_hidden_1 = 256
n_hidden_2 = 128
n_input = 784
n_classes = 10

#由于不知道输入图片个数,所以用placeholder
x = tf.placeholder("float",[None,n_input])
y = tf.placeholder("float",[None,n_classes])

stddev = 0.1

#定义权重
weights = {
    'w1':tf.Variable(tf.random_normal([n_input,n_hidden_1],stddev = stddev)),
    'w2':tf.Variable(tf.random_normal([n_hidden_1,n_hidden_2],stddev=stddev)),
    'out':tf.Variable(tf.random_normal([n_hidden_2,n_classes],stddev=stddev))    
    }

#定义偏置
biases = {
    'b1':tf.Variable(tf.random_normal([n_hidden_1])),
    'b2':tf.Variable(tf.random_normal([n_hidden_2])),
    'out':tf.Variable(tf.random_normal([n_classes])), 
    }
print("Network is Ready")


#前向传播
def multilayer_perceptrin(_X,_weights,_biases):
  layer1 = tf.nn.sigmoid(tf.add(tf.matmul(_X,_weights['w1']),_biases['b1']))
  layer2 = tf.nn.sigmoid(tf.add(tf.matmul(layer1,_weights['w2']),_biases['b2']))
  return (tf.matmul(layer2,_weights['out'])+_biases['out'])

#定义优化函数,精准度等
pred = multilayer_perceptrin(x,weights,biases)
cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits = pred,labels=y))
optm = tf.train.GradientDescentOptimizer(learning_rate = 0.001).minimize(cost)
corr = tf.equal(tf.argmax(pred,1),tf.argmax(y,1))
accr = tf.reduce_mean(tf.cast(corr,"float"))
print("Functions is ready")

#定义超参数
training_epochs = 80
batch_size = 200
display_step = 4

#会话开始
init = tf.global_variables_initializer()
sess = tf.Session()
sess.run(init)

#优化
for epoch in range(training_epochs):
  avg_cost=0.
  total_batch = int(mnist.train.num_examples/batch_size)

  for i in range(total_batch):
    batch_xs,batch_ys = mnist.train.next_batch(batch_size)
    feeds = {x:batch_xs,y:batch_ys}
    sess.run(optm,feed_dict = feeds)
    avg_cost += sess.run(cost,feed_dict=feeds)
  avg_cost = avg_cost/total_batch

  if (epoch+1) % display_step ==0:
    print("Epoch:%03d/%03d cost:%.9f"%(epoch,training_epochs,avg_cost))
    feeds = {x:batch_xs,y:batch_ys}
    train_acc = sess.run(accr,feed_dict = feeds)
    print("Train accuracy:%.3f"%(train_acc))
    feeds = {x:mnist.test.images,y:mnist.test.labels}
    test_acc = sess.run(accr,feed_dict = feeds)
    print("Test accuracy:%.3f"%(test_acc))
print("Optimization Finished")

程序部分运行结果如下:

Train accuracy:0.605
Test accuracy:0.633
Epoch:071/080 cost:1.810029302
Train accuracy:0.600
Test accuracy:0.645
Epoch:075/080 cost:1.761531130
Train accuracy:0.690
Test accuracy:0.649
Epoch:079/080 cost:1.711757494
Train accuracy:0.640
Test accuracy:0.660
Optimization Finished

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python实例分享:快速查找出被挂马的文件
Jun 08 Python
python实现从一组颜色中找出与给定颜色最接近颜色的方法
Mar 19 Python
详解Python中的多线程编程
Apr 09 Python
详解python3百度指数抓取实例
Dec 12 Python
Python2与Python3的区别实例分析
Apr 11 Python
python3利用Socket实现通信的方法示例
May 06 Python
python3的print()函数的用法图文讲解
Jul 16 Python
python word转pdf代码实例
Aug 16 Python
django ListView的使用 ListView中获取url中的参数值方式
Mar 27 Python
Python __slots__的使用方法
Nov 15 Python
python中time包实例详解
Feb 02 Python
Python 可迭代对象 iterable的具体使用
Aug 07 Python
Python unittest单元测试框架总结
Sep 08 #Python
tensorflow实现加载mnist数据集
Sep 08 #Python
使用tensorflow实现线性回归
Sep 08 #Python
Python  unittest单元测试框架的使用
Sep 08 #Python
tensorflow实现逻辑回归模型
Sep 08 #Python
Django实现表单验证
Sep 08 #Python
python实现排序算法解析
Sep 08 #Python
You might like
php断点续传之如何分割合并文件
2014/03/22 PHP
ThinkPHP关于session的操作方法汇总
2014/07/18 PHP
PHP实现过滤各种HTML标签
2015/05/17 PHP
PHP调用API接口实现天气查询功能的示例
2017/09/21 PHP
详细解读php的命名空间(二)
2018/02/21 PHP
javascript import css实例代码
2008/07/18 Javascript
javascript之typeof、instanceof操作符使用探讨
2013/05/19 Javascript
jQuery获得内容和属性示例代码
2014/01/16 Javascript
让alert不出现弹窗的两种方法
2014/05/18 Javascript
让checkbox不选中即将选中的checkbox不选中
2014/07/11 Javascript
angularjs自定义ng-model标签的属性
2016/01/21 Javascript
如何让一个json文件显示在表格里【实现代码】
2016/05/09 Javascript
jQuery实现限制文本框的输入长度
2017/01/11 Javascript
canvas 实现中国象棋
2017/02/17 Javascript
详解vue-cli + webpack 多页面实例应用
2017/04/25 Javascript
微信小程序 自定义消息提示框
2017/08/06 Javascript
Vue 拦截器对token过期处理方法
2018/01/23 Javascript
JavaScript基础心法 深浅拷贝(浅拷贝和深拷贝)
2018/03/05 Javascript
微信小程序全局变量功能与用法详解
2019/01/22 Javascript
python 设置输出图像的像素大小方法
2019/07/04 Python
如何验证python安装成功
2020/07/06 Python
css3动画事件—webkitAnimationEnd与计时器time事件
2013/01/31 HTML / CSS
物业管理求职自荐信
2013/09/25 职场文书
自荐信范文
2013/12/10 职场文书
优秀员工年终发言演讲稿
2014/01/01 职场文书
禁毒宣传工作方案
2014/05/23 职场文书
信用卡结清证明怎么写
2014/09/13 职场文书
预备党员群众路线教育实践活动思想汇报2014
2014/10/25 职场文书
优秀班主任材料
2014/12/16 职场文书
先进员工事迹材料
2014/12/20 职场文书
2015年设计师个人工作总结
2015/04/25 职场文书
英文投诉信格式
2015/07/03 职场文书
2015年校医个人工作总结
2015/07/24 职场文书
Java中常用解析工具jackson及fastjson的使用
2021/06/28 Java/Android
Python图像处理库PIL详细使用说明
2022/04/06 Python
nginx静态资源的服务器配置方法
2022/07/07 Servers