tensorflow实现逻辑回归模型


Posted in Python onSeptember 08, 2018

逻辑回归模型

逻辑回归是应用非常广泛的一个分类机器学习算法,它将数据拟合到一个logit函数(或者叫做logistic函数)中,从而能够完成对事件发生的概率进行预测。

import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
from tensorflow.examples.tutorials.mnist import input_data
#下载好的mnist数据集存在F:/mnist/data/中
mnist = input_data.read_data_sets('F:/mnist/data/',one_hot = True)
print(mnist.train.num_examples)
print(mnist.test.num_examples)

trainimg = mnist.train.images
trainlabel = mnist.train.labels
testimg = mnist.test.images
testlabel = mnist.test.labels

print(type(trainimg))
print(trainimg.shape,)
print(trainlabel.shape,)
print(testimg.shape,)
print(testlabel.shape,)

nsample = 5
randidx = np.random.randint(trainimg.shape[0],size = nsample)

for i in randidx:
  curr_img = np.reshape(trainimg[i,:],(28,28))
  curr_label = np.argmax(trainlabel[i,:])
  plt.matshow(curr_img,cmap=plt.get_cmap('gray'))
  plt.title(""+str(i)+"th Training Data"+"label is"+str(curr_label))
  print(""+str(i)+"th Training Data"+"label is"+str(curr_label))
  plt.show()


x = tf.placeholder("float",[None,784])
y = tf.placeholder("float",[None,10])
W = tf.Variable(tf.zeros([784,10]))
b = tf.Variable(tf.zeros([10]))

#
actv = tf.nn.softmax(tf.matmul(x,W)+b)
#计算损失
cost = tf.reduce_mean(-tf.reduce_sum(y*tf.log(actv),reduction_indices=1))
#学习率
learning_rate = 0.01
#随机梯度下降
optm = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost)

#求1位置索引值 对比预测值索引与label索引是否一样,一样返回True
pred = tf.equal(tf.argmax(actv,1),tf.argmax(y,1))
#tf.cast把True和false转换为float类型 0,1
#把所有预测结果加在一起求精度
accr = tf.reduce_mean(tf.cast(pred,"float"))
init = tf.global_variables_initializer()
"""
#测试代码 
sess = tf.InteractiveSession()
arr = np.array([[31,23,4,24,27,34],[18,3,25,4,5,6],[4,3,2,1,5,67]])
#返回数组的维数 2
print(tf.rank(arr).eval())
#返回数组的行列数 [3 6]
print(tf.shape(arr).eval())
#返回数组中每一列中最大元素的索引[0 0 1 0 0 2]
print(tf.argmax(arr,0).eval())
#返回数组中每一行中最大元素的索引[5 2 5]
print(tf.argmax(arr,1).eval()) 
J"""
#把所有样本迭代50次
training_epochs = 50
#每次迭代选择多少样本
batch_size = 100
display_step = 5

sess = tf.Session()
sess.run(init)

#循环迭代
for epoch in range(training_epochs):
  avg_cost = 0
  num_batch = int(mnist.train.num_examples/batch_size)
  for i in range(num_batch):
    batch_xs,batch_ys = mnist.train.next_batch(batch_size)
    sess.run(optm,feed_dict = {x:batch_xs,y:batch_ys})
    feeds = {x:batch_xs,y:batch_ys}
    avg_cost += sess.run(cost,feed_dict = feeds)/num_batch

  if epoch % display_step ==0:
    feeds_train = {x:batch_xs,y:batch_ys}
    feeds_test = {x:mnist.test.images,y:mnist.test.labels}
    train_acc = sess.run(accr,feed_dict = feeds_train)
    test_acc = sess.run(accr,feed_dict = feeds_test)
    #每五个epoch打印一次信息
    print("Epoch:%03d/%03d cost:%.9f train_acc:%.3f test_acc: %.3f" %(epoch,training_epochs,avg_cost,train_acc,test_acc))

print("Done")

程序训练结果如下:

Epoch:000/050 cost:1.177228655 train_acc:0.800 test_acc: 0.855
Epoch:005/050 cost:0.440933891 train_acc:0.890 test_acc: 0.894
Epoch:010/050 cost:0.383387268 train_acc:0.930 test_acc: 0.905
Epoch:015/050 cost:0.357281335 train_acc:0.930 test_acc: 0.909
Epoch:020/050 cost:0.341473956 train_acc:0.890 test_acc: 0.913
Epoch:025/050 cost:0.330586549 train_acc:0.920 test_acc: 0.915
Epoch:030/050 cost:0.322370980 train_acc:0.870 test_acc: 0.916
Epoch:035/050 cost:0.315942993 train_acc:0.940 test_acc: 0.916
Epoch:040/050 cost:0.310728854 train_acc:0.890 test_acc: 0.917
Epoch:045/050 cost:0.306357428 train_acc:0.870 test_acc: 0.918
Done

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python 切片和range()用法说明
Mar 24 Python
Python使用py2exe打包程序介绍
Nov 20 Python
Python中Random和Math模块学习笔记
May 18 Python
Python爬虫辅助利器PyQuery模块的安装使用攻略
Apr 24 Python
python验证码识别的实例详解
Sep 09 Python
Python基于Pymssql模块实现连接SQL Server数据库的方法详解
Jul 20 Python
用python实现的线程池实例代码
Jan 06 Python
python 移除字符串尾部的数字方法
Jul 17 Python
Python3实现计算两个数组的交集算法示例
Apr 03 Python
Python 函数list&read&seek详解
Aug 28 Python
Python3+Selenium+Chrome实现自动填写WPS表单
Feb 12 Python
Python 实现劳拉游戏的实例代码(四连环、重力四子棋)
Mar 03 Python
Django实现表单验证
Sep 08 #Python
python实现排序算法解析
Sep 08 #Python
TensorFlow实现Logistic回归
Sep 07 #Python
tensorflow实现简单逻辑回归
Sep 07 #Python
Tensorflow使用支持向量机拟合线性回归
Sep 07 #Python
TensorFlow实现iris数据集线性回归
Sep 07 #Python
TensorFlow实现模型评估
Sep 07 #Python
You might like
令PHP初学者头疼十四条问题大总结
2008/11/12 PHP
PHP最常用的ini函数分析 针对PHP.ini配置文件
2010/04/22 PHP
PHP人民币金额数字转中文大写的函数代码
2013/02/27 PHP
如何利用PHP执行.SQL文件
2013/07/05 PHP
PHP文件及文件夹操作之创建、删除、移动、复制
2016/07/13 PHP
laravel框架之数据库查出来的对象实现转化为数组
2019/10/23 PHP
IE6,IE7下js动态加载图片不显示错误
2010/07/17 Javascript
js string 转 int 注意的问题小结
2013/08/15 Javascript
JS测试显示屏分辨率以及屏幕尺寸的方法
2013/11/22 Javascript
js闭包的用途详解
2014/11/09 Javascript
小巧强大的jquery layer弹窗弹层插件
2015/12/06 Javascript
Node.js利用js-xlsx处理Excel文件的方法详解
2017/07/05 Javascript
Three.js利用性能插件stats实现性能监听的方法
2017/09/25 Javascript
Vue 框架之键盘事件、健值修饰符、双向数据绑定
2018/11/14 Javascript
vue项目配置 webpack-obfuscator 进行代码加密混淆的实现
2021/02/26 Vue.js
[03:37]2016完美“圣”典 风云人物:Mikasa专访
2016/12/07 DOTA
Windows下为Python安装Matplotlib模块
2015/11/06 Python
Python数据结构与算法之使用队列解决小猫钓鱼问题
2017/12/14 Python
Python使用Scrapy爬虫框架全站爬取图片并保存本地的实现代码
2018/03/04 Python
pygame游戏之旅 计算游戏中躲过的障碍数量
2018/11/20 Python
对Python 语音识别框架详解
2018/12/24 Python
linux查找当前python解释器的位置方法
2019/02/20 Python
Python 调用 Outlook 发送邮件过程解析
2019/08/08 Python
python实现ftp文件传输功能
2020/03/20 Python
浅谈keras中Dropout在预测过程中是否仍要起作用
2020/07/09 Python
详解HTML5中的元素与元素
2015/08/17 HTML / CSS
MONNIER Frères英国官网:源自巴黎女士奢侈品配饰电商平台
2018/12/06 全球购物
为什么要使用servlet
2016/01/17 面试题
网页设计个人找工作求职信
2013/11/28 职场文书
小学安全教育材料
2014/02/17 职场文书
居委会个人对照检查材料思想汇报
2014/09/29 职场文书
二手车交易协议书标准版
2014/11/16 职场文书
优秀少先队辅导员事迹材料
2014/12/24 职场文书
中学生社区服务活动报告
2015/02/05 职场文书
2015年安全月活动总结
2015/03/26 职场文书
为什么 Nginx 比 Apache 更牛逼
2021/03/31 Servers