Python tensorflow实现mnist手写数字识别示例【非卷积与卷积实现】


Posted in Python onDecember 19, 2019

本文实例讲述了Python tensorflow实现mnist手写数字识别。分享给大家供大家参考,具体如下:

非卷积实现

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
data_path = 'F:\CNN\data\mnist'
mnist_data = input_data.read_data_sets(data_path,one_hot=True) #offline dataset
x_data = tf.placeholder("float32", [None, 784]) # None means we can import any number of images
weight = tf.Variable(tf.ones([784,10]))
bias = tf.Variable(tf.ones([10]))
Y_model = tf.nn.softmax(tf.matmul(x_data ,weight) + bias)
#Y_model = tf.nn.sigmoid(tf.matmul(x_data ,weight) + bias)
'''
weight1 = tf.Variable(tf.ones([784,256]))
bias1 = tf.Variable(tf.ones([256]))
Y_model1 = tf.nn.softmax(tf.matmul(x_data ,weight1) + bias1)
weight1 = tf.Variable(tf.ones([256,10]))
bias1 = tf.Variable(tf.ones([10]))
Y_model = tf.nn.softmax(tf.matmul(Y_model1 ,weight1) + bias1)
'''
y_data = tf.placeholder("float32", [None, 10])
loss = tf.reduce_sum(tf.pow((y_data - Y_model), 2 ))#92%-93%
#loss = tf.reduce_sum(tf.square(y_data - Y_model)) #90%-91%
optimizer = tf.train.GradientDescentOptimizer(0.01)
train = optimizer.minimize(loss)
init = tf.global_variables_initializer()
sess = tf.Session()
sess.run(init) # reset values to wrong
for i in range(100000):
  batch_xs, batch_ys = mnist_data.train.next_batch(50)
  sess.run(train, feed_dict = {x_data: batch_xs, y_data: batch_ys})
  if i%50==0:
    correct_predict = tf.equal(tf.arg_max(Y_model,1),tf.argmax(y_data,1))
    accurate = tf.reduce_mean(tf.cast(correct_predict,"float"))
    print(sess.run(accurate,feed_dict={x_data:mnist_data.test.images,y_data:mnist_data.test.labels}))

卷积实现

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
data_path = 'F:\CNN\data\mnist'
mnist_data = input_data.read_data_sets(data_path,one_hot=True) #offline dataset
x_data = tf.placeholder("float32", [None, 784]) # None means we can import any number of images
x_image = tf.reshape(x_data, [-1,28,28,1])
w_conv = tf.Variable(tf.ones([5,5,1,32])) #weight
b_conv = tf.Variable(tf.ones([32]))    #bias
h_conv = tf.nn.relu(tf.nn.conv2d(x_image , w_conv,strides=[1,1,1,1],padding='SAME')+ b_conv)
h_pool = tf.nn.max_pool(h_conv,ksize=[1,2,2,1],strides=[1,2,2,1],padding='SAME')
w_fc = tf.Variable(tf.ones([14*14*32,1024]))
b_fc = tf.Variable(tf.ones([1024]))
h_pool_flat = tf.reshape(h_pool,[-1,14*14*32])
h_fc = tf.nn.relu(tf.matmul(h_pool_flat,w_fc) +b_fc)
W_fc = w_fc = tf.Variable(tf.ones([1024,10]))
B_fc = tf.Variable(tf.ones([10]))
Y_model = tf.nn.softmax(tf.matmul(h_fc,W_fc) +B_fc)
y_data = tf.placeholder("float32",[None,10])
loss = -tf.reduce_sum(y_data * tf.log(Y_model))
train_step = tf.train.GradientDescentOptimizer(0.01).minimize(loss)
init = tf.initialize_all_variables()
sess = tf.Session()
sess.run(init)
for i in range(1000):
  batch_xs,batch_ys =mnist_data.train.next_batch(5)
  sess.run(train_step,feed_dict={x_data:batch_xs,y_data:batch_ys})
  if i%50==0:
    correct_prediction = tf.equal(tf.argmax(Y_model,1),tf.argmax(y_data,1))
    accuracy = tf.reduce_mean(tf.cast(correct_prediction,"float"))
    print(sess.run(accuracy,feed_dict={x_data:mnist_data.test.images,y_data:mnist_data.test.labels}))

更多关于Python相关内容可查看本站专题:《Python数学运算技巧总结》、《Python图片操作技巧总结》、《Python数据结构与算法教程》、《Python函数使用技巧总结》、《Python字符串操作技巧汇总》及《Python入门与进阶经典教程》

希望本文所述对大家Python程序设计有所帮助。

Python 相关文章推荐
python实现在字符串中查找子字符串的方法
Jul 11 Python
Python 3.x 连接数据库示例(pymysql 方式)
Jan 19 Python
Python 私有函数的实例详解
Sep 11 Python
Python爬虫之正则表达式基本用法实例分析
Aug 08 Python
pycharm远程开发项目的实现步骤
Jan 20 Python
Pycharm运行加载文本出现错误的解决方法
Jun 27 Python
django如何自己创建一个中间件
Jul 24 Python
对django中foreignkey的简单使用详解
Jul 28 Python
Python字典推导式将cookie字符串转化为字典解析
Aug 10 Python
Pytorch Tensor 输出为txt和mat格式方式
Jan 03 Python
PyCharm Anaconda配置PyQt5开发环境及创建项目的教程详解
Mar 24 Python
Python Django搭建文件下载服务器的实现
May 10 Python
Python: 传递列表副本方式
Dec 19 #Python
python内置模块collections知识点总结
Dec 19 #Python
Python操作redis和mongoDB的方法
Dec 19 #Python
Python 实现Serial 与STM32J进行串口通讯
Dec 18 #Python
实现Python与STM32通信方式
Dec 18 #Python
利用pandas将非数值数据转换成数值的方式
Dec 18 #Python
python 浅谈serial与stm32通信的编码问题
Dec 18 #Python
You might like
可定制的PHP缩略图生成程式(需要GD库支持)
2007/03/06 PHP
php采集文章中的图片获取替换到本地(实现代码)
2013/07/08 PHP
使用php测试硬盘写入速度示例
2014/01/27 PHP
浅谈PHP正则表达式中修饰符/i, /is, /s, /isU
2014/10/21 PHP
PHP读取zip文件的方法示例
2016/11/17 PHP
Gambit vs CL BO3 第一场 2.13
2021/03/10 DOTA
JS array 数组详解
2009/03/22 Javascript
SeaJS 与 RequireJS 的差异对比
2014/12/08 Javascript
原生js开发的日历插件
2017/02/04 Javascript
vue初尝试--项目结构(推荐)
2018/01/30 Javascript
javascript数据结构之多叉树经典操作示例【创建、添加、遍历、移除等】
2018/08/01 Javascript
浅谈Vue数据响应
2018/11/05 Javascript
详解webpack引入第三方库的方式以及注意事项
2019/01/15 Javascript
js使用cookie实现记住用户名功能示例
2019/06/13 Javascript
Vue之封装公用变量以及实现方式
2020/07/31 Javascript
详解Vue3.0 + TypeScript + Vite初体验
2021/02/22 Vue.js
[50:02]完美世界DOTA2联赛循环赛 Magma vs IO BO2第一场 11.01
2020/11/02 DOTA
Python  连接字符串(join %)
2008/09/06 Python
python获取Linux下文件版本信息、公司名和产品名的方法
2014/10/05 Python
Python找出文件中使用率最高的汉字实例详解
2015/06/03 Python
利用Python实现kNN算法的代码
2019/08/16 Python
python sorted方法和列表使用解析
2019/11/18 Python
Python多线程threading创建及使用方法解析
2020/06/17 Python
HTML5 CSS3实现一个精美VCD包装盒个性幻灯片案例
2014/06/16 HTML / CSS
canvas绘制图片drawImage使用方法
2020/09/15 HTML / CSS
比利时网上药店: Drogisterij.net
2017/03/17 全球购物
德国家用电器购物网站:Premiumshop24
2019/08/22 全球购物
橄榄树药房:OLIVEDA
2019/09/01 全球购物
财务担保书范文
2014/04/02 职场文书
服装设计专业自荐信
2014/06/17 职场文书
心得体会的写法
2014/09/05 职场文书
校园安全广播稿范文
2014/09/25 职场文书
委托公证书格式
2015/01/26 职场文书
投资申请报告
2015/05/19 职场文书
2016新年问候语大全
2015/11/11 职场文书
【HBU】数据库第四周 单表查询
2021/04/05 SQL Server