Python tensorflow实现mnist手写数字识别示例【非卷积与卷积实现】


Posted in Python onDecember 19, 2019

本文实例讲述了Python tensorflow实现mnist手写数字识别。分享给大家供大家参考,具体如下:

非卷积实现

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
data_path = 'F:\CNN\data\mnist'
mnist_data = input_data.read_data_sets(data_path,one_hot=True) #offline dataset
x_data = tf.placeholder("float32", [None, 784]) # None means we can import any number of images
weight = tf.Variable(tf.ones([784,10]))
bias = tf.Variable(tf.ones([10]))
Y_model = tf.nn.softmax(tf.matmul(x_data ,weight) + bias)
#Y_model = tf.nn.sigmoid(tf.matmul(x_data ,weight) + bias)
'''
weight1 = tf.Variable(tf.ones([784,256]))
bias1 = tf.Variable(tf.ones([256]))
Y_model1 = tf.nn.softmax(tf.matmul(x_data ,weight1) + bias1)
weight1 = tf.Variable(tf.ones([256,10]))
bias1 = tf.Variable(tf.ones([10]))
Y_model = tf.nn.softmax(tf.matmul(Y_model1 ,weight1) + bias1)
'''
y_data = tf.placeholder("float32", [None, 10])
loss = tf.reduce_sum(tf.pow((y_data - Y_model), 2 ))#92%-93%
#loss = tf.reduce_sum(tf.square(y_data - Y_model)) #90%-91%
optimizer = tf.train.GradientDescentOptimizer(0.01)
train = optimizer.minimize(loss)
init = tf.global_variables_initializer()
sess = tf.Session()
sess.run(init) # reset values to wrong
for i in range(100000):
  batch_xs, batch_ys = mnist_data.train.next_batch(50)
  sess.run(train, feed_dict = {x_data: batch_xs, y_data: batch_ys})
  if i%50==0:
    correct_predict = tf.equal(tf.arg_max(Y_model,1),tf.argmax(y_data,1))
    accurate = tf.reduce_mean(tf.cast(correct_predict,"float"))
    print(sess.run(accurate,feed_dict={x_data:mnist_data.test.images,y_data:mnist_data.test.labels}))

卷积实现

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
data_path = 'F:\CNN\data\mnist'
mnist_data = input_data.read_data_sets(data_path,one_hot=True) #offline dataset
x_data = tf.placeholder("float32", [None, 784]) # None means we can import any number of images
x_image = tf.reshape(x_data, [-1,28,28,1])
w_conv = tf.Variable(tf.ones([5,5,1,32])) #weight
b_conv = tf.Variable(tf.ones([32]))    #bias
h_conv = tf.nn.relu(tf.nn.conv2d(x_image , w_conv,strides=[1,1,1,1],padding='SAME')+ b_conv)
h_pool = tf.nn.max_pool(h_conv,ksize=[1,2,2,1],strides=[1,2,2,1],padding='SAME')
w_fc = tf.Variable(tf.ones([14*14*32,1024]))
b_fc = tf.Variable(tf.ones([1024]))
h_pool_flat = tf.reshape(h_pool,[-1,14*14*32])
h_fc = tf.nn.relu(tf.matmul(h_pool_flat,w_fc) +b_fc)
W_fc = w_fc = tf.Variable(tf.ones([1024,10]))
B_fc = tf.Variable(tf.ones([10]))
Y_model = tf.nn.softmax(tf.matmul(h_fc,W_fc) +B_fc)
y_data = tf.placeholder("float32",[None,10])
loss = -tf.reduce_sum(y_data * tf.log(Y_model))
train_step = tf.train.GradientDescentOptimizer(0.01).minimize(loss)
init = tf.initialize_all_variables()
sess = tf.Session()
sess.run(init)
for i in range(1000):
  batch_xs,batch_ys =mnist_data.train.next_batch(5)
  sess.run(train_step,feed_dict={x_data:batch_xs,y_data:batch_ys})
  if i%50==0:
    correct_prediction = tf.equal(tf.argmax(Y_model,1),tf.argmax(y_data,1))
    accuracy = tf.reduce_mean(tf.cast(correct_prediction,"float"))
    print(sess.run(accuracy,feed_dict={x_data:mnist_data.test.images,y_data:mnist_data.test.labels}))

更多关于Python相关内容可查看本站专题:《Python数学运算技巧总结》、《Python图片操作技巧总结》、《Python数据结构与算法教程》、《Python函数使用技巧总结》、《Python字符串操作技巧汇总》及《Python入门与进阶经典教程》

希望本文所述对大家Python程序设计有所帮助。

Python 相关文章推荐
python对字典进行排序实例
Sep 25 Python
浅谈Python的Django框架中的缓存控制
Jul 24 Python
python使用正则表达式匹配字符串开头并打印示例
Jan 11 Python
Python下实现的RSA加密/解密及签名/验证功能示例
Jul 17 Python
Python实现删除排序数组中重复项的两种方法示例
Jan 31 Python
python协程gevent案例 爬取斗鱼图片过程解析
Aug 27 Python
django自带调试服务器的使用详解
Aug 29 Python
python 实现绘制整齐的表格
Nov 18 Python
Python爬虫JSON及JSONPath运行原理详解
Jun 04 Python
Pytest之测试命名规则的使用
Apr 16 Python
Python opencv缺陷检测的实现及问题解决
Apr 24 Python
Python中常见的导入方式总结
May 06 Python
Python: 传递列表副本方式
Dec 19 #Python
python内置模块collections知识点总结
Dec 19 #Python
Python操作redis和mongoDB的方法
Dec 19 #Python
Python 实现Serial 与STM32J进行串口通讯
Dec 18 #Python
实现Python与STM32通信方式
Dec 18 #Python
利用pandas将非数值数据转换成数值的方式
Dec 18 #Python
python 浅谈serial与stm32通信的编码问题
Dec 18 #Python
You might like
PHP设置一边执行一边输出结果的代码
2013/09/30 PHP
PHP处理Ajax请求与Ajax跨域问题
2017/02/13 PHP
PHP实现更改hosts文件的方法示例
2017/08/08 PHP
一起来写段JS drag拖动代码
2010/12/09 Javascript
基于jQuery的自动完成插件
2011/02/03 Javascript
jquery 定位input元素的几种方法小结
2013/07/28 Javascript
windows8.1+iis8.5下安装node.js开发环境
2014/12/12 Javascript
jQuery+CSS实现的网页二级下滑菜单效果
2015/08/25 Javascript
Javascript 字符串模板的简单实现
2016/02/13 Javascript
JS判断非空至少输入两个字符的简单实现方法
2017/06/23 Javascript
原生JS+HTML5实现的可调节写字板功能示例
2018/08/30 Javascript
JS复杂判断的更优雅写法代码详解
2018/11/07 Javascript
vue.js 2.0实现简单分页效果
2019/07/29 Javascript
浅谈vue异步数据影响页面渲染
2019/10/29 Javascript
从0搭建vue-cli4脚手架
2020/06/17 Javascript
微信小程序实现发微博功能的示例代码
2020/06/24 Javascript
基于vue项目设置resolves.alias: '@'路径并适配webstorm
2020/12/02 Vue.js
[01:51]2014DOTA2国际邀请赛 这个赛场没有失败者VGTi5再见
2014/07/23 DOTA
windows下安装python paramiko模块的代码
2013/02/10 Python
把大数据数字口语化(python与js)两种实现
2013/02/21 Python
Python 常用的安装Module方式汇总
2017/05/06 Python
Python贪心算法实例小结
2018/04/22 Python
JSON文件及Python对JSON文件的读写操作
2018/10/07 Python
在Python中pandas.DataFrame重置索引名称的实例
2018/11/06 Python
Django ManyToManyField 跨越中间表查询的方法
2018/12/18 Python
解决Python selenium get页面很慢时的问题
2019/01/30 Python
在Pandas中处理NaN值的方法
2019/06/25 Python
python selenium 查找隐藏元素 自动播放视频功能
2019/07/24 Python
详解Django CAS 解决方案
2019/10/30 Python
Python3爬虫带上cookie的实例代码
2020/07/28 Python
Python自动登录QQ的实现示例
2020/08/28 Python
美国在线健康和美容市场:Pharmapacks
2018/12/05 全球购物
英国和爱尔兰最大的地毯零售商:Kukoon
2018/12/17 全球购物
优秀员工个人的自我评价
2013/11/29 职场文书
项目投资合作意向书
2014/07/29 职场文书
mysql连接查询中and与where的区别浅析
2021/07/01 MySQL