tensorflow实现softma识别MNIST


Posted in Python onMarch 12, 2018

识别MNIST已经成了深度学习的hello world,所以每次例程基本都会用到这个数据集,这个数据集在tensorflow内部用着很好的封装,因此可以方便地使用。

这次我们用tensorflow搭建一个softmax多分类器,和之前搭建线性回归差不多,第一步是通过确定变量建立图模型,然后确定误差函数,最后调用优化器优化。

误差函数与线性回归不同,这里因为是多分类问题,所以使用了交叉熵。

另外,有一点值得注意的是,这里构建模型时我试图想拆分多个函数,但是后来发现这样做难度很大,因为图是在规定变量就已经定义好的,不能随意拆分,也不能当做变量传来传去,因此需要将他们写在一起。

代码如下:

#encoding=utf-8 
__author__ = 'freedom' 
import tensorflow as tf 
 
def loadMNIST(): 
 from tensorflow.examples.tutorials.mnist import input_data 
 mnist = input_data.read_data_sets('MNIST_data',one_hot=True) 
 return mnist 
 
def softmax(mnist,rate=0.01,batchSize=50,epoch=20): 
 n = 784 # 向量的维度数目 
 m = None # 样本数,这里可以获取,也可以不获取 
 c = 10 # 类别数目 
 
 x = tf.placeholder(tf.float32,[m,n]) 
 y = tf.placeholder(tf.float32,[m,c]) 
 
 w = tf.Variable(tf.zeros([n,c])) 
 b = tf.Variable(tf.zeros([c])) 
 
 pred= tf.nn.softmax(tf.matmul(x,w)+b) 
 loss = tf.reduce_mean(-tf.reduce_sum(y*tf.log(pred),reduction_indices=1)) 
 opt = tf.train.GradientDescentOptimizer(rate).minimize(loss) 
 
 init = tf.initialize_all_variables() 
 
 sess = tf.Session() 
 sess.run(init) 
 for index in range(epoch): 
  avgLoss = 0 
  batchNum = int(mnist.train.num_examples/batchSize) 
  for batch in range(batchNum): 
   batch_x,batch_y = mnist.train.next_batch(batchSize) 
   _,Loss = sess.run([opt,loss],{x:batch_x,y:batch_y}) 
   avgLoss += Loss 
  avgLoss /= batchNum 
  print 'every epoch average loss is ',avgLoss 
 
 right = tf.equal(tf.argmax(pred,1),tf.argmax(y,1)) 
 accuracy = tf.reduce_mean(tf.cast(right,tf.float32)) 
 print 'Accracy is ',sess.run(accuracy,({x:mnist.test.images,y:mnist.test.labels})) 
 
 
if __name__ == "__main__": 
 mnist = loadMNIST() 
 softmax(mnist)

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
9种python web 程序的部署方式小结
Jun 30 Python
Python中计算三角函数之cos()方法的使用简介
May 15 Python
Python安装第三方库的3种方法
Jun 21 Python
python使用matplotlib绘制折线图教程
Feb 08 Python
彻底理解Python中的yield关键字
Apr 01 Python
Python 用matplotlib画以时间日期为x轴的图像
Aug 06 Python
django-rest-swagger对API接口注释的方法
Aug 29 Python
python 利用已有Ner模型进行数据清洗合并代码
Dec 24 Python
解决pycharm debug时界面下方不出现step等按钮及变量值的问题
Jun 09 Python
Python调用C语言程序方法解析
Jul 07 Python
Numpy中的数组搜索中np.where方法详细介绍
Jan 08 Python
python字典的元素访问实例详解
Jul 21 Python
wxpython实现图书管理系统
Mar 12 #Python
人生苦短我用python python如何快速入门?
Mar 12 #Python
tensorflow实现KNN识别MNIST
Mar 12 #Python
Python操作MySQL模拟银行转账
Mar 12 #Python
python3 图片referer防盗链的实现方法
Mar 12 #Python
tensorflow构建BP神经网络的方法
Mar 12 #Python
Python管理Windows服务小脚本
Mar 12 #Python
You might like
php 批量查询搜狗sogou代码分享
2015/05/17 PHP
十个迅速提升JQuery性能让你的JQuery跑得更快
2012/12/10 Javascript
jquery购物车实时结算特效实现思路
2013/09/23 Javascript
javascript模拟订火车票和退票示例
2014/04/24 Javascript
javascript进行数组追加方法小结
2014/06/16 Javascript
JavaScript bold方法入门实例(把指定文字显示为粗体)
2014/10/17 Javascript
使用 js+正则表达式为关键词添加链接
2014/11/11 Javascript
JS实现的4种数字千位符格式化方法分享
2015/03/02 Javascript
深入理解JavaScript系列(33):设计模式之策略模式详解
2015/03/03 Javascript
JavaScript动态添加style节点的方法
2015/06/09 Javascript
轻松使用jQuery双向select控件Bootstrap Dual Listbox
2015/12/13 Javascript
学习JavaScript设计模式之装饰者模式
2016/01/19 Javascript
微信小程序 动态绑定事件并实现事件修改样式
2017/04/13 Javascript
Vue 2.0的数据依赖实现原理代码简析
2017/07/10 Javascript
监听angularJs列表数据是否渲染完毕的方法示例
2018/11/07 Javascript
Cocos2d实现刮刮卡效果
2018/12/20 Javascript
如何在Angular应用中创建包含组件方法示例
2019/03/23 Javascript
JS 遍历 json 和 JQuery 遍历json操作完整示例
2019/11/11 jQuery
Vue 嵌套路由使用总结(推荐)
2020/01/13 Javascript
解决vue的router组件component在import时不能使用变量问题
2020/07/26 Javascript
python并发编程之线程实例解析
2017/12/27 Python
便捷提取python导入包的属性方法
2018/10/15 Python
Python Django框架模板渲染功能示例
2019/11/08 Python
利用Python的sympy包求解一元三次方程示例
2019/11/22 Python
使用Django和Postgres进行全文搜索的实例代码
2020/02/13 Python
SCDKey德国:全球领先的数字游戏市场
2019/04/09 全球购物
什么是lambda函数
2013/09/17 面试题
医院护士求职自荐信格式
2013/09/21 职场文书
护士进修自我鉴定
2014/02/07 职场文书
物流业务员岗位职责
2014/02/08 职场文书
企业办公室主任岗位职责
2014/02/19 职场文书
读群众路线心得体会
2014/03/07 职场文书
2014年干部培训工作总结
2014/12/17 职场文书
信仰纪录片观后感
2015/06/08 职场文书
2016年社区六一儿童节活动总结
2016/04/06 职场文书
Javascript设计模式之原型模式详细
2021/10/05 Javascript