Python tensorflow实现mnist手写数字识别示例【非卷积与卷积实现】


Posted in Python onDecember 19, 2019

本文实例讲述了Python tensorflow实现mnist手写数字识别。分享给大家供大家参考,具体如下:

非卷积实现

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
data_path = 'F:\CNN\data\mnist'
mnist_data = input_data.read_data_sets(data_path,one_hot=True) #offline dataset
x_data = tf.placeholder("float32", [None, 784]) # None means we can import any number of images
weight = tf.Variable(tf.ones([784,10]))
bias = tf.Variable(tf.ones([10]))
Y_model = tf.nn.softmax(tf.matmul(x_data ,weight) + bias)
#Y_model = tf.nn.sigmoid(tf.matmul(x_data ,weight) + bias)
'''
weight1 = tf.Variable(tf.ones([784,256]))
bias1 = tf.Variable(tf.ones([256]))
Y_model1 = tf.nn.softmax(tf.matmul(x_data ,weight1) + bias1)
weight1 = tf.Variable(tf.ones([256,10]))
bias1 = tf.Variable(tf.ones([10]))
Y_model = tf.nn.softmax(tf.matmul(Y_model1 ,weight1) + bias1)
'''
y_data = tf.placeholder("float32", [None, 10])
loss = tf.reduce_sum(tf.pow((y_data - Y_model), 2 ))#92%-93%
#loss = tf.reduce_sum(tf.square(y_data - Y_model)) #90%-91%
optimizer = tf.train.GradientDescentOptimizer(0.01)
train = optimizer.minimize(loss)
init = tf.global_variables_initializer()
sess = tf.Session()
sess.run(init) # reset values to wrong
for i in range(100000):
  batch_xs, batch_ys = mnist_data.train.next_batch(50)
  sess.run(train, feed_dict = {x_data: batch_xs, y_data: batch_ys})
  if i%50==0:
    correct_predict = tf.equal(tf.arg_max(Y_model,1),tf.argmax(y_data,1))
    accurate = tf.reduce_mean(tf.cast(correct_predict,"float"))
    print(sess.run(accurate,feed_dict={x_data:mnist_data.test.images,y_data:mnist_data.test.labels}))

卷积实现

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
data_path = 'F:\CNN\data\mnist'
mnist_data = input_data.read_data_sets(data_path,one_hot=True) #offline dataset
x_data = tf.placeholder("float32", [None, 784]) # None means we can import any number of images
x_image = tf.reshape(x_data, [-1,28,28,1])
w_conv = tf.Variable(tf.ones([5,5,1,32])) #weight
b_conv = tf.Variable(tf.ones([32]))    #bias
h_conv = tf.nn.relu(tf.nn.conv2d(x_image , w_conv,strides=[1,1,1,1],padding='SAME')+ b_conv)
h_pool = tf.nn.max_pool(h_conv,ksize=[1,2,2,1],strides=[1,2,2,1],padding='SAME')
w_fc = tf.Variable(tf.ones([14*14*32,1024]))
b_fc = tf.Variable(tf.ones([1024]))
h_pool_flat = tf.reshape(h_pool,[-1,14*14*32])
h_fc = tf.nn.relu(tf.matmul(h_pool_flat,w_fc) +b_fc)
W_fc = w_fc = tf.Variable(tf.ones([1024,10]))
B_fc = tf.Variable(tf.ones([10]))
Y_model = tf.nn.softmax(tf.matmul(h_fc,W_fc) +B_fc)
y_data = tf.placeholder("float32",[None,10])
loss = -tf.reduce_sum(y_data * tf.log(Y_model))
train_step = tf.train.GradientDescentOptimizer(0.01).minimize(loss)
init = tf.initialize_all_variables()
sess = tf.Session()
sess.run(init)
for i in range(1000):
  batch_xs,batch_ys =mnist_data.train.next_batch(5)
  sess.run(train_step,feed_dict={x_data:batch_xs,y_data:batch_ys})
  if i%50==0:
    correct_prediction = tf.equal(tf.argmax(Y_model,1),tf.argmax(y_data,1))
    accuracy = tf.reduce_mean(tf.cast(correct_prediction,"float"))
    print(sess.run(accuracy,feed_dict={x_data:mnist_data.test.images,y_data:mnist_data.test.labels}))

更多关于Python相关内容可查看本站专题:《Python数学运算技巧总结》、《Python图片操作技巧总结》、《Python数据结构与算法教程》、《Python函数使用技巧总结》、《Python字符串操作技巧汇总》及《Python入门与进阶经典教程》

希望本文所述对大家Python程序设计有所帮助。

Python 相关文章推荐
Python2.x版本中cmp()方法的使用教程
May 14 Python
Python实现批量下载文件
May 17 Python
分析用Python脚本关闭文件操作的机制
Jun 28 Python
Python中格式化format()方法详解
Apr 01 Python
Python学生信息管理系统修改版
Mar 13 Python
Python 限制线程的最大数量的方法(Semaphore)
Feb 22 Python
Python Django切换MySQL数据库实例详解
Jul 16 Python
kafka监控获取指定topic的消息总量示例
Dec 23 Python
Tensorflow 实现分批量读取数据
Jan 04 Python
pytorch载入预训练模型后,实现训练指定层
Jan 06 Python
详解python如何引用包package
Jun 07 Python
Python可视化神器pyecharts之绘制箱形图
Jul 07 Python
Python: 传递列表副本方式
Dec 19 #Python
python内置模块collections知识点总结
Dec 19 #Python
Python操作redis和mongoDB的方法
Dec 19 #Python
Python 实现Serial 与STM32J进行串口通讯
Dec 18 #Python
实现Python与STM32通信方式
Dec 18 #Python
利用pandas将非数值数据转换成数值的方式
Dec 18 #Python
python 浅谈serial与stm32通信的编码问题
Dec 18 #Python
You might like
使用无限生命期Session的方法
2006/10/09 PHP
php 获取远程网页内容的函数
2009/09/08 PHP
使ecshop模板中可引用常量的实现方法
2011/06/02 PHP
php curl获取网页内容(IPV6下超时)的解决办法
2013/07/16 PHP
为Plesk PHP7启用Oracle OCI8扩展方法总结
2019/03/29 PHP
浅谈PHP之ThinkPHP框架使用详解
2020/07/21 PHP
jquery中常用的函数和属性详细解析
2014/03/07 Javascript
Node.js的MongoDB驱动Mongoose基本使用教程
2016/03/01 Javascript
JS实现上传图片的三种方法并实现预览图片功能
2017/07/14 Javascript
使用jQuery实现购物车结算功能
2017/08/15 jQuery
mescroll.js上拉加载下拉刷新组件使用详解
2017/11/13 Javascript
vue 本地环境跨域请求proxyTable的方法
2018/09/19 Javascript
vue组件中的样式属性scoped实例详解
2018/10/30 Javascript
使用Vue实现简单计算器
2020/02/25 Javascript
jQuery实现鼠标拖动图片功能
2021/03/04 jQuery
[02:45]2016年中国刀塔全程回顾,完美“圣”典即将上演
2016/12/15 DOTA
教你如何在Django 1.6中正确使用 Signal
2014/06/22 Python
为python设置socket代理的方法
2015/01/14 Python
Python中的深拷贝和浅拷贝详解
2015/06/03 Python
TensorFlow如何实现反向传播
2018/02/06 Python
pandas中的ExcelWriter和ExcelFile的实现方法
2020/04/24 Python
如何在Windows中安装多个python解释器
2020/06/16 Python
Python实现京东抢秒杀功能
2021/01/25 Python
纯CSS3实现扇形动画菜单(简化版)实例源码
2017/01/17 HTML / CSS
html5指南-5.使用web storage存储键值对的数据
2013/01/07 HTML / CSS
html5 Canvas画图教程(2)—画直线与设置线条的样式如颜色/端点/交汇点
2013/01/09 HTML / CSS
Book Depository欧盟:一家领先的国际图书零售商
2019/05/21 全球购物
华为消费者德国官方网站:HUAWEI德国
2020/11/03 全球购物
环境工程求职简历的自我评价范文
2013/10/24 职场文书
师范院校学生自荐信范文
2013/12/27 职场文书
回门宴新郎答谢词
2014/01/12 职场文书
国庆促销活动总结
2014/08/29 职场文书
初二学生评语大全
2014/12/26 职场文书
项目经理岗位职责
2015/01/31 职场文书
学校党支部承诺书
2015/04/30 职场文书
python turtle绘图
2022/05/04 Python