Python tensorflow实现mnist手写数字识别示例【非卷积与卷积实现】


Posted in Python onDecember 19, 2019

本文实例讲述了Python tensorflow实现mnist手写数字识别。分享给大家供大家参考,具体如下:

非卷积实现

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
data_path = 'F:\CNN\data\mnist'
mnist_data = input_data.read_data_sets(data_path,one_hot=True) #offline dataset
x_data = tf.placeholder("float32", [None, 784]) # None means we can import any number of images
weight = tf.Variable(tf.ones([784,10]))
bias = tf.Variable(tf.ones([10]))
Y_model = tf.nn.softmax(tf.matmul(x_data ,weight) + bias)
#Y_model = tf.nn.sigmoid(tf.matmul(x_data ,weight) + bias)
'''
weight1 = tf.Variable(tf.ones([784,256]))
bias1 = tf.Variable(tf.ones([256]))
Y_model1 = tf.nn.softmax(tf.matmul(x_data ,weight1) + bias1)
weight1 = tf.Variable(tf.ones([256,10]))
bias1 = tf.Variable(tf.ones([10]))
Y_model = tf.nn.softmax(tf.matmul(Y_model1 ,weight1) + bias1)
'''
y_data = tf.placeholder("float32", [None, 10])
loss = tf.reduce_sum(tf.pow((y_data - Y_model), 2 ))#92%-93%
#loss = tf.reduce_sum(tf.square(y_data - Y_model)) #90%-91%
optimizer = tf.train.GradientDescentOptimizer(0.01)
train = optimizer.minimize(loss)
init = tf.global_variables_initializer()
sess = tf.Session()
sess.run(init) # reset values to wrong
for i in range(100000):
  batch_xs, batch_ys = mnist_data.train.next_batch(50)
  sess.run(train, feed_dict = {x_data: batch_xs, y_data: batch_ys})
  if i%50==0:
    correct_predict = tf.equal(tf.arg_max(Y_model,1),tf.argmax(y_data,1))
    accurate = tf.reduce_mean(tf.cast(correct_predict,"float"))
    print(sess.run(accurate,feed_dict={x_data:mnist_data.test.images,y_data:mnist_data.test.labels}))

卷积实现

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
data_path = 'F:\CNN\data\mnist'
mnist_data = input_data.read_data_sets(data_path,one_hot=True) #offline dataset
x_data = tf.placeholder("float32", [None, 784]) # None means we can import any number of images
x_image = tf.reshape(x_data, [-1,28,28,1])
w_conv = tf.Variable(tf.ones([5,5,1,32])) #weight
b_conv = tf.Variable(tf.ones([32]))    #bias
h_conv = tf.nn.relu(tf.nn.conv2d(x_image , w_conv,strides=[1,1,1,1],padding='SAME')+ b_conv)
h_pool = tf.nn.max_pool(h_conv,ksize=[1,2,2,1],strides=[1,2,2,1],padding='SAME')
w_fc = tf.Variable(tf.ones([14*14*32,1024]))
b_fc = tf.Variable(tf.ones([1024]))
h_pool_flat = tf.reshape(h_pool,[-1,14*14*32])
h_fc = tf.nn.relu(tf.matmul(h_pool_flat,w_fc) +b_fc)
W_fc = w_fc = tf.Variable(tf.ones([1024,10]))
B_fc = tf.Variable(tf.ones([10]))
Y_model = tf.nn.softmax(tf.matmul(h_fc,W_fc) +B_fc)
y_data = tf.placeholder("float32",[None,10])
loss = -tf.reduce_sum(y_data * tf.log(Y_model))
train_step = tf.train.GradientDescentOptimizer(0.01).minimize(loss)
init = tf.initialize_all_variables()
sess = tf.Session()
sess.run(init)
for i in range(1000):
  batch_xs,batch_ys =mnist_data.train.next_batch(5)
  sess.run(train_step,feed_dict={x_data:batch_xs,y_data:batch_ys})
  if i%50==0:
    correct_prediction = tf.equal(tf.argmax(Y_model,1),tf.argmax(y_data,1))
    accuracy = tf.reduce_mean(tf.cast(correct_prediction,"float"))
    print(sess.run(accuracy,feed_dict={x_data:mnist_data.test.images,y_data:mnist_data.test.labels}))

更多关于Python相关内容可查看本站专题:《Python数学运算技巧总结》、《Python图片操作技巧总结》、《Python数据结构与算法教程》、《Python函数使用技巧总结》、《Python字符串操作技巧汇总》及《Python入门与进阶经典教程》

希望本文所述对大家Python程序设计有所帮助。

Python 相关文章推荐
用Python输出一个杨辉三角的例子
Jun 13 Python
spyder常用快捷键(分享)
Jul 19 Python
Python3使用PyQt5制作简单的画板/手写板实例
Oct 19 Python
浅谈pyqt5在QMainWindow中布局的问题
Jun 21 Python
python 将字符串中的数字相加求和的实现
Jul 18 Python
python requests证书问题解决
Sep 05 Python
python英语单词测试小程序代码实例
Sep 09 Python
Python 脚本拉取 Docker 镜像问题
Nov 10 Python
Python3 使用map()批量的转换数据类型,如str转float的实现
Nov 29 Python
基于keras输出中间层结果的2种实现方式
Jan 24 Python
详解Python中string模块除去Str还剩下什么
Nov 30 Python
python dir函数快速掌握用法技巧
Dec 09 Python
Python: 传递列表副本方式
Dec 19 #Python
python内置模块collections知识点总结
Dec 19 #Python
Python操作redis和mongoDB的方法
Dec 19 #Python
Python 实现Serial 与STM32J进行串口通讯
Dec 18 #Python
实现Python与STM32通信方式
Dec 18 #Python
利用pandas将非数值数据转换成数值的方式
Dec 18 #Python
python 浅谈serial与stm32通信的编码问题
Dec 18 #Python
You might like
Symfony2 session用法实例分析
2016/02/04 PHP
php封装的验证码类分享
2017/02/26 PHP
php多进程模拟并发事务产生的问题小结
2018/12/07 PHP
分享几种好用的PHP自定义加密函数(可逆/不可逆)
2020/09/15 PHP
javascript中callee与caller的用法和应用场景
2010/12/08 Javascript
JavaScript子窗口ModalDialog中操作父窗口对像
2012/12/11 Javascript
JS随机生成不重复数据的实例方法
2013/07/17 Javascript
JavaScript显示表单内元素数量的方法
2015/04/02 Javascript
浅谈Javascript的静态属性和原型属性
2015/05/07 Javascript
Treegrid的动态加载实例代码
2016/04/29 Javascript
批量下载对路网图片并生成html的实现方法
2016/06/07 Javascript
详解vue2.0组件通信各种情况总结与实例分析
2017/03/22 Javascript
Vue.js render方法使用详解
2017/04/05 Javascript
vue实现简单表格组件实例详解
2017/04/16 Javascript
jquery引入外部CDN 加载失败则引入本地jq库
2018/05/23 jQuery
nodejs更新package.json中的dependencies依赖到最新版本的方法
2018/10/10 NodeJs
利用Dectorator分模块存储Vuex状态的实现
2019/02/05 Javascript
nodejs 递归拷贝、读取目录下所有文件和目录
2019/07/18 NodeJs
在Python中使用NLTK库实现对词干的提取的教程
2015/04/08 Python
Python使用os模块和fileinput模块来操作文件目录
2016/01/19 Python
Python随机数用法实例详解【基于random模块】
2017/04/18 Python
Python之re操作方法(详解)
2017/06/14 Python
Python Threading 线程/互斥锁/死锁/GIL锁
2019/07/21 Python
Python3列表List入门知识附实例
2020/02/09 Python
python 绘制场景热力图的示例
2020/09/23 Python
Python实现京东抢秒杀功能
2021/01/25 Python
浅谈HTML5 defer和async的区别
2016/06/07 HTML / CSS
英国最大的纸工艺品商店:CraftStash
2018/12/01 全球购物
愚人节活动策划方案
2014/03/11 职场文书
群众路线教育实践活动学习心得体会
2014/10/30 职场文书
大学四年个人总结
2015/03/03 职场文书
个人借条范本
2015/05/25 职场文书
少年犯观后感
2015/06/11 职场文书
新娘父亲婚礼致辞
2015/07/27 职场文书
python制作图形界面的2048游戏, 基于tkinter
2021/04/06 Python
Apache Linkis 中间件架构及快速安装步骤
2022/03/16 Servers