tensorflow实现softma识别MNIST


Posted in Python onMarch 12, 2018

识别MNIST已经成了深度学习的hello world,所以每次例程基本都会用到这个数据集,这个数据集在tensorflow内部用着很好的封装,因此可以方便地使用。

这次我们用tensorflow搭建一个softmax多分类器,和之前搭建线性回归差不多,第一步是通过确定变量建立图模型,然后确定误差函数,最后调用优化器优化。

误差函数与线性回归不同,这里因为是多分类问题,所以使用了交叉熵。

另外,有一点值得注意的是,这里构建模型时我试图想拆分多个函数,但是后来发现这样做难度很大,因为图是在规定变量就已经定义好的,不能随意拆分,也不能当做变量传来传去,因此需要将他们写在一起。

代码如下:

#encoding=utf-8 
__author__ = 'freedom' 
import tensorflow as tf 
 
def loadMNIST(): 
 from tensorflow.examples.tutorials.mnist import input_data 
 mnist = input_data.read_data_sets('MNIST_data',one_hot=True) 
 return mnist 
 
def softmax(mnist,rate=0.01,batchSize=50,epoch=20): 
 n = 784 # 向量的维度数目 
 m = None # 样本数,这里可以获取,也可以不获取 
 c = 10 # 类别数目 
 
 x = tf.placeholder(tf.float32,[m,n]) 
 y = tf.placeholder(tf.float32,[m,c]) 
 
 w = tf.Variable(tf.zeros([n,c])) 
 b = tf.Variable(tf.zeros([c])) 
 
 pred= tf.nn.softmax(tf.matmul(x,w)+b) 
 loss = tf.reduce_mean(-tf.reduce_sum(y*tf.log(pred),reduction_indices=1)) 
 opt = tf.train.GradientDescentOptimizer(rate).minimize(loss) 
 
 init = tf.initialize_all_variables() 
 
 sess = tf.Session() 
 sess.run(init) 
 for index in range(epoch): 
  avgLoss = 0 
  batchNum = int(mnist.train.num_examples/batchSize) 
  for batch in range(batchNum): 
   batch_x,batch_y = mnist.train.next_batch(batchSize) 
   _,Loss = sess.run([opt,loss],{x:batch_x,y:batch_y}) 
   avgLoss += Loss 
  avgLoss /= batchNum 
  print 'every epoch average loss is ',avgLoss 
 
 right = tf.equal(tf.argmax(pred,1),tf.argmax(y,1)) 
 accuracy = tf.reduce_mean(tf.cast(right,tf.float32)) 
 print 'Accracy is ',sess.run(accuracy,({x:mnist.test.images,y:mnist.test.labels})) 
 
 
if __name__ == "__main__": 
 mnist = loadMNIST() 
 softmax(mnist)

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
在Django中编写模版节点及注册标签的方法
Jul 20 Python
轻松理解Python 中的 descriptor
Sep 15 Python
Python基于递归实现电话号码映射功能示例
Apr 13 Python
Python可变和不可变、类的私有属性实例分析
May 31 Python
python 实现将多条曲线画在一幅图上的方法
Jul 07 Python
python scipy卷积运算的实现方法
Sep 16 Python
Windows平台Python编程必会模块之pywin32介绍
Oct 01 Python
使用Python实现画一个中国地图
Nov 23 Python
python GUI库图形界面开发之PyQt5 UI主线程与耗时线程分离详细方法实例
Feb 26 Python
Django框架获取form表单数据方式总结
Apr 22 Python
python3通过udp实现组播数据的发送和接收操作
May 05 Python
详解Python中的GIL(全局解释器锁)详解及解决GIL的几种方案
Jan 29 Python
wxpython实现图书管理系统
Mar 12 #Python
人生苦短我用python python如何快速入门?
Mar 12 #Python
tensorflow实现KNN识别MNIST
Mar 12 #Python
Python操作MySQL模拟银行转账
Mar 12 #Python
python3 图片referer防盗链的实现方法
Mar 12 #Python
tensorflow构建BP神经网络的方法
Mar 12 #Python
Python管理Windows服务小脚本
Mar 12 #Python
You might like
JS实现php的伪分页
2008/05/25 PHP
ThinkPHP字符串函数及常用函数汇总
2014/07/18 PHP
ThinkPHP函数详解之M方法和R方法
2015/09/10 PHP
浅析php设计模式之数据对象映射模式
2016/03/03 PHP
php opendir()列出目录下所有文件的实例代码
2016/10/02 PHP
php输出控制函数和输出函数生成静态页面
2019/06/27 PHP
在js中单选框和复选框获取值的方式
2009/11/06 Javascript
jQuery 取值、赋值的基本方法整理
2014/03/31 Javascript
javascript中Array数组的迭代方法实例分析
2015/02/04 Javascript
js闭包引起的事件注册问题介绍
2016/03/29 Javascript
JavaScript 随机验证码的生成实例代码
2016/09/22 Javascript
手机软键盘弹出时影响布局的解决方法
2016/12/15 Javascript
使用JS 插件qrcode.js生成二维码功能
2017/02/20 Javascript
vue打包的时候自动将px转成rem的操作方法
2018/06/20 Javascript
bootstrap中的导航条实例代码详解
2019/05/20 Javascript
layer 关闭指定弹出层的例子
2019/09/25 Javascript
推荐几个不错的console调试技巧实现
2019/12/20 Javascript
python搭建简易服务器分析与实现
2012/12/15 Python
Python字符串特性及常用字符串方法的简单笔记
2016/01/04 Python
python生成器与迭代器详解
2019/01/01 Python
Tensorflow限制CPU个数实例
2020/02/06 Python
python中def是做什么的
2020/06/10 Python
python suds访问webservice服务实现
2020/06/26 Python
PyTorch实现重写/改写Dataset并载入Dataloader
2020/07/14 Python
Python趣味入门教程之循环语句while
2020/08/26 Python
DJI美国:消费类无人机领域的领导者
2018/04/27 全球购物
印尼在线旅游门户网站:NusaTrip
2019/11/01 全球购物
杠杆的科学教学反思
2014/01/10 职场文书
自我介绍演讲稿
2014/01/15 职场文书
优秀企业获奖感言
2014/02/01 职场文书
开学典礼主持词
2014/03/19 职场文书
2015感人爱情寄语
2015/02/26 职场文书
植物园观后感
2015/06/11 职场文书
导游词之临安白水涧
2019/11/05 职场文书
mysql批量新增和存储的方法实例
2021/04/07 MySQL
PYTHON 使用 Pandas 删除某列指定值所在的行
2022/04/28 Python