Python tensorflow实现mnist手写数字识别示例【非卷积与卷积实现】


Posted in Python onDecember 19, 2019

本文实例讲述了Python tensorflow实现mnist手写数字识别。分享给大家供大家参考,具体如下:

非卷积实现

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
data_path = 'F:\CNN\data\mnist'
mnist_data = input_data.read_data_sets(data_path,one_hot=True) #offline dataset
x_data = tf.placeholder("float32", [None, 784]) # None means we can import any number of images
weight = tf.Variable(tf.ones([784,10]))
bias = tf.Variable(tf.ones([10]))
Y_model = tf.nn.softmax(tf.matmul(x_data ,weight) + bias)
#Y_model = tf.nn.sigmoid(tf.matmul(x_data ,weight) + bias)
'''
weight1 = tf.Variable(tf.ones([784,256]))
bias1 = tf.Variable(tf.ones([256]))
Y_model1 = tf.nn.softmax(tf.matmul(x_data ,weight1) + bias1)
weight1 = tf.Variable(tf.ones([256,10]))
bias1 = tf.Variable(tf.ones([10]))
Y_model = tf.nn.softmax(tf.matmul(Y_model1 ,weight1) + bias1)
'''
y_data = tf.placeholder("float32", [None, 10])
loss = tf.reduce_sum(tf.pow((y_data - Y_model), 2 ))#92%-93%
#loss = tf.reduce_sum(tf.square(y_data - Y_model)) #90%-91%
optimizer = tf.train.GradientDescentOptimizer(0.01)
train = optimizer.minimize(loss)
init = tf.global_variables_initializer()
sess = tf.Session()
sess.run(init) # reset values to wrong
for i in range(100000):
  batch_xs, batch_ys = mnist_data.train.next_batch(50)
  sess.run(train, feed_dict = {x_data: batch_xs, y_data: batch_ys})
  if i%50==0:
    correct_predict = tf.equal(tf.arg_max(Y_model,1),tf.argmax(y_data,1))
    accurate = tf.reduce_mean(tf.cast(correct_predict,"float"))
    print(sess.run(accurate,feed_dict={x_data:mnist_data.test.images,y_data:mnist_data.test.labels}))

卷积实现

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
data_path = 'F:\CNN\data\mnist'
mnist_data = input_data.read_data_sets(data_path,one_hot=True) #offline dataset
x_data = tf.placeholder("float32", [None, 784]) # None means we can import any number of images
x_image = tf.reshape(x_data, [-1,28,28,1])
w_conv = tf.Variable(tf.ones([5,5,1,32])) #weight
b_conv = tf.Variable(tf.ones([32]))    #bias
h_conv = tf.nn.relu(tf.nn.conv2d(x_image , w_conv,strides=[1,1,1,1],padding='SAME')+ b_conv)
h_pool = tf.nn.max_pool(h_conv,ksize=[1,2,2,1],strides=[1,2,2,1],padding='SAME')
w_fc = tf.Variable(tf.ones([14*14*32,1024]))
b_fc = tf.Variable(tf.ones([1024]))
h_pool_flat = tf.reshape(h_pool,[-1,14*14*32])
h_fc = tf.nn.relu(tf.matmul(h_pool_flat,w_fc) +b_fc)
W_fc = w_fc = tf.Variable(tf.ones([1024,10]))
B_fc = tf.Variable(tf.ones([10]))
Y_model = tf.nn.softmax(tf.matmul(h_fc,W_fc) +B_fc)
y_data = tf.placeholder("float32",[None,10])
loss = -tf.reduce_sum(y_data * tf.log(Y_model))
train_step = tf.train.GradientDescentOptimizer(0.01).minimize(loss)
init = tf.initialize_all_variables()
sess = tf.Session()
sess.run(init)
for i in range(1000):
  batch_xs,batch_ys =mnist_data.train.next_batch(5)
  sess.run(train_step,feed_dict={x_data:batch_xs,y_data:batch_ys})
  if i%50==0:
    correct_prediction = tf.equal(tf.argmax(Y_model,1),tf.argmax(y_data,1))
    accuracy = tf.reduce_mean(tf.cast(correct_prediction,"float"))
    print(sess.run(accuracy,feed_dict={x_data:mnist_data.test.images,y_data:mnist_data.test.labels}))

更多关于Python相关内容可查看本站专题:《Python数学运算技巧总结》、《Python图片操作技巧总结》、《Python数据结构与算法教程》、《Python函数使用技巧总结》、《Python字符串操作技巧汇总》及《Python入门与进阶经典教程》

希望本文所述对大家Python程序设计有所帮助。

Python 相关文章推荐
python 图片验证码代码分享
Jul 04 Python
Python简单调用MySQL存储过程并获得返回值的方法
Jul 20 Python
Python的自动化部署模块Fabric的安装及使用指南
Jan 19 Python
Python 中开发pattern的string模板(template) 实例详解
Apr 01 Python
使用Python实现简单的服务器功能
Aug 25 Python
python分析作业提交情况
Nov 22 Python
python 读取DICOM头文件的实例
May 07 Python
PyCharm代码格式调整方法
May 23 Python
Django实现简单网页弹出警告代码
Nov 15 Python
Python面向对象编程基础实例分析
Jan 17 Python
ansible-playbook实现自动部署KVM及安装python3的详细教程
May 11 Python
用python实现一个简单的验证码
Dec 09 Python
Python: 传递列表副本方式
Dec 19 #Python
python内置模块collections知识点总结
Dec 19 #Python
Python操作redis和mongoDB的方法
Dec 19 #Python
Python 实现Serial 与STM32J进行串口通讯
Dec 18 #Python
实现Python与STM32通信方式
Dec 18 #Python
利用pandas将非数值数据转换成数值的方式
Dec 18 #Python
python 浅谈serial与stm32通信的编码问题
Dec 18 #Python
You might like
PHP的FTP学习(四)
2006/10/09 PHP
PHP计算数组中值的和与乘积的方法(array_sum与array_product函数)
2016/04/01 PHP
Code:loadScript( )加载js的功能函数
2007/02/02 Javascript
javascript面向对象包装类Class封装类库剖析
2013/01/24 Javascript
javascript 中__proto__和prototype详解
2014/11/25 Javascript
jQuery+CSS3实现树叶飘落特效
2015/02/01 Javascript
jquery实现可自动判断位置的弹出层效果代码
2015/10/12 Javascript
详解js私有作用域中创建特权方法
2016/01/25 Javascript
利用jQuery中的ajax分页实现代码
2016/02/25 Javascript
jQuery中使用animate自定义动画的方法
2016/05/29 Javascript
JQuery和PHP结合实现动态进度条上传显示
2016/11/23 Javascript
浅析js的模块化编写 require.js
2016/12/07 Javascript
jQuery在header中设置请求信息的方法
2017/03/06 Javascript
jQuery图片瀑布流的简单实现代码
2017/03/15 Javascript
iscroll动态加载数据完美解决方法
2017/07/18 Javascript
ES6学习教程之块级作用域详解
2017/10/09 Javascript
利用JavaScript的%做隔行换色的实例
2017/11/25 Javascript
让axios发送表单请求形式的键值对post数据的实例
2018/08/11 Javascript
详解如何在Vue里建立长按指令
2018/08/20 Javascript
详解小程序输入框闪烁及重影BUG解决方案
2018/08/31 Javascript
Vue中Quill富文本编辑器的使用教程
2018/09/21 Javascript
javascript中this的用法实践分析
2019/07/29 Javascript
[01:04:09]DOTA2-DPC中国联赛 正赛 iG vs VG BO3 第二场 2月2日
2021/03/11 DOTA
以一段代码为实例快速入门Python2.7
2015/03/31 Python
Python selenium文件上传方法汇总
2020/11/19 Python
使用Tensorflow实现可视化中间层和卷积层
2020/01/24 Python
Django mysqlclient安装和使用详解
2020/09/17 Python
CSS3使用border-radius属性制作圆角
2014/12/22 HTML / CSS
美国著名的品牌折扣店:Burlington
2017/06/08 全球购物
巴西最大的玩具连锁店:Ri Happy
2020/06/17 全球购物
会计与审计专业大专生求职信
2013/10/03 职场文书
庆元旦文艺演出主持词
2014/03/27 职场文书
小学数学教研活动总结
2014/07/01 职场文书
公证委托书标准格式
2014/09/11 职场文书
停车位租赁协议书
2014/09/24 职场文书
校长一岗双责责任书
2015/05/09 职场文书