Python tensorflow实现mnist手写数字识别示例【非卷积与卷积实现】


Posted in Python onDecember 19, 2019

本文实例讲述了Python tensorflow实现mnist手写数字识别。分享给大家供大家参考,具体如下:

非卷积实现

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
data_path = 'F:\CNN\data\mnist'
mnist_data = input_data.read_data_sets(data_path,one_hot=True) #offline dataset
x_data = tf.placeholder("float32", [None, 784]) # None means we can import any number of images
weight = tf.Variable(tf.ones([784,10]))
bias = tf.Variable(tf.ones([10]))
Y_model = tf.nn.softmax(tf.matmul(x_data ,weight) + bias)
#Y_model = tf.nn.sigmoid(tf.matmul(x_data ,weight) + bias)
'''
weight1 = tf.Variable(tf.ones([784,256]))
bias1 = tf.Variable(tf.ones([256]))
Y_model1 = tf.nn.softmax(tf.matmul(x_data ,weight1) + bias1)
weight1 = tf.Variable(tf.ones([256,10]))
bias1 = tf.Variable(tf.ones([10]))
Y_model = tf.nn.softmax(tf.matmul(Y_model1 ,weight1) + bias1)
'''
y_data = tf.placeholder("float32", [None, 10])
loss = tf.reduce_sum(tf.pow((y_data - Y_model), 2 ))#92%-93%
#loss = tf.reduce_sum(tf.square(y_data - Y_model)) #90%-91%
optimizer = tf.train.GradientDescentOptimizer(0.01)
train = optimizer.minimize(loss)
init = tf.global_variables_initializer()
sess = tf.Session()
sess.run(init) # reset values to wrong
for i in range(100000):
  batch_xs, batch_ys = mnist_data.train.next_batch(50)
  sess.run(train, feed_dict = {x_data: batch_xs, y_data: batch_ys})
  if i%50==0:
    correct_predict = tf.equal(tf.arg_max(Y_model,1),tf.argmax(y_data,1))
    accurate = tf.reduce_mean(tf.cast(correct_predict,"float"))
    print(sess.run(accurate,feed_dict={x_data:mnist_data.test.images,y_data:mnist_data.test.labels}))

卷积实现

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
data_path = 'F:\CNN\data\mnist'
mnist_data = input_data.read_data_sets(data_path,one_hot=True) #offline dataset
x_data = tf.placeholder("float32", [None, 784]) # None means we can import any number of images
x_image = tf.reshape(x_data, [-1,28,28,1])
w_conv = tf.Variable(tf.ones([5,5,1,32])) #weight
b_conv = tf.Variable(tf.ones([32]))    #bias
h_conv = tf.nn.relu(tf.nn.conv2d(x_image , w_conv,strides=[1,1,1,1],padding='SAME')+ b_conv)
h_pool = tf.nn.max_pool(h_conv,ksize=[1,2,2,1],strides=[1,2,2,1],padding='SAME')
w_fc = tf.Variable(tf.ones([14*14*32,1024]))
b_fc = tf.Variable(tf.ones([1024]))
h_pool_flat = tf.reshape(h_pool,[-1,14*14*32])
h_fc = tf.nn.relu(tf.matmul(h_pool_flat,w_fc) +b_fc)
W_fc = w_fc = tf.Variable(tf.ones([1024,10]))
B_fc = tf.Variable(tf.ones([10]))
Y_model = tf.nn.softmax(tf.matmul(h_fc,W_fc) +B_fc)
y_data = tf.placeholder("float32",[None,10])
loss = -tf.reduce_sum(y_data * tf.log(Y_model))
train_step = tf.train.GradientDescentOptimizer(0.01).minimize(loss)
init = tf.initialize_all_variables()
sess = tf.Session()
sess.run(init)
for i in range(1000):
  batch_xs,batch_ys =mnist_data.train.next_batch(5)
  sess.run(train_step,feed_dict={x_data:batch_xs,y_data:batch_ys})
  if i%50==0:
    correct_prediction = tf.equal(tf.argmax(Y_model,1),tf.argmax(y_data,1))
    accuracy = tf.reduce_mean(tf.cast(correct_prediction,"float"))
    print(sess.run(accuracy,feed_dict={x_data:mnist_data.test.images,y_data:mnist_data.test.labels}))

更多关于Python相关内容可查看本站专题:《Python数学运算技巧总结》、《Python图片操作技巧总结》、《Python数据结构与算法教程》、《Python函数使用技巧总结》、《Python字符串操作技巧汇总》及《Python入门与进阶经典教程》

希望本文所述对大家Python程序设计有所帮助。

Python 相关文章推荐
使用Python脚本来控制Windows Azure的简单教程
Apr 16 Python
Scrapy爬虫实例讲解_校花网
Oct 23 Python
python实现n个数中选出m个数的方法
Nov 13 Python
浅谈Python基础—判断和循环
Mar 22 Python
wxPython+Matplotlib绘制折线图表
Nov 19 Python
windows下python安装pip方法详解
Feb 10 Python
Python3搭建http服务器的实现代码
Feb 11 Python
python和pywin32实现窗口查找、遍历和点击的示例代码
Apr 01 Python
Pygame框架实现飞机大战
Aug 07 Python
python利用opencv实现颜色检测
Feb 23 Python
pytorch Dataset,DataLoader产生自定义的训练数据案例
Mar 03 Python
Python3.10的一些新特性原理分析
Sep 15 Python
Python: 传递列表副本方式
Dec 19 #Python
python内置模块collections知识点总结
Dec 19 #Python
Python操作redis和mongoDB的方法
Dec 19 #Python
Python 实现Serial 与STM32J进行串口通讯
Dec 18 #Python
实现Python与STM32通信方式
Dec 18 #Python
利用pandas将非数值数据转换成数值的方式
Dec 18 #Python
python 浅谈serial与stm32通信的编码问题
Dec 18 #Python
You might like
域名查询代码公布
2006/10/09 PHP
《APMServ 5.1.2》使用图解
2006/10/23 PHP
php仿QQ验证码的实例分析
2013/07/01 PHP
基于thinkPHP实现的微信自定义分享功能示例
2016/09/23 PHP
php 广告点击统计代码(php+mysql)
2018/02/21 PHP
jQuery仿Flash上下翻动的中英文导航菜单实例
2015/03/10 Javascript
浅谈javascript的分号的使用
2015/05/12 Javascript
基于dropdown.js实现的两款美观大气的二级导航菜单
2015/09/02 Javascript
JS遍历页面所有对象属性及实现方法
2016/08/01 Javascript
JavaScript学习笔记整理_setTimeout的应用
2016/09/19 Javascript
js异步上传多张图片插件的使用方法
2018/10/22 Javascript
vue中的适配px2rem示例代码
2018/11/19 Javascript
Vue CL3 配置路径别名详解
2019/05/30 Javascript
JQuery实现简单的复选框树形结构图示例【附源码下载】
2019/07/16 jQuery
Vue全局loading及错误提示的思路与实现
2019/08/09 Javascript
Python 包含汉字的文件读写之每行末尾加上特定字符
2016/12/12 Python
Flask框架使用DBUtils模块连接数据库操作示例
2018/07/20 Python
Python实现购物评论文本情感分析操作【基于中文文本挖掘库snownlp】
2018/08/07 Python
Python自定义一个异常类的方法
2019/06/27 Python
用python求一个数组的和与平均值的实现方法
2019/06/29 Python
pytorch中的embedding词向量的使用方法
2019/08/18 Python
Python调用Redis的示例代码
2020/11/24 Python
草莓网化妆品澳大利亚站:Strawberrynet AU
2017/12/18 全球购物
荷兰音乐会和音乐剧门票订购网站:Topticketshop
2019/08/27 全球购物
为什么需要版本控制?
2013/08/08 面试题
高中生学习生活的自我评价
2013/11/27 职场文书
毕业典礼主持词大全
2014/03/26 职场文书
铁路安全事故反思
2014/04/26 职场文书
学校节能减排倡议书
2014/05/16 职场文书
政治学求职信
2014/06/03 职场文书
机械工程及自动化专业求职信
2014/09/03 职场文书
大学拉赞助协议书范文
2014/09/26 职场文书
2014年班级工作总结
2014/11/14 职场文书
论语读书笔记
2015/06/26 职场文书
庆七一晚会主持词
2015/06/30 职场文书
Java+swing实现抖音上的表白程序详解
2022/06/25 Java/Android