tensorflow使用神经网络实现mnist分类


Posted in Python onSeptember 08, 2018

本文实例为大家分享了tensorflow神经网络实现mnist分类的具体代码,供大家参考,具体内容如下

只有两层的神经网络,直接上代码

#引入包
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
#引入input_data文件
from tensorflow.examples.tutorials.mnist import input_data
#读取文件
mnist = input_data.read_data_sets('F:/mnist/data/',one_hot=True)

#定义第一个隐藏层和第二个隐藏层,输入层输出层
n_hidden_1 = 256
n_hidden_2 = 128
n_input = 784
n_classes = 10

#由于不知道输入图片个数,所以用placeholder
x = tf.placeholder("float",[None,n_input])
y = tf.placeholder("float",[None,n_classes])

stddev = 0.1

#定义权重
weights = {
    'w1':tf.Variable(tf.random_normal([n_input,n_hidden_1],stddev = stddev)),
    'w2':tf.Variable(tf.random_normal([n_hidden_1,n_hidden_2],stddev=stddev)),
    'out':tf.Variable(tf.random_normal([n_hidden_2,n_classes],stddev=stddev))    
    }

#定义偏置
biases = {
    'b1':tf.Variable(tf.random_normal([n_hidden_1])),
    'b2':tf.Variable(tf.random_normal([n_hidden_2])),
    'out':tf.Variable(tf.random_normal([n_classes])), 
    }
print("Network is Ready")


#前向传播
def multilayer_perceptrin(_X,_weights,_biases):
  layer1 = tf.nn.sigmoid(tf.add(tf.matmul(_X,_weights['w1']),_biases['b1']))
  layer2 = tf.nn.sigmoid(tf.add(tf.matmul(layer1,_weights['w2']),_biases['b2']))
  return (tf.matmul(layer2,_weights['out'])+_biases['out'])

#定义优化函数,精准度等
pred = multilayer_perceptrin(x,weights,biases)
cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits = pred,labels=y))
optm = tf.train.GradientDescentOptimizer(learning_rate = 0.001).minimize(cost)
corr = tf.equal(tf.argmax(pred,1),tf.argmax(y,1))
accr = tf.reduce_mean(tf.cast(corr,"float"))
print("Functions is ready")

#定义超参数
training_epochs = 80
batch_size = 200
display_step = 4

#会话开始
init = tf.global_variables_initializer()
sess = tf.Session()
sess.run(init)

#优化
for epoch in range(training_epochs):
  avg_cost=0.
  total_batch = int(mnist.train.num_examples/batch_size)

  for i in range(total_batch):
    batch_xs,batch_ys = mnist.train.next_batch(batch_size)
    feeds = {x:batch_xs,y:batch_ys}
    sess.run(optm,feed_dict = feeds)
    avg_cost += sess.run(cost,feed_dict=feeds)
  avg_cost = avg_cost/total_batch

  if (epoch+1) % display_step ==0:
    print("Epoch:%03d/%03d cost:%.9f"%(epoch,training_epochs,avg_cost))
    feeds = {x:batch_xs,y:batch_ys}
    train_acc = sess.run(accr,feed_dict = feeds)
    print("Train accuracy:%.3f"%(train_acc))
    feeds = {x:mnist.test.images,y:mnist.test.labels}
    test_acc = sess.run(accr,feed_dict = feeds)
    print("Test accuracy:%.3f"%(test_acc))
print("Optimization Finished")

程序部分运行结果如下:

Train accuracy:0.605
Test accuracy:0.633
Epoch:071/080 cost:1.810029302
Train accuracy:0.600
Test accuracy:0.645
Epoch:075/080 cost:1.761531130
Train accuracy:0.690
Test accuracy:0.649
Epoch:079/080 cost:1.711757494
Train accuracy:0.640
Test accuracy:0.660
Optimization Finished

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python生成pdf文件的方法
Aug 04 Python
用Python实现一个简单的多线程TCP服务器的教程
May 05 Python
Python中的命令行参数解析工具之docopt详解
Mar 27 Python
Python数据分析之双色球中蓝红球分析统计示例
Feb 03 Python
Ubuntu下升级 python3.7.1流程备忘(推荐)
Dec 10 Python
Python3删除排序数组中重复项的方法分析
Jan 31 Python
Python通过cv2读取多个USB摄像头
Aug 28 Python
python字典排序的方法
Oct 12 Python
在django中自定义字段Field详解
Dec 03 Python
Python 余弦相似度与皮尔逊相关系数 计算实例
Dec 23 Python
Jupyter加载文件的实现方法
Apr 14 Python
Python进程池与进程锁之语法学习
Apr 11 Python
Python unittest单元测试框架总结
Sep 08 #Python
tensorflow实现加载mnist数据集
Sep 08 #Python
使用tensorflow实现线性回归
Sep 08 #Python
Python  unittest单元测试框架的使用
Sep 08 #Python
tensorflow实现逻辑回归模型
Sep 08 #Python
Django实现表单验证
Sep 08 #Python
python实现排序算法解析
Sep 08 #Python
You might like
评分9.0以上的动画电影,剧情除了经典还很燃
2020/03/04 日漫
咖啡语言
2021/03/03 咖啡文化
PHP获取MSN好友列表类的实现代码
2013/06/23 PHP
分享下页面关键字抓取www.icbase.com站点代码(带asp.net参数的)
2014/01/30 PHP
静态html文件执行php语句的方法(推荐)
2016/11/21 PHP
Jquery 组合form元素为json格式,asp.net反序列化
2009/07/09 Javascript
动态载入/删除/更新外部 JavaScript/Css 文件的代码
2010/07/03 Javascript
理解Javascript_11_constructor实现原理
2010/10/18 Javascript
js控制表单奇偶行样式的简单方法
2013/07/31 Javascript
JavaScript判断FileUpload控件上传文件类型
2015/09/28 Javascript
jQuery中的Deferred和promise 的区别
2016/04/03 Javascript
完美解决node.js中使用https请求报CERT_UNTRUSTED的问题
2017/01/08 Javascript
JS控件bootstrap suggest plugin使用方法详解
2017/03/25 Javascript
shiro授权的实现原理
2017/09/21 Javascript
浅谈react-router HashRouter和BrowserRouter的使用
2017/12/29 Javascript
用 js 写一个 js 解释器过程详解
2019/08/02 Javascript
antd-日历组件,前后禁止选择,只能选中间一部分的实例
2020/10/29 Javascript
Python 流程控制实例代码
2009/09/25 Python
Python RuntimeError: thread.__init__() not called解决方法
2015/04/28 Python
python结合API实现即时天气信息
2016/01/19 Python
Python利用operator模块实现对象的多级排序详解
2017/05/09 Python
Python实现的三层BP神经网络算法示例
2018/02/07 Python
Python设计模式之抽象工厂模式原理与用法详解
2019/01/15 Python
Python Excel处理库openpyxl使用详解
2019/05/09 Python
python程序快速缩进多行代码方法总结
2019/06/23 Python
python机器学习实现决策树
2019/11/11 Python
python实现将json多行数据传入到mysql中使用
2019/12/31 Python
浅谈Python线程的同步互斥与死锁
2020/03/22 Python
浅谈HTML5 FileReader分布读取文件以及其方法简介
2017/11/09 HTML / CSS
娇韵诗加拿大官网:Clarins加拿大
2017/11/20 全球购物
服务员自我评价
2014/01/25 职场文书
幼儿园小班植树节活动方案
2014/03/04 职场文书
教师节促销方案
2014/03/22 职场文书
大学组织委员竞选稿
2015/11/21 职场文书
职工的安全责任书范文!
2019/07/02 职场文书
HTML5 语义化标签(移动端必备)
2021/08/23 HTML / CSS