Python tensorflow实现mnist手写数字识别示例【非卷积与卷积实现】


Posted in Python onDecember 19, 2019

本文实例讲述了Python tensorflow实现mnist手写数字识别。分享给大家供大家参考,具体如下:

非卷积实现

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
data_path = 'F:\CNN\data\mnist'
mnist_data = input_data.read_data_sets(data_path,one_hot=True) #offline dataset
x_data = tf.placeholder("float32", [None, 784]) # None means we can import any number of images
weight = tf.Variable(tf.ones([784,10]))
bias = tf.Variable(tf.ones([10]))
Y_model = tf.nn.softmax(tf.matmul(x_data ,weight) + bias)
#Y_model = tf.nn.sigmoid(tf.matmul(x_data ,weight) + bias)
'''
weight1 = tf.Variable(tf.ones([784,256]))
bias1 = tf.Variable(tf.ones([256]))
Y_model1 = tf.nn.softmax(tf.matmul(x_data ,weight1) + bias1)
weight1 = tf.Variable(tf.ones([256,10]))
bias1 = tf.Variable(tf.ones([10]))
Y_model = tf.nn.softmax(tf.matmul(Y_model1 ,weight1) + bias1)
'''
y_data = tf.placeholder("float32", [None, 10])
loss = tf.reduce_sum(tf.pow((y_data - Y_model), 2 ))#92%-93%
#loss = tf.reduce_sum(tf.square(y_data - Y_model)) #90%-91%
optimizer = tf.train.GradientDescentOptimizer(0.01)
train = optimizer.minimize(loss)
init = tf.global_variables_initializer()
sess = tf.Session()
sess.run(init) # reset values to wrong
for i in range(100000):
  batch_xs, batch_ys = mnist_data.train.next_batch(50)
  sess.run(train, feed_dict = {x_data: batch_xs, y_data: batch_ys})
  if i%50==0:
    correct_predict = tf.equal(tf.arg_max(Y_model,1),tf.argmax(y_data,1))
    accurate = tf.reduce_mean(tf.cast(correct_predict,"float"))
    print(sess.run(accurate,feed_dict={x_data:mnist_data.test.images,y_data:mnist_data.test.labels}))

卷积实现

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
data_path = 'F:\CNN\data\mnist'
mnist_data = input_data.read_data_sets(data_path,one_hot=True) #offline dataset
x_data = tf.placeholder("float32", [None, 784]) # None means we can import any number of images
x_image = tf.reshape(x_data, [-1,28,28,1])
w_conv = tf.Variable(tf.ones([5,5,1,32])) #weight
b_conv = tf.Variable(tf.ones([32]))    #bias
h_conv = tf.nn.relu(tf.nn.conv2d(x_image , w_conv,strides=[1,1,1,1],padding='SAME')+ b_conv)
h_pool = tf.nn.max_pool(h_conv,ksize=[1,2,2,1],strides=[1,2,2,1],padding='SAME')
w_fc = tf.Variable(tf.ones([14*14*32,1024]))
b_fc = tf.Variable(tf.ones([1024]))
h_pool_flat = tf.reshape(h_pool,[-1,14*14*32])
h_fc = tf.nn.relu(tf.matmul(h_pool_flat,w_fc) +b_fc)
W_fc = w_fc = tf.Variable(tf.ones([1024,10]))
B_fc = tf.Variable(tf.ones([10]))
Y_model = tf.nn.softmax(tf.matmul(h_fc,W_fc) +B_fc)
y_data = tf.placeholder("float32",[None,10])
loss = -tf.reduce_sum(y_data * tf.log(Y_model))
train_step = tf.train.GradientDescentOptimizer(0.01).minimize(loss)
init = tf.initialize_all_variables()
sess = tf.Session()
sess.run(init)
for i in range(1000):
  batch_xs,batch_ys =mnist_data.train.next_batch(5)
  sess.run(train_step,feed_dict={x_data:batch_xs,y_data:batch_ys})
  if i%50==0:
    correct_prediction = tf.equal(tf.argmax(Y_model,1),tf.argmax(y_data,1))
    accuracy = tf.reduce_mean(tf.cast(correct_prediction,"float"))
    print(sess.run(accuracy,feed_dict={x_data:mnist_data.test.images,y_data:mnist_data.test.labels}))

更多关于Python相关内容可查看本站专题:《Python数学运算技巧总结》、《Python图片操作技巧总结》、《Python数据结构与算法教程》、《Python函数使用技巧总结》、《Python字符串操作技巧汇总》及《Python入门与进阶经典教程》

希望本文所述对大家Python程序设计有所帮助。

Python 相关文章推荐
python计算N天之后日期的方法
Mar 31 Python
Python3实现的爬虫爬取数据并存入mysql数据库操作示例
Jun 06 Python
基于Python开发chrome插件的方法分析
Jul 07 Python
python 获取页面表格数据存放到csv中的方法
Dec 26 Python
docker django无法访问redis容器的解决方法
Aug 21 Python
python中@property和property函数常见使用方法示例
Oct 21 Python
在Python中实现函数重载的示例代码
Dec 12 Python
pycharm快捷键汇总
Feb 14 Python
使用 Python ssh 远程登陆服务器的最佳方案
Mar 06 Python
Python ckeditor富文本编辑器代码实例解析
Jun 22 Python
Python 列表推导式需要注意的地方
Oct 23 Python
django如何自定义manage.py管理命令
Apr 27 Python
Python: 传递列表副本方式
Dec 19 #Python
python内置模块collections知识点总结
Dec 19 #Python
Python操作redis和mongoDB的方法
Dec 19 #Python
Python 实现Serial 与STM32J进行串口通讯
Dec 18 #Python
实现Python与STM32通信方式
Dec 18 #Python
利用pandas将非数值数据转换成数值的方式
Dec 18 #Python
python 浅谈serial与stm32通信的编码问题
Dec 18 #Python
You might like
用Flash图形化数据(二)
2006/10/09 PHP
PHP获取远程图片并保存到本地的方法
2015/05/12 PHP
javascript获得CheckBoxList选中的数量
2009/10/27 Javascript
javascript中typeof的使用示例
2013/12/19 Javascript
js 获取、清空input type="file"的值(示例代码)
2013/12/24 Javascript
node.js中的fs.rmdirSync方法使用说明
2014/12/16 Javascript
原生js实现addClass,removeClass,hasClass方法
2016/04/27 Javascript
jQuery模拟实现的select点击选择效果【附demo源码下载】
2016/11/09 Javascript
JavaScript日期选择功能示例
2017/01/16 Javascript
React Native基础入门之调试React Native应用的一小步
2018/07/02 Javascript
更改BootStrap popover的默认样式及popover简单用法
2018/09/13 Javascript
详解json串反转义(消除反斜杠)
2019/08/12 Javascript
layui关闭弹窗后刷新主页面和当前更改项的例子
2019/09/06 Javascript
vue选项卡切换登录方式小案例
2019/09/27 Javascript
python逐行读取文件内容的三种方法
2014/01/20 Python
Python的Urllib库的基本使用教程
2015/04/30 Python
python实现批量改文件名称的方法
2015/05/25 Python
Python中int()函数的用法浅析
2017/10/17 Python
浅谈Python Opencv中gamma变换的使用详解
2018/04/02 Python
python 顺时针打印矩阵的超简洁代码
2018/11/14 Python
python生成以及打开json、csv和txt文件的实例
2018/11/16 Python
浅谈python在提示符下使用open打开文件失败的原因及解决方法
2018/11/30 Python
PyTorch在Windows环境搭建的方法步骤
2020/05/12 Python
matlab、python中矩阵的互相导入导出方式
2020/06/01 Python
Pytorch 解决自定义子Module .cuda() tensor失败的问题
2020/06/23 Python
python使用建议技巧分享(三)
2020/08/18 Python
如何创建一个Flask项目并进行简单配置
2020/11/18 Python
详解background属性的8个属性值(面试题)
2020/11/02 HTML / CSS
欧洲第一的摇滚和金属乐队服装网站:EMP
2017/10/26 全球购物
Skyscanner英国:苏格兰的全球三大领先航班搜索服务之一
2017/11/09 全球购物
世界上最好的野生海鲜和有机食品:Vital Choice
2020/01/16 全球购物
简单英文演讲稿
2014/01/01 职场文书
卫生安全检查制度
2014/02/04 职场文书
初中生毕业评语
2014/12/29 职场文书
2015年圣诞节寄语
2015/08/17 职场文书
python3实现无权最短路径的方法
2021/05/12 Python