tensorflow实现softma识别MNIST


Posted in Python onMarch 12, 2018

识别MNIST已经成了深度学习的hello world,所以每次例程基本都会用到这个数据集,这个数据集在tensorflow内部用着很好的封装,因此可以方便地使用。

这次我们用tensorflow搭建一个softmax多分类器,和之前搭建线性回归差不多,第一步是通过确定变量建立图模型,然后确定误差函数,最后调用优化器优化。

误差函数与线性回归不同,这里因为是多分类问题,所以使用了交叉熵。

另外,有一点值得注意的是,这里构建模型时我试图想拆分多个函数,但是后来发现这样做难度很大,因为图是在规定变量就已经定义好的,不能随意拆分,也不能当做变量传来传去,因此需要将他们写在一起。

代码如下:

#encoding=utf-8 
__author__ = 'freedom' 
import tensorflow as tf 
 
def loadMNIST(): 
 from tensorflow.examples.tutorials.mnist import input_data 
 mnist = input_data.read_data_sets('MNIST_data',one_hot=True) 
 return mnist 
 
def softmax(mnist,rate=0.01,batchSize=50,epoch=20): 
 n = 784 # 向量的维度数目 
 m = None # 样本数,这里可以获取,也可以不获取 
 c = 10 # 类别数目 
 
 x = tf.placeholder(tf.float32,[m,n]) 
 y = tf.placeholder(tf.float32,[m,c]) 
 
 w = tf.Variable(tf.zeros([n,c])) 
 b = tf.Variable(tf.zeros([c])) 
 
 pred= tf.nn.softmax(tf.matmul(x,w)+b) 
 loss = tf.reduce_mean(-tf.reduce_sum(y*tf.log(pred),reduction_indices=1)) 
 opt = tf.train.GradientDescentOptimizer(rate).minimize(loss) 
 
 init = tf.initialize_all_variables() 
 
 sess = tf.Session() 
 sess.run(init) 
 for index in range(epoch): 
  avgLoss = 0 
  batchNum = int(mnist.train.num_examples/batchSize) 
  for batch in range(batchNum): 
   batch_x,batch_y = mnist.train.next_batch(batchSize) 
   _,Loss = sess.run([opt,loss],{x:batch_x,y:batch_y}) 
   avgLoss += Loss 
  avgLoss /= batchNum 
  print 'every epoch average loss is ',avgLoss 
 
 right = tf.equal(tf.argmax(pred,1),tf.argmax(y,1)) 
 accuracy = tf.reduce_mean(tf.cast(right,tf.float32)) 
 print 'Accracy is ',sess.run(accuracy,({x:mnist.test.images,y:mnist.test.labels})) 
 
 
if __name__ == "__main__": 
 mnist = loadMNIST() 
 softmax(mnist)

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python中使用sys模板和logging模块获取行号和函数名的方法
Apr 15 Python
pandas 小数位数 精度的处理方法
Jun 09 Python
详解python中的json和字典dict
Jun 22 Python
python实现一个简单的udp通信的示例代码
Feb 01 Python
python中aioysql(异步操作MySQL)的方法
Apr 11 Python
python Tkinter的图片刷新实例
Jun 14 Python
django框架模型层功能、组成与用法分析
Jul 30 Python
Python实现TCP通信的示例代码
Sep 09 Python
Python大数据之使用lxml库解析html网页文件示例
Nov 16 Python
基于python SMTP实现自动发送邮件教程解析
Jun 02 Python
Python并发请求下限制QPS(每秒查询率)的实现代码
Jun 05 Python
使用python向MongoDB插入时间字段的操作
May 18 Python
wxpython实现图书管理系统
Mar 12 #Python
人生苦短我用python python如何快速入门?
Mar 12 #Python
tensorflow实现KNN识别MNIST
Mar 12 #Python
Python操作MySQL模拟银行转账
Mar 12 #Python
python3 图片referer防盗链的实现方法
Mar 12 #Python
tensorflow构建BP神经网络的方法
Mar 12 #Python
Python管理Windows服务小脚本
Mar 12 #Python
You might like
smarty+adodb+部分自定义类的php开发模式
2006/12/31 PHP
使用PHP获取汉字的拼音(全部与首字母)
2013/06/27 PHP
php从字符串创建函数的方法
2015/03/16 PHP
PHP基于cookie与session统计网站访问量并输出显示的方法
2016/01/15 PHP
PHP实现163邮箱自动发送邮件
2016/03/29 PHP
php获取flash尺寸详细数据的方法
2016/11/12 PHP
Laravel 中使用 Vue.js 实现基于 Ajax 的表单提交错误验证操作
2017/06/30 PHP
php封装实现钉钉机器人报警接口的示例代码
2020/08/08 PHP
prototype.js的Ajax对象
2006/09/23 Javascript
jquery操作cookie插件分享
2014/01/14 Javascript
用Jquery.load载入页面实现局部刷新
2014/01/22 Javascript
JQuery为页面Dom元素绑定事件及解除绑定方法
2014/04/23 Javascript
详解vue跨组件通信的几种方法
2017/06/15 Javascript
对于js垃圾回收机制的理解
2017/09/14 Javascript
jQuery实现的表格前端排序功能示例
2017/09/18 jQuery
javascript中神奇的 Date对象小结
2017/10/12 Javascript
JS中判断某个字符串是否包含另一个字符串的五种方法
2018/05/03 Javascript
基于vue+echarts 数据可视化大屏展示的方法示例
2020/03/09 Javascript
全面解析Vue中的$nextTick
2020/12/24 Vue.js
[02:14]完美“圣”典2016风云人物:xiao8专访
2016/12/01 DOTA
[49:43]VG vs FNATIC 2019国际邀请赛小组赛 BO2 第一场 8.15
2019/08/17 DOTA
pip安装时ReadTimeoutError的解决方法
2018/06/12 Python
python安装twisted的问题解析
2018/08/21 Python
Python利用sqlacodegen自动生成ORM实体类示例
2019/06/04 Python
Python创建临时文件和文件夹
2020/08/05 Python
Zatchels官网:英国剑桥包品牌
2021/01/12 全球购物
String和StringBuffer的区别
2015/08/13 面试题
幼儿运动会邀请函
2014/01/17 职场文书
创业计划书的写作技巧及要点
2014/01/31 职场文书
弘扬雷锋精神演讲稿
2014/05/10 职场文书
优秀共青团员事迹材料
2014/12/25 职场文书
证婚人致辞精选
2015/07/28 职场文书
大学生安全教育心得体会
2016/01/15 职场文书
一篇文章学会Vue中间件管道
2021/06/20 Vue.js
vmware虚拟机打不开vmx文件怎么办 ?vmware虚拟机vmx文件打开方法
2022/04/08 数码科技
Springboot-cli 开发脚手架,权限认证,附demo演示
2022/04/28 Java/Android