Python tensorflow实现mnist手写数字识别示例【非卷积与卷积实现】


Posted in Python onDecember 19, 2019

本文实例讲述了Python tensorflow实现mnist手写数字识别。分享给大家供大家参考,具体如下:

非卷积实现

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
data_path = 'F:\CNN\data\mnist'
mnist_data = input_data.read_data_sets(data_path,one_hot=True) #offline dataset
x_data = tf.placeholder("float32", [None, 784]) # None means we can import any number of images
weight = tf.Variable(tf.ones([784,10]))
bias = tf.Variable(tf.ones([10]))
Y_model = tf.nn.softmax(tf.matmul(x_data ,weight) + bias)
#Y_model = tf.nn.sigmoid(tf.matmul(x_data ,weight) + bias)
'''
weight1 = tf.Variable(tf.ones([784,256]))
bias1 = tf.Variable(tf.ones([256]))
Y_model1 = tf.nn.softmax(tf.matmul(x_data ,weight1) + bias1)
weight1 = tf.Variable(tf.ones([256,10]))
bias1 = tf.Variable(tf.ones([10]))
Y_model = tf.nn.softmax(tf.matmul(Y_model1 ,weight1) + bias1)
'''
y_data = tf.placeholder("float32", [None, 10])
loss = tf.reduce_sum(tf.pow((y_data - Y_model), 2 ))#92%-93%
#loss = tf.reduce_sum(tf.square(y_data - Y_model)) #90%-91%
optimizer = tf.train.GradientDescentOptimizer(0.01)
train = optimizer.minimize(loss)
init = tf.global_variables_initializer()
sess = tf.Session()
sess.run(init) # reset values to wrong
for i in range(100000):
  batch_xs, batch_ys = mnist_data.train.next_batch(50)
  sess.run(train, feed_dict = {x_data: batch_xs, y_data: batch_ys})
  if i%50==0:
    correct_predict = tf.equal(tf.arg_max(Y_model,1),tf.argmax(y_data,1))
    accurate = tf.reduce_mean(tf.cast(correct_predict,"float"))
    print(sess.run(accurate,feed_dict={x_data:mnist_data.test.images,y_data:mnist_data.test.labels}))

卷积实现

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
data_path = 'F:\CNN\data\mnist'
mnist_data = input_data.read_data_sets(data_path,one_hot=True) #offline dataset
x_data = tf.placeholder("float32", [None, 784]) # None means we can import any number of images
x_image = tf.reshape(x_data, [-1,28,28,1])
w_conv = tf.Variable(tf.ones([5,5,1,32])) #weight
b_conv = tf.Variable(tf.ones([32]))    #bias
h_conv = tf.nn.relu(tf.nn.conv2d(x_image , w_conv,strides=[1,1,1,1],padding='SAME')+ b_conv)
h_pool = tf.nn.max_pool(h_conv,ksize=[1,2,2,1],strides=[1,2,2,1],padding='SAME')
w_fc = tf.Variable(tf.ones([14*14*32,1024]))
b_fc = tf.Variable(tf.ones([1024]))
h_pool_flat = tf.reshape(h_pool,[-1,14*14*32])
h_fc = tf.nn.relu(tf.matmul(h_pool_flat,w_fc) +b_fc)
W_fc = w_fc = tf.Variable(tf.ones([1024,10]))
B_fc = tf.Variable(tf.ones([10]))
Y_model = tf.nn.softmax(tf.matmul(h_fc,W_fc) +B_fc)
y_data = tf.placeholder("float32",[None,10])
loss = -tf.reduce_sum(y_data * tf.log(Y_model))
train_step = tf.train.GradientDescentOptimizer(0.01).minimize(loss)
init = tf.initialize_all_variables()
sess = tf.Session()
sess.run(init)
for i in range(1000):
  batch_xs,batch_ys =mnist_data.train.next_batch(5)
  sess.run(train_step,feed_dict={x_data:batch_xs,y_data:batch_ys})
  if i%50==0:
    correct_prediction = tf.equal(tf.argmax(Y_model,1),tf.argmax(y_data,1))
    accuracy = tf.reduce_mean(tf.cast(correct_prediction,"float"))
    print(sess.run(accuracy,feed_dict={x_data:mnist_data.test.images,y_data:mnist_data.test.labels}))

更多关于Python相关内容可查看本站专题:《Python数学运算技巧总结》、《Python图片操作技巧总结》、《Python数据结构与算法教程》、《Python函数使用技巧总结》、《Python字符串操作技巧汇总》及《Python入门与进阶经典教程》

希望本文所述对大家Python程序设计有所帮助。

Python 相关文章推荐
Python守护线程用法实例
Jun 23 Python
Python3 循环语句(for、while、break、range等)
Nov 20 Python
python-str,list,set间的转换实例
Jun 27 Python
Python 字符串转换为整形和浮点类型的方法
Jul 17 Python
python看某个模块的版本方法
Oct 16 Python
Django页面数据的缓存与使用的具体方法
Apr 23 Python
Python实现计算文件MD5和SHA1的方法示例
Jun 11 Python
Django--权限Permissions的例子
Aug 28 Python
Python 开发工具PyCharm安装教程图文详解(新手必看)
Feb 28 Python
Python HTTP下载文件并显示下载进度条功能的实现
Apr 02 Python
django 实现简单的插入视频
Apr 07 Python
PyCharm设置Ipython交互环境和宏快捷键进行数据分析图文详解
Apr 23 Python
Python: 传递列表副本方式
Dec 19 #Python
python内置模块collections知识点总结
Dec 19 #Python
Python操作redis和mongoDB的方法
Dec 19 #Python
Python 实现Serial 与STM32J进行串口通讯
Dec 18 #Python
实现Python与STM32通信方式
Dec 18 #Python
利用pandas将非数值数据转换成数值的方式
Dec 18 #Python
python 浅谈serial与stm32通信的编码问题
Dec 18 #Python
You might like
php笔记之:有规律大文件的读取与写入的分析
2013/04/26 PHP
PHP 循环删除无限分类子节点的实现代码
2013/06/21 PHP
php生成短网址示例
2014/05/05 PHP
微信公众号开发之微信公共平台消息回复类实例
2014/11/14 PHP
php+mysql+ajax实现单表多字段多关键词查询的方法
2017/04/15 PHP
Laravel框架中Blade模板的用法示例
2017/08/30 PHP
thinkphp5实现微信扫码支付
2019/12/23 PHP
Yii框架小部件(Widgets)用法实例详解
2020/05/15 PHP
XHTML-Strict 内允许出现的标签
2006/12/11 Javascript
javascript实现的动态添加表单元素input,button等(appendChild)
2007/11/24 Javascript
javascript 火狐(firefox)不显示本地图片问题解决
2008/07/05 Javascript
基于Jquery制作的幻灯片图集效果打包下载
2011/02/12 Javascript
模拟select的代码
2011/10/19 Javascript
JS实现简单易用的手机端浮动窗口显示效果
2016/09/07 Javascript
原生JS实现图片左右轮播
2016/12/30 Javascript
jQuery实现的分页功能示例
2017/01/22 Javascript
js获取元素下的第一级子元素的方法(推荐)
2017/03/05 Javascript
vue router学习之动态路由和嵌套路由详解
2017/09/21 Javascript
vue工程全局设置ajax的等待动效的方法
2019/02/22 Javascript
[原创]微信小程序获取网络类型的方法示例
2019/03/01 Javascript
封装微信小程序http拦截器过程解析
2019/08/13 Javascript
Vue 封装防刷新考试倒计时组件的实现
2020/06/05 Javascript
vue 在methods中调用mounted的实现操作
2020/08/07 Javascript
[50:21]Liquid vs Winstrike 2018国际邀请赛小组赛BO2 第二场
2018/08/19 DOTA
mac系统安装Python3初体验
2018/01/02 Python
Python数据可视化:泊松分布详解
2019/12/07 Python
Python序列化pickle模块使用详解
2020/03/05 Python
小 200 行 Python 代码制作一个换脸程序
2020/05/12 Python
去除python中的字符串空格的简单方法
2020/12/22 Python
简单介绍CSS3中Media Query的使用
2015/07/07 HTML / CSS
餐饮加盟计划书
2014/01/10 职场文书
军校本科大学生自我评价
2014/01/14 职场文书
保密工作责任书
2014/04/16 职场文书
2014年会计主管工作总结
2014/12/20 职场文书
java如何实现socket连接方法封装
2021/09/25 Java/Android
怎么禁用Windows 11快照布局? win11不使用快照布局的技巧
2021/11/21 数码科技