tensorflow实现逻辑回归模型


Posted in Python onSeptember 08, 2018

逻辑回归模型

逻辑回归是应用非常广泛的一个分类机器学习算法,它将数据拟合到一个logit函数(或者叫做logistic函数)中,从而能够完成对事件发生的概率进行预测。

import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
from tensorflow.examples.tutorials.mnist import input_data
#下载好的mnist数据集存在F:/mnist/data/中
mnist = input_data.read_data_sets('F:/mnist/data/',one_hot = True)
print(mnist.train.num_examples)
print(mnist.test.num_examples)

trainimg = mnist.train.images
trainlabel = mnist.train.labels
testimg = mnist.test.images
testlabel = mnist.test.labels

print(type(trainimg))
print(trainimg.shape,)
print(trainlabel.shape,)
print(testimg.shape,)
print(testlabel.shape,)

nsample = 5
randidx = np.random.randint(trainimg.shape[0],size = nsample)

for i in randidx:
  curr_img = np.reshape(trainimg[i,:],(28,28))
  curr_label = np.argmax(trainlabel[i,:])
  plt.matshow(curr_img,cmap=plt.get_cmap('gray'))
  plt.title(""+str(i)+"th Training Data"+"label is"+str(curr_label))
  print(""+str(i)+"th Training Data"+"label is"+str(curr_label))
  plt.show()


x = tf.placeholder("float",[None,784])
y = tf.placeholder("float",[None,10])
W = tf.Variable(tf.zeros([784,10]))
b = tf.Variable(tf.zeros([10]))

#
actv = tf.nn.softmax(tf.matmul(x,W)+b)
#计算损失
cost = tf.reduce_mean(-tf.reduce_sum(y*tf.log(actv),reduction_indices=1))
#学习率
learning_rate = 0.01
#随机梯度下降
optm = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost)

#求1位置索引值 对比预测值索引与label索引是否一样,一样返回True
pred = tf.equal(tf.argmax(actv,1),tf.argmax(y,1))
#tf.cast把True和false转换为float类型 0,1
#把所有预测结果加在一起求精度
accr = tf.reduce_mean(tf.cast(pred,"float"))
init = tf.global_variables_initializer()
"""
#测试代码 
sess = tf.InteractiveSession()
arr = np.array([[31,23,4,24,27,34],[18,3,25,4,5,6],[4,3,2,1,5,67]])
#返回数组的维数 2
print(tf.rank(arr).eval())
#返回数组的行列数 [3 6]
print(tf.shape(arr).eval())
#返回数组中每一列中最大元素的索引[0 0 1 0 0 2]
print(tf.argmax(arr,0).eval())
#返回数组中每一行中最大元素的索引[5 2 5]
print(tf.argmax(arr,1).eval()) 
J"""
#把所有样本迭代50次
training_epochs = 50
#每次迭代选择多少样本
batch_size = 100
display_step = 5

sess = tf.Session()
sess.run(init)

#循环迭代
for epoch in range(training_epochs):
  avg_cost = 0
  num_batch = int(mnist.train.num_examples/batch_size)
  for i in range(num_batch):
    batch_xs,batch_ys = mnist.train.next_batch(batch_size)
    sess.run(optm,feed_dict = {x:batch_xs,y:batch_ys})
    feeds = {x:batch_xs,y:batch_ys}
    avg_cost += sess.run(cost,feed_dict = feeds)/num_batch

  if epoch % display_step ==0:
    feeds_train = {x:batch_xs,y:batch_ys}
    feeds_test = {x:mnist.test.images,y:mnist.test.labels}
    train_acc = sess.run(accr,feed_dict = feeds_train)
    test_acc = sess.run(accr,feed_dict = feeds_test)
    #每五个epoch打印一次信息
    print("Epoch:%03d/%03d cost:%.9f train_acc:%.3f test_acc: %.3f" %(epoch,training_epochs,avg_cost,train_acc,test_acc))

print("Done")

程序训练结果如下:

Epoch:000/050 cost:1.177228655 train_acc:0.800 test_acc: 0.855
Epoch:005/050 cost:0.440933891 train_acc:0.890 test_acc: 0.894
Epoch:010/050 cost:0.383387268 train_acc:0.930 test_acc: 0.905
Epoch:015/050 cost:0.357281335 train_acc:0.930 test_acc: 0.909
Epoch:020/050 cost:0.341473956 train_acc:0.890 test_acc: 0.913
Epoch:025/050 cost:0.330586549 train_acc:0.920 test_acc: 0.915
Epoch:030/050 cost:0.322370980 train_acc:0.870 test_acc: 0.916
Epoch:035/050 cost:0.315942993 train_acc:0.940 test_acc: 0.916
Epoch:040/050 cost:0.310728854 train_acc:0.890 test_acc: 0.917
Epoch:045/050 cost:0.306357428 train_acc:0.870 test_acc: 0.918
Done

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python合并多个装饰器小技巧
Apr 28 Python
PyCharm使用教程之搭建Python开发环境
Jun 07 Python
python 添加用户设置密码并发邮件给root用户
Jul 25 Python
python扫描proxy并获取可用代理ip的实例
Aug 07 Python
python中模块的__all__属性详解
Oct 26 Python
python 使用 requests 模块发送http请求 的方法
Dec 09 Python
对Python中画图时候的线类型详解
Jul 07 Python
Python3enumrate和range对比及示例详解
Jul 13 Python
Django MEDIA的配置及用法详解
Jul 25 Python
PyQt5基本控件使用之消息弹出、用户输入、文件对话框的使用方法
Aug 06 Python
python中sys模块是做什么用的
Aug 16 Python
python3从网络摄像机解析mjpeg http流的示例
Nov 13 Python
Django实现表单验证
Sep 08 #Python
python实现排序算法解析
Sep 08 #Python
TensorFlow实现Logistic回归
Sep 07 #Python
tensorflow实现简单逻辑回归
Sep 07 #Python
Tensorflow使用支持向量机拟合线性回归
Sep 07 #Python
TensorFlow实现iris数据集线性回归
Sep 07 #Python
TensorFlow实现模型评估
Sep 07 #Python
You might like
php pki加密技术(openssl)详解
2013/07/01 PHP
php获取远程文件内容的函数
2015/11/02 PHP
PHP数组实际占用内存大小原理解析
2020/12/11 PHP
BOOM vs RR BO5 第一场 2.14
2021/03/10 DOTA
javascript prototype,executing,context,closure
2008/12/24 Javascript
jquery实现盒子下拉效果示例代码
2013/09/12 Javascript
JS打字效果的动态菜单代码分享
2015/08/21 Javascript
详解axios在node.js中的post使用
2017/04/27 Javascript
Angular2整合其他插件的方法
2018/01/20 Javascript
angular4强制刷新视图的方法
2018/10/09 Javascript
微信小程序环境下将文件上传到OSS的方法步骤
2019/05/31 Javascript
layui的select联动实现代码
2019/09/28 Javascript
element-ui 文件上传修改文件名的方法示例
2019/11/05 Javascript
JavaScript实现雪花飘落效果
2020/12/27 Javascript
[54:57]DOTA2-DPC中国联赛定级赛 Aster vs DLG BO3第二场 1月8日
2021/03/11 DOTA
Python实现统计英文单词个数及字符串分割代码
2015/05/28 Python
对Python的Django框架中的项目进行单元测试的方法
2016/04/11 Python
对numpy中array和asarray的区别详解
2018/04/17 Python
python批量修改图片大小的方法
2018/07/24 Python
python实现对输入的密文加密
2019/03/20 Python
Falsk 与 Django 过滤器的使用与区别详解
2019/06/04 Python
Python利用神经网络解决非线性回归问题实例详解
2019/07/19 Python
Python测试Kafka集群(pykafka)实例
2019/12/23 Python
Python面向对象中类(class)的简单理解与用法分析
2020/02/21 Python
python str字符串转uuid实例
2020/03/03 Python
基于DOM+CSS3实现OrgChart组织结构图插件
2016/03/02 HTML / CSS
耐克美国官网:Nike.com
2016/08/01 全球购物
英国潮流网站:END.(全球免邮)
2017/01/16 全球购物
美国在线乐器和设备商店:Musician’s Friend
2018/07/06 全球购物
整改落实自查报告
2014/11/05 职场文书
财务审计整改报告
2014/11/06 职场文书
2014年宣传部个人工作总结
2014/12/06 职场文书
2014酒店客房部工作总结
2014/12/16 职场文书
地道战观后感
2015/06/04 职场文书
一看就懂的MySQL的聚簇索引及聚簇索引是如何长高的
2021/05/25 MySQL
浅谈Python3中datetime不同时区转换介绍与踩坑
2021/08/02 Python