tensorflow使用神经网络实现mnist分类


Posted in Python onSeptember 08, 2018

本文实例为大家分享了tensorflow神经网络实现mnist分类的具体代码,供大家参考,具体内容如下

只有两层的神经网络,直接上代码

#引入包
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
#引入input_data文件
from tensorflow.examples.tutorials.mnist import input_data
#读取文件
mnist = input_data.read_data_sets('F:/mnist/data/',one_hot=True)

#定义第一个隐藏层和第二个隐藏层,输入层输出层
n_hidden_1 = 256
n_hidden_2 = 128
n_input = 784
n_classes = 10

#由于不知道输入图片个数,所以用placeholder
x = tf.placeholder("float",[None,n_input])
y = tf.placeholder("float",[None,n_classes])

stddev = 0.1

#定义权重
weights = {
    'w1':tf.Variable(tf.random_normal([n_input,n_hidden_1],stddev = stddev)),
    'w2':tf.Variable(tf.random_normal([n_hidden_1,n_hidden_2],stddev=stddev)),
    'out':tf.Variable(tf.random_normal([n_hidden_2,n_classes],stddev=stddev))    
    }

#定义偏置
biases = {
    'b1':tf.Variable(tf.random_normal([n_hidden_1])),
    'b2':tf.Variable(tf.random_normal([n_hidden_2])),
    'out':tf.Variable(tf.random_normal([n_classes])), 
    }
print("Network is Ready")


#前向传播
def multilayer_perceptrin(_X,_weights,_biases):
  layer1 = tf.nn.sigmoid(tf.add(tf.matmul(_X,_weights['w1']),_biases['b1']))
  layer2 = tf.nn.sigmoid(tf.add(tf.matmul(layer1,_weights['w2']),_biases['b2']))
  return (tf.matmul(layer2,_weights['out'])+_biases['out'])

#定义优化函数,精准度等
pred = multilayer_perceptrin(x,weights,biases)
cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits = pred,labels=y))
optm = tf.train.GradientDescentOptimizer(learning_rate = 0.001).minimize(cost)
corr = tf.equal(tf.argmax(pred,1),tf.argmax(y,1))
accr = tf.reduce_mean(tf.cast(corr,"float"))
print("Functions is ready")

#定义超参数
training_epochs = 80
batch_size = 200
display_step = 4

#会话开始
init = tf.global_variables_initializer()
sess = tf.Session()
sess.run(init)

#优化
for epoch in range(training_epochs):
  avg_cost=0.
  total_batch = int(mnist.train.num_examples/batch_size)

  for i in range(total_batch):
    batch_xs,batch_ys = mnist.train.next_batch(batch_size)
    feeds = {x:batch_xs,y:batch_ys}
    sess.run(optm,feed_dict = feeds)
    avg_cost += sess.run(cost,feed_dict=feeds)
  avg_cost = avg_cost/total_batch

  if (epoch+1) % display_step ==0:
    print("Epoch:%03d/%03d cost:%.9f"%(epoch,training_epochs,avg_cost))
    feeds = {x:batch_xs,y:batch_ys}
    train_acc = sess.run(accr,feed_dict = feeds)
    print("Train accuracy:%.3f"%(train_acc))
    feeds = {x:mnist.test.images,y:mnist.test.labels}
    test_acc = sess.run(accr,feed_dict = feeds)
    print("Test accuracy:%.3f"%(test_acc))
print("Optimization Finished")

程序部分运行结果如下:

Train accuracy:0.605
Test accuracy:0.633
Epoch:071/080 cost:1.810029302
Train accuracy:0.600
Test accuracy:0.645
Epoch:075/080 cost:1.761531130
Train accuracy:0.690
Test accuracy:0.649
Epoch:079/080 cost:1.711757494
Train accuracy:0.640
Test accuracy:0.660
Optimization Finished

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python与sqlite3实现解密chrome cookie实例代码
Jan 20 Python
在CMD命令行中运行python脚本的方法
May 12 Python
pandas分别写入excel的不同sheet方法
Dec 11 Python
使用Python创建简单的HTTP服务器的方法步骤
Apr 26 Python
Python完全识别验证码自动登录实例详解
Nov 24 Python
python爬虫开发之使用Python爬虫库requests多线程抓取猫眼电影TOP100实例
Mar 10 Python
Python更换pip源方法过程解析
May 19 Python
python怎么自定义捕获错误
Jun 29 Python
Jupyter安装链接aconda实现过程图解
Nov 02 Python
安装不同版本的tensorflow与models方法实现
Feb 20 Python
python3.9之你应该知道的新特性详解
Apr 29 Python
PYTHON InceptionV3模型的复现详解
May 06 Python
Python unittest单元测试框架总结
Sep 08 #Python
tensorflow实现加载mnist数据集
Sep 08 #Python
使用tensorflow实现线性回归
Sep 08 #Python
Python  unittest单元测试框架的使用
Sep 08 #Python
tensorflow实现逻辑回归模型
Sep 08 #Python
Django实现表单验证
Sep 08 #Python
python实现排序算法解析
Sep 08 #Python
You might like
PHP图片处理之图片旋转和图片翻转实例
2014/11/19 PHP
PHP实现获取并生成数据库字典的方法
2016/05/04 PHP
PHP面向对象之事务脚本模式(详解)
2017/06/07 PHP
PHP中命名空间的使用例子
2019/03/22 PHP
html数组字符串拼接的最快方法
2009/09/16 Javascript
jQuery ajax BUG:object doesn't support this property or method
2010/07/06 Javascript
javascript开发技术大全-第1章javascript概述
2011/07/03 Javascript
javascript自适应宽度的瀑布流实现思路
2013/02/20 Javascript
关于jQuery参考实例 1.0 jQuery的哲学
2013/04/07 Javascript
jQuery编写设置和获取颜色的插件
2017/01/09 Javascript
JavaScript生成简单等差数列
2017/11/28 Javascript
p5.js 毕达哥拉斯树的实现代码
2018/03/23 Javascript
详解NodeJs开发微信公众号
2018/05/25 NodeJs
bootstrap自定义样式之bootstrap实现侧边导航栏功能
2018/09/10 Javascript
Vue+element 解决浏览器自动填充记住的账号密码问题
2019/06/11 Javascript
Vue实现点击当前行变色
2020/12/14 Vue.js
[01:01:23]完美世界DOTA2联赛PWL S2 Forest vs FTD.C 第一场 11.26
2020/11/30 DOTA
python实现html转ubb代码(html2ubb)
2014/07/03 Python
python PIL模块与随机生成中文验证码
2016/02/27 Python
对python中数组的del,remove,pop区别详解
2018/11/07 Python
使用selenium和pyquery爬取京东商品列表过程解析
2019/08/15 Python
Python cookie的保存与读取、SSL讲解
2020/02/17 Python
浅谈pytorch中的BN层的注意事项
2020/06/23 Python
Python读取多列数据以及用matplotlib制作图表方法实例
2020/09/23 Python
Python+logging输出到屏幕将log日志写入文件
2020/11/11 Python
如何用Python提取10000份log中的产品信息
2021/01/14 Python
理工大学毕业生自荐信
2013/11/01 职场文书
父亲生日宴会答谢词
2014/01/10 职场文书
学生打架检讨书
2014/10/20 职场文书
试用期辞职信范文
2015/03/02 职场文书
工作感想范文
2015/08/07 职场文书
合同补充协议书
2016/03/24 职场文书
ROS系统将python包编译为可执行文件的简单步骤
2021/07/25 Python
图文详解nginx日志切割的实现
2022/01/18 Servers
nginx 配置缓存
2022/05/11 Servers
openEuler 搭建java开发环境的详细过程
2022/06/10 Servers