Python tensorflow实现mnist手写数字识别示例【非卷积与卷积实现】


Posted in Python onDecember 19, 2019

本文实例讲述了Python tensorflow实现mnist手写数字识别。分享给大家供大家参考,具体如下:

非卷积实现

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
data_path = 'F:\CNN\data\mnist'
mnist_data = input_data.read_data_sets(data_path,one_hot=True) #offline dataset
x_data = tf.placeholder("float32", [None, 784]) # None means we can import any number of images
weight = tf.Variable(tf.ones([784,10]))
bias = tf.Variable(tf.ones([10]))
Y_model = tf.nn.softmax(tf.matmul(x_data ,weight) + bias)
#Y_model = tf.nn.sigmoid(tf.matmul(x_data ,weight) + bias)
'''
weight1 = tf.Variable(tf.ones([784,256]))
bias1 = tf.Variable(tf.ones([256]))
Y_model1 = tf.nn.softmax(tf.matmul(x_data ,weight1) + bias1)
weight1 = tf.Variable(tf.ones([256,10]))
bias1 = tf.Variable(tf.ones([10]))
Y_model = tf.nn.softmax(tf.matmul(Y_model1 ,weight1) + bias1)
'''
y_data = tf.placeholder("float32", [None, 10])
loss = tf.reduce_sum(tf.pow((y_data - Y_model), 2 ))#92%-93%
#loss = tf.reduce_sum(tf.square(y_data - Y_model)) #90%-91%
optimizer = tf.train.GradientDescentOptimizer(0.01)
train = optimizer.minimize(loss)
init = tf.global_variables_initializer()
sess = tf.Session()
sess.run(init) # reset values to wrong
for i in range(100000):
  batch_xs, batch_ys = mnist_data.train.next_batch(50)
  sess.run(train, feed_dict = {x_data: batch_xs, y_data: batch_ys})
  if i%50==0:
    correct_predict = tf.equal(tf.arg_max(Y_model,1),tf.argmax(y_data,1))
    accurate = tf.reduce_mean(tf.cast(correct_predict,"float"))
    print(sess.run(accurate,feed_dict={x_data:mnist_data.test.images,y_data:mnist_data.test.labels}))

卷积实现

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
data_path = 'F:\CNN\data\mnist'
mnist_data = input_data.read_data_sets(data_path,one_hot=True) #offline dataset
x_data = tf.placeholder("float32", [None, 784]) # None means we can import any number of images
x_image = tf.reshape(x_data, [-1,28,28,1])
w_conv = tf.Variable(tf.ones([5,5,1,32])) #weight
b_conv = tf.Variable(tf.ones([32]))    #bias
h_conv = tf.nn.relu(tf.nn.conv2d(x_image , w_conv,strides=[1,1,1,1],padding='SAME')+ b_conv)
h_pool = tf.nn.max_pool(h_conv,ksize=[1,2,2,1],strides=[1,2,2,1],padding='SAME')
w_fc = tf.Variable(tf.ones([14*14*32,1024]))
b_fc = tf.Variable(tf.ones([1024]))
h_pool_flat = tf.reshape(h_pool,[-1,14*14*32])
h_fc = tf.nn.relu(tf.matmul(h_pool_flat,w_fc) +b_fc)
W_fc = w_fc = tf.Variable(tf.ones([1024,10]))
B_fc = tf.Variable(tf.ones([10]))
Y_model = tf.nn.softmax(tf.matmul(h_fc,W_fc) +B_fc)
y_data = tf.placeholder("float32",[None,10])
loss = -tf.reduce_sum(y_data * tf.log(Y_model))
train_step = tf.train.GradientDescentOptimizer(0.01).minimize(loss)
init = tf.initialize_all_variables()
sess = tf.Session()
sess.run(init)
for i in range(1000):
  batch_xs,batch_ys =mnist_data.train.next_batch(5)
  sess.run(train_step,feed_dict={x_data:batch_xs,y_data:batch_ys})
  if i%50==0:
    correct_prediction = tf.equal(tf.argmax(Y_model,1),tf.argmax(y_data,1))
    accuracy = tf.reduce_mean(tf.cast(correct_prediction,"float"))
    print(sess.run(accuracy,feed_dict={x_data:mnist_data.test.images,y_data:mnist_data.test.labels}))

更多关于Python相关内容可查看本站专题:《Python数学运算技巧总结》、《Python图片操作技巧总结》、《Python数据结构与算法教程》、《Python函数使用技巧总结》、《Python字符串操作技巧汇总》及《Python入门与进阶经典教程》

希望本文所述对大家Python程序设计有所帮助。

Python 相关文章推荐
由Python运算π的值深入Python中科学计算的实现
Apr 17 Python
python爬虫爬取某站上海租房图片
Feb 04 Python
python中返回矩阵的行列方法
Apr 04 Python
Python pandas.DataFrame调整列顺序及修改index名的方法
Jun 21 Python
Django对数据库进行添加与更新的例子
Jul 12 Python
实现Python与STM32通信方式
Dec 18 Python
Pytorch实现神经网络的分类方式
Jan 08 Python
Python实现汇率转换操作
May 03 Python
什么是python的列表推导式
May 26 Python
django序列化时使用外键的真实值操作
Jul 15 Python
python进行二次方程式计算的实例讲解
Dec 06 Python
pandas数值排序的实现实例
Jul 25 Python
Python: 传递列表副本方式
Dec 19 #Python
python内置模块collections知识点总结
Dec 19 #Python
Python操作redis和mongoDB的方法
Dec 19 #Python
Python 实现Serial 与STM32J进行串口通讯
Dec 18 #Python
实现Python与STM32通信方式
Dec 18 #Python
利用pandas将非数值数据转换成数值的方式
Dec 18 #Python
python 浅谈serial与stm32通信的编码问题
Dec 18 #Python
You might like
也谈截取首页新闻 - 范例
2006/10/09 PHP
深思 PHP 数组遍历的差异(array_diff 的实现)
2008/03/23 PHP
深入PHP数据加密详解
2013/06/18 PHP
简单谈谈PHP vs Node.js
2015/07/17 PHP
PHP实现小程序批量通知推送
2018/11/27 PHP
W3C Group的JavaScript1.8 新特性介绍
2009/05/19 Javascript
10个基于jQuery或JavaScript的WYSIWYG 编辑器整理
2010/05/06 Javascript
Javascript 实现图片无缝滚动
2014/12/19 Javascript
Node.js开发者必须了解的4个JS要点
2016/02/21 Javascript
精通JavaScript的this关键字
2020/05/28 Javascript
类似于QQ的右滑删除效果的实现方法
2016/10/16 Javascript
js输入框使用正则表达式校验输入内容的实例
2017/02/12 Javascript
Bootstrap媒体对象学习使用
2017/03/07 Javascript
vue router 配置路由的方法
2018/07/26 Javascript
vue2过滤器模糊查询方法
2018/09/16 Javascript
ES6入门教程之Array.from()方法
2019/03/23 Javascript
Vue项目中使用jquery的简单方法
2019/05/16 jQuery
layui table数据修改的回显方法
2019/09/04 Javascript
原生javascript运动函数的封装示例【匀速、抛物线、多属性的运动等】
2020/02/23 Javascript
ant-design表单处理和常用方法及自定义验证操作
2020/10/27 Javascript
JavaScript实现网页tab栏效果制作
2020/11/20 Javascript
[01:41]DOTA2 2015国际邀请赛中国区预选赛第三日战报
2015/05/28 DOTA
[44:50]2018DOTA2亚洲邀请赛 4.1 小组赛 A组 TNC vs VG
2018/04/02 DOTA
Python利用ansible分发处理任务
2015/08/04 Python
实例讲解Python中global语句下全局变量的值的修改
2016/06/16 Python
pygame游戏之旅 添加键盘按键的方法
2018/11/20 Python
python实现三维拟合的方法
2018/12/29 Python
Python实现名片管理系统
2020/02/14 Python
python实现拼图小游戏
2020/02/22 Python
HTML5中判断横屏竖屏的方法(移动端)
2016/08/04 HTML / CSS
中专毕业生自我鉴定
2014/02/02 职场文书
优秀高中生事迹材料
2014/02/11 职场文书
高中生学期学习自我评价
2014/02/24 职场文书
安全目标管理责任书
2014/07/25 职场文书
运动会加油稿
2015/07/22 职场文书
Spring Data JPA框架自定义Repository接口
2022/04/28 Java/Android