tensorflow使用神经网络实现mnist分类


Posted in Python onSeptember 08, 2018

本文实例为大家分享了tensorflow神经网络实现mnist分类的具体代码,供大家参考,具体内容如下

只有两层的神经网络,直接上代码

#引入包
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
#引入input_data文件
from tensorflow.examples.tutorials.mnist import input_data
#读取文件
mnist = input_data.read_data_sets('F:/mnist/data/',one_hot=True)

#定义第一个隐藏层和第二个隐藏层,输入层输出层
n_hidden_1 = 256
n_hidden_2 = 128
n_input = 784
n_classes = 10

#由于不知道输入图片个数,所以用placeholder
x = tf.placeholder("float",[None,n_input])
y = tf.placeholder("float",[None,n_classes])

stddev = 0.1

#定义权重
weights = {
    'w1':tf.Variable(tf.random_normal([n_input,n_hidden_1],stddev = stddev)),
    'w2':tf.Variable(tf.random_normal([n_hidden_1,n_hidden_2],stddev=stddev)),
    'out':tf.Variable(tf.random_normal([n_hidden_2,n_classes],stddev=stddev))    
    }

#定义偏置
biases = {
    'b1':tf.Variable(tf.random_normal([n_hidden_1])),
    'b2':tf.Variable(tf.random_normal([n_hidden_2])),
    'out':tf.Variable(tf.random_normal([n_classes])), 
    }
print("Network is Ready")


#前向传播
def multilayer_perceptrin(_X,_weights,_biases):
  layer1 = tf.nn.sigmoid(tf.add(tf.matmul(_X,_weights['w1']),_biases['b1']))
  layer2 = tf.nn.sigmoid(tf.add(tf.matmul(layer1,_weights['w2']),_biases['b2']))
  return (tf.matmul(layer2,_weights['out'])+_biases['out'])

#定义优化函数,精准度等
pred = multilayer_perceptrin(x,weights,biases)
cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits = pred,labels=y))
optm = tf.train.GradientDescentOptimizer(learning_rate = 0.001).minimize(cost)
corr = tf.equal(tf.argmax(pred,1),tf.argmax(y,1))
accr = tf.reduce_mean(tf.cast(corr,"float"))
print("Functions is ready")

#定义超参数
training_epochs = 80
batch_size = 200
display_step = 4

#会话开始
init = tf.global_variables_initializer()
sess = tf.Session()
sess.run(init)

#优化
for epoch in range(training_epochs):
  avg_cost=0.
  total_batch = int(mnist.train.num_examples/batch_size)

  for i in range(total_batch):
    batch_xs,batch_ys = mnist.train.next_batch(batch_size)
    feeds = {x:batch_xs,y:batch_ys}
    sess.run(optm,feed_dict = feeds)
    avg_cost += sess.run(cost,feed_dict=feeds)
  avg_cost = avg_cost/total_batch

  if (epoch+1) % display_step ==0:
    print("Epoch:%03d/%03d cost:%.9f"%(epoch,training_epochs,avg_cost))
    feeds = {x:batch_xs,y:batch_ys}
    train_acc = sess.run(accr,feed_dict = feeds)
    print("Train accuracy:%.3f"%(train_acc))
    feeds = {x:mnist.test.images,y:mnist.test.labels}
    test_acc = sess.run(accr,feed_dict = feeds)
    print("Test accuracy:%.3f"%(test_acc))
print("Optimization Finished")

程序部分运行结果如下:

Train accuracy:0.605
Test accuracy:0.633
Epoch:071/080 cost:1.810029302
Train accuracy:0.600
Test accuracy:0.645
Epoch:075/080 cost:1.761531130
Train accuracy:0.690
Test accuracy:0.649
Epoch:079/080 cost:1.711757494
Train accuracy:0.640
Test accuracy:0.660
Optimization Finished

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
举例讲解Django中数据模型访问外键值的方法
Jul 21 Python
Python中urllib+urllib2+cookielib模块编写爬虫实战
Jan 20 Python
深入讲解Python函数中参数的使用及默认参数的陷阱
Mar 13 Python
python让列表倒序输出的实例
Jun 25 Python
Python3.4学习笔记之常用操作符,条件分支和循环用法示例
Mar 01 Python
详解Python读取yaml文件多层菜单
Mar 23 Python
Django框架使用mysql视图操作示例
May 15 Python
python3在同一行内输入n个数并用列表保存的例子
Jul 20 Python
Python Opencv提取图片中某种颜色组成的图形的方法
Sep 19 Python
如何使用python进行pdf文件分割
Nov 11 Python
使用python matplotlib 画图导入到word中如何保证分辨率
Apr 16 Python
Pycharm打开已有项目配置python环境的方法
Jul 03 Python
Python unittest单元测试框架总结
Sep 08 #Python
tensorflow实现加载mnist数据集
Sep 08 #Python
使用tensorflow实现线性回归
Sep 08 #Python
Python  unittest单元测试框架的使用
Sep 08 #Python
tensorflow实现逻辑回归模型
Sep 08 #Python
Django实现表单验证
Sep 08 #Python
python实现排序算法解析
Sep 08 #Python
You might like
php学习之数据类型之间的转换代码
2011/05/29 PHP
php中用于检测一个地理IP地址是否可用的代码
2012/02/19 PHP
解析php file_exists无效的解决办法
2013/06/26 PHP
Yii2实现同时搜索多个字段的方法
2016/08/10 PHP
PHP手机号码及邮箱正则表达式实例解析
2020/07/11 PHP
jquery下checked取值问题的解决方法
2012/08/09 Javascript
异步动态加载js与css文件的js代码
2013/09/15 Javascript
使用js完成节点的增删改复制等的操作
2014/01/02 Javascript
Jquery.Form 异步提交表单的简单实例
2014/03/03 Javascript
js中substring和substr的定义和用法
2014/05/05 Javascript
JS实现淡蓝色简洁竖向Tab点击切换效果
2015/10/06 Javascript
基于BootStrap栅格栏系统完成网站底部版权信息区
2016/12/23 Javascript
vue检测对象和数组的变化分析
2018/06/30 Javascript
layer扩展打开/关闭动画的方法
2019/09/23 Javascript
vue 解决移动端弹出键盘导致页面fixed布局错乱的问题
2019/11/06 Javascript
vue 获取元素额外生成的data-v-xxx操作
2020/09/09 Javascript
忘记ftp密码使用python ftplib库暴力破解密码的方法示例
2014/01/22 Python
Python中针对函数处理的特殊方法
2014/03/06 Python
Python学习笔记之迭代器和生成器用法实例详解
2019/08/08 Python
详解numpy矩阵的创建与数据类型
2019/10/18 Python
tensorflow模型保存、加载之变量重命名实例
2020/01/21 Python
tensorflow:指定gpu 限制使用量百分比,设置最小使用量的实现
2020/02/06 Python
使用SQLAlchemy操作数据库表过程解析
2020/06/10 Python
python 密码学示例——理解哈希(Hash)算法
2020/09/21 Python
HTML5 3D书本翻页动画的实现示例
2019/08/28 HTML / CSS
土耳其玩具商店:Toyzz Shop
2019/08/02 全球购物
电子商务专员岗位职责
2013/12/11 职场文书
五年级科学教学反思
2014/02/05 职场文书
大堂副理的岗位职责范文
2014/02/17 职场文书
共筑中国梦演讲稿
2014/04/23 职场文书
数学教研活动总结
2014/07/02 职场文书
大学生个人总结范文
2015/02/15 职场文书
财务出纳岗位职责
2015/03/31 职场文书
惹女朋友生气检讨书
2015/05/06 职场文书
2015暑期爱心支教策划书
2015/07/14 职场文书
劳动模范获奖感言
2015/07/31 职场文书