Python tensorflow实现mnist手写数字识别示例【非卷积与卷积实现】


Posted in Python onDecember 19, 2019

本文实例讲述了Python tensorflow实现mnist手写数字识别。分享给大家供大家参考,具体如下:

非卷积实现

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
data_path = 'F:\CNN\data\mnist'
mnist_data = input_data.read_data_sets(data_path,one_hot=True) #offline dataset
x_data = tf.placeholder("float32", [None, 784]) # None means we can import any number of images
weight = tf.Variable(tf.ones([784,10]))
bias = tf.Variable(tf.ones([10]))
Y_model = tf.nn.softmax(tf.matmul(x_data ,weight) + bias)
#Y_model = tf.nn.sigmoid(tf.matmul(x_data ,weight) + bias)
'''
weight1 = tf.Variable(tf.ones([784,256]))
bias1 = tf.Variable(tf.ones([256]))
Y_model1 = tf.nn.softmax(tf.matmul(x_data ,weight1) + bias1)
weight1 = tf.Variable(tf.ones([256,10]))
bias1 = tf.Variable(tf.ones([10]))
Y_model = tf.nn.softmax(tf.matmul(Y_model1 ,weight1) + bias1)
'''
y_data = tf.placeholder("float32", [None, 10])
loss = tf.reduce_sum(tf.pow((y_data - Y_model), 2 ))#92%-93%
#loss = tf.reduce_sum(tf.square(y_data - Y_model)) #90%-91%
optimizer = tf.train.GradientDescentOptimizer(0.01)
train = optimizer.minimize(loss)
init = tf.global_variables_initializer()
sess = tf.Session()
sess.run(init) # reset values to wrong
for i in range(100000):
  batch_xs, batch_ys = mnist_data.train.next_batch(50)
  sess.run(train, feed_dict = {x_data: batch_xs, y_data: batch_ys})
  if i%50==0:
    correct_predict = tf.equal(tf.arg_max(Y_model,1),tf.argmax(y_data,1))
    accurate = tf.reduce_mean(tf.cast(correct_predict,"float"))
    print(sess.run(accurate,feed_dict={x_data:mnist_data.test.images,y_data:mnist_data.test.labels}))

卷积实现

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
data_path = 'F:\CNN\data\mnist'
mnist_data = input_data.read_data_sets(data_path,one_hot=True) #offline dataset
x_data = tf.placeholder("float32", [None, 784]) # None means we can import any number of images
x_image = tf.reshape(x_data, [-1,28,28,1])
w_conv = tf.Variable(tf.ones([5,5,1,32])) #weight
b_conv = tf.Variable(tf.ones([32]))    #bias
h_conv = tf.nn.relu(tf.nn.conv2d(x_image , w_conv,strides=[1,1,1,1],padding='SAME')+ b_conv)
h_pool = tf.nn.max_pool(h_conv,ksize=[1,2,2,1],strides=[1,2,2,1],padding='SAME')
w_fc = tf.Variable(tf.ones([14*14*32,1024]))
b_fc = tf.Variable(tf.ones([1024]))
h_pool_flat = tf.reshape(h_pool,[-1,14*14*32])
h_fc = tf.nn.relu(tf.matmul(h_pool_flat,w_fc) +b_fc)
W_fc = w_fc = tf.Variable(tf.ones([1024,10]))
B_fc = tf.Variable(tf.ones([10]))
Y_model = tf.nn.softmax(tf.matmul(h_fc,W_fc) +B_fc)
y_data = tf.placeholder("float32",[None,10])
loss = -tf.reduce_sum(y_data * tf.log(Y_model))
train_step = tf.train.GradientDescentOptimizer(0.01).minimize(loss)
init = tf.initialize_all_variables()
sess = tf.Session()
sess.run(init)
for i in range(1000):
  batch_xs,batch_ys =mnist_data.train.next_batch(5)
  sess.run(train_step,feed_dict={x_data:batch_xs,y_data:batch_ys})
  if i%50==0:
    correct_prediction = tf.equal(tf.argmax(Y_model,1),tf.argmax(y_data,1))
    accuracy = tf.reduce_mean(tf.cast(correct_prediction,"float"))
    print(sess.run(accuracy,feed_dict={x_data:mnist_data.test.images,y_data:mnist_data.test.labels}))

更多关于Python相关内容可查看本站专题:《Python数学运算技巧总结》、《Python图片操作技巧总结》、《Python数据结构与算法教程》、《Python函数使用技巧总结》、《Python字符串操作技巧汇总》及《Python入门与进阶经典教程》

希望本文所述对大家Python程序设计有所帮助。

Python 相关文章推荐
netbeans7安装python插件的方法图解
Dec 24 Python
简单介绍Python中的JSON使用
Apr 28 Python
Python 获取当前所在目录的方法详解
Aug 02 Python
解决python nohup linux 后台运行输出的问题
May 11 Python
python对list中的每个元素进行某种操作的方法
Jun 29 Python
Python3 修改默认环境的方法
Feb 16 Python
Python中的引用知识点总结
May 20 Python
python操作excel让工作自动化
Aug 09 Python
Django 拆分model和view的实现方法
Aug 16 Python
python爬虫 基于requests模块的get请求实现详解
Aug 20 Python
Python cookie的保存与读取、SSL讲解
Feb 17 Python
python实现双链表
May 25 Python
Python: 传递列表副本方式
Dec 19 #Python
python内置模块collections知识点总结
Dec 19 #Python
Python操作redis和mongoDB的方法
Dec 19 #Python
Python 实现Serial 与STM32J进行串口通讯
Dec 18 #Python
实现Python与STM32通信方式
Dec 18 #Python
利用pandas将非数值数据转换成数值的方式
Dec 18 #Python
python 浅谈serial与stm32通信的编码问题
Dec 18 #Python
You might like
php minixml详解
2008/07/19 PHP
PHP的serialize序列化数据以及JSON格式化数据分析
2015/10/10 PHP
php中序列化与反序列化详解
2017/02/13 PHP
Aster vs KG BO3 第三场2.18
2021/03/10 DOTA
基于jQuery实现仿淘宝套餐选择插件
2015/03/04 Javascript
使用plupload自定义参数实现多文件上传
2016/07/19 Javascript
基于Javascript倒计时效果
2016/12/22 Javascript
微信小程序 ES6Promise.all批量上传文件实现代码
2017/04/14 Javascript
通过 JS 判断页面是否有滚动条的实现方法
2018/04/05 Javascript
Vue中的Props(不可变状态)
2018/09/29 Javascript
详解javascript设计模式三:代理模式
2019/03/25 Javascript
原生js实现购物车功能
2020/09/23 Javascript
js实现鼠标拖曳效果
2020/12/30 Javascript
微信小程序抽奖组件的使用步骤
2021/01/11 Javascript
python url 参数修改方法
2018/12/26 Python
python利用百度云接口实现车牌识别的示例
2020/02/21 Python
Python unittest单元测试框架实现参数化
2020/04/29 Python
python打开文件的方式有哪些
2020/06/29 Python
pandas针对excel处理的实现
2021/01/15 Python
用python制作个音乐下载器
2021/01/30 Python
pandas统计重复值次数的方法实现
2021/02/20 Python
波兰最大的宠物用品网上商店:FERA.PL
2019/08/11 全球购物
公司人力资源的自我评价
2014/01/02 职场文书
酒店总经理欢迎词
2014/01/15 职场文书
关于保护环境的标语
2014/06/09 职场文书
中华魂放飞梦想演讲稿
2014/08/26 职场文书
党的群众路线批评与自我批评发言稿
2014/10/16 职场文书
讲座开场白台词和结束语
2015/05/29 职场文书
2015年女工委工作总结
2015/07/27 职场文书
反四风问题学习心得体会
2016/01/22 职场文书
Golang: 内建容器的用法
2021/05/05 Golang
Python基础之条件语句详解
2021/06/16 Python
Linux系统下安装PHP7.3版本
2021/06/26 PHP
Java面试题冲刺第十七天--基础篇3
2021/08/07 面试题
《群青的幻想曲》京力秋树角色PV公开
2022/04/08 日漫
i5-10400f处理相当于i7多少水平
2022/04/19 数码科技