tensorflow实现softma识别MNIST


Posted in Python onMarch 12, 2018

识别MNIST已经成了深度学习的hello world,所以每次例程基本都会用到这个数据集,这个数据集在tensorflow内部用着很好的封装,因此可以方便地使用。

这次我们用tensorflow搭建一个softmax多分类器,和之前搭建线性回归差不多,第一步是通过确定变量建立图模型,然后确定误差函数,最后调用优化器优化。

误差函数与线性回归不同,这里因为是多分类问题,所以使用了交叉熵。

另外,有一点值得注意的是,这里构建模型时我试图想拆分多个函数,但是后来发现这样做难度很大,因为图是在规定变量就已经定义好的,不能随意拆分,也不能当做变量传来传去,因此需要将他们写在一起。

代码如下:

#encoding=utf-8 
__author__ = 'freedom' 
import tensorflow as tf 
 
def loadMNIST(): 
 from tensorflow.examples.tutorials.mnist import input_data 
 mnist = input_data.read_data_sets('MNIST_data',one_hot=True) 
 return mnist 
 
def softmax(mnist,rate=0.01,batchSize=50,epoch=20): 
 n = 784 # 向量的维度数目 
 m = None # 样本数,这里可以获取,也可以不获取 
 c = 10 # 类别数目 
 
 x = tf.placeholder(tf.float32,[m,n]) 
 y = tf.placeholder(tf.float32,[m,c]) 
 
 w = tf.Variable(tf.zeros([n,c])) 
 b = tf.Variable(tf.zeros([c])) 
 
 pred= tf.nn.softmax(tf.matmul(x,w)+b) 
 loss = tf.reduce_mean(-tf.reduce_sum(y*tf.log(pred),reduction_indices=1)) 
 opt = tf.train.GradientDescentOptimizer(rate).minimize(loss) 
 
 init = tf.initialize_all_variables() 
 
 sess = tf.Session() 
 sess.run(init) 
 for index in range(epoch): 
  avgLoss = 0 
  batchNum = int(mnist.train.num_examples/batchSize) 
  for batch in range(batchNum): 
   batch_x,batch_y = mnist.train.next_batch(batchSize) 
   _,Loss = sess.run([opt,loss],{x:batch_x,y:batch_y}) 
   avgLoss += Loss 
  avgLoss /= batchNum 
  print 'every epoch average loss is ',avgLoss 
 
 right = tf.equal(tf.argmax(pred,1),tf.argmax(y,1)) 
 accuracy = tf.reduce_mean(tf.cast(right,tf.float32)) 
 print 'Accracy is ',sess.run(accuracy,({x:mnist.test.images,y:mnist.test.labels})) 
 
 
if __name__ == "__main__": 
 mnist = loadMNIST() 
 softmax(mnist)

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
编写Python脚本来获取mp3文件tag信息的教程
May 04 Python
python中json格式数据输出的简单实现方法
Oct 31 Python
Python2.7基于淘宝接口获取IP地址所在地理位置的方法【测试可用】
Jun 07 Python
Python爬虫实现百度图片自动下载
Feb 04 Python
python3监控CentOS磁盘空间脚本
Jun 21 Python
python版本单链表实现代码
Sep 28 Python
对Python w和w+权限的区别详解
Jan 23 Python
python2使用bs4爬取腾讯社招过程解析
Aug 14 Python
python 两个数据库postgresql对比
Oct 21 Python
浅析python 动态库m.so.1.0错误问题
May 09 Python
keras 获取某层的输入/输出 tensor 尺寸操作
Jun 10 Python
Python requests及aiohttp速度对比代码实例
Jul 16 Python
wxpython实现图书管理系统
Mar 12 #Python
人生苦短我用python python如何快速入门?
Mar 12 #Python
tensorflow实现KNN识别MNIST
Mar 12 #Python
Python操作MySQL模拟银行转账
Mar 12 #Python
python3 图片referer防盗链的实现方法
Mar 12 #Python
tensorflow构建BP神经网络的方法
Mar 12 #Python
Python管理Windows服务小脚本
Mar 12 #Python
You might like
PHP漏洞全解(详细介绍)
2012/11/13 PHP
PHP中substr_count()函数获取子字符串出现次数的方法
2016/01/07 PHP
thinkPHP自动验证机制详解
2016/12/05 PHP
js+CSS 图片等比缩小并垂直居中实现代码
2008/12/01 Javascript
实现变速回到顶部的JavaScript代码
2011/05/09 Javascript
读jQuery之七 判断点击了鼠标哪个键的代码
2011/06/21 Javascript
jQuery表单获取和失去焦点输入框提示效果的实例代码
2013/08/01 Javascript
二叉树的非递归后序遍历算法实例详解
2014/02/07 Javascript
使用jquery.upload.js实现异步上传示例代码
2014/07/29 Javascript
JScript中的条件注释详解
2015/04/24 Javascript
jQuery实现带有洗牌效果的动画分页实例
2015/08/31 Javascript
jQuery简单验证上传文件大小及类型的方法
2016/06/02 Javascript
AngularJS 整理一些优化的小技巧
2016/08/18 Javascript
浅谈js数据类型判断与数组判断
2016/08/29 Javascript
基于JavaScript实现鼠标箭头移动图片跟着移动
2016/08/30 Javascript
详解vue之页面缓存问题(基于2.0)
2017/01/10 Javascript
详解vue跨组件通信的几种方法
2017/06/15 Javascript
Layui Form 自定义验证的实例代码
2019/09/14 Javascript
微信小程序新闻网站详情页实例代码
2020/01/10 Javascript
[02:00]DOTA2英雄COSPLAY闹市街头巡游助威2015国际邀请赛
2015/08/02 DOTA
python进程类subprocess的一些操作方法例子
2014/11/22 Python
Unicode和Python的中文处理
2017/03/19 Python
你真的了解Python的random模块吗?
2017/12/12 Python
python3使用smtplib实现发送邮件功能
2018/05/22 Python
pycharm 主题theme设置调整仿sublime的方法
2018/05/23 Python
wxPython的安装与使用教程
2018/08/31 Python
python设置环境变量的作用和实例
2019/07/09 Python
python解析yaml文件过程详解
2019/08/30 Python
Django使用消息提示简单的弹出个对话框实例
2019/11/15 Python
Python新建项目自动添加介绍和utf-8编码的方法
2020/12/26 Python
京东奢侈品:全球奢侈品牌
2018/03/17 全球购物
公司行政经理岗位职责
2013/12/24 职场文书
MySQL 角色(role)功能介绍
2021/04/24 MySQL
教你使用pyinstaller打包Python教程
2021/05/27 Python
使用numpy实现矩阵的翻转(flip)与旋转
2021/06/03 Python
javascript中Set、Map、WeakSet、WeakMap区别
2022/12/24 Javascript