tensorflow实现逻辑回归模型


Posted in Python onSeptember 08, 2018

逻辑回归模型

逻辑回归是应用非常广泛的一个分类机器学习算法,它将数据拟合到一个logit函数(或者叫做logistic函数)中,从而能够完成对事件发生的概率进行预测。

import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
from tensorflow.examples.tutorials.mnist import input_data
#下载好的mnist数据集存在F:/mnist/data/中
mnist = input_data.read_data_sets('F:/mnist/data/',one_hot = True)
print(mnist.train.num_examples)
print(mnist.test.num_examples)

trainimg = mnist.train.images
trainlabel = mnist.train.labels
testimg = mnist.test.images
testlabel = mnist.test.labels

print(type(trainimg))
print(trainimg.shape,)
print(trainlabel.shape,)
print(testimg.shape,)
print(testlabel.shape,)

nsample = 5
randidx = np.random.randint(trainimg.shape[0],size = nsample)

for i in randidx:
  curr_img = np.reshape(trainimg[i,:],(28,28))
  curr_label = np.argmax(trainlabel[i,:])
  plt.matshow(curr_img,cmap=plt.get_cmap('gray'))
  plt.title(""+str(i)+"th Training Data"+"label is"+str(curr_label))
  print(""+str(i)+"th Training Data"+"label is"+str(curr_label))
  plt.show()


x = tf.placeholder("float",[None,784])
y = tf.placeholder("float",[None,10])
W = tf.Variable(tf.zeros([784,10]))
b = tf.Variable(tf.zeros([10]))

#
actv = tf.nn.softmax(tf.matmul(x,W)+b)
#计算损失
cost = tf.reduce_mean(-tf.reduce_sum(y*tf.log(actv),reduction_indices=1))
#学习率
learning_rate = 0.01
#随机梯度下降
optm = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost)

#求1位置索引值 对比预测值索引与label索引是否一样,一样返回True
pred = tf.equal(tf.argmax(actv,1),tf.argmax(y,1))
#tf.cast把True和false转换为float类型 0,1
#把所有预测结果加在一起求精度
accr = tf.reduce_mean(tf.cast(pred,"float"))
init = tf.global_variables_initializer()
"""
#测试代码 
sess = tf.InteractiveSession()
arr = np.array([[31,23,4,24,27,34],[18,3,25,4,5,6],[4,3,2,1,5,67]])
#返回数组的维数 2
print(tf.rank(arr).eval())
#返回数组的行列数 [3 6]
print(tf.shape(arr).eval())
#返回数组中每一列中最大元素的索引[0 0 1 0 0 2]
print(tf.argmax(arr,0).eval())
#返回数组中每一行中最大元素的索引[5 2 5]
print(tf.argmax(arr,1).eval()) 
J"""
#把所有样本迭代50次
training_epochs = 50
#每次迭代选择多少样本
batch_size = 100
display_step = 5

sess = tf.Session()
sess.run(init)

#循环迭代
for epoch in range(training_epochs):
  avg_cost = 0
  num_batch = int(mnist.train.num_examples/batch_size)
  for i in range(num_batch):
    batch_xs,batch_ys = mnist.train.next_batch(batch_size)
    sess.run(optm,feed_dict = {x:batch_xs,y:batch_ys})
    feeds = {x:batch_xs,y:batch_ys}
    avg_cost += sess.run(cost,feed_dict = feeds)/num_batch

  if epoch % display_step ==0:
    feeds_train = {x:batch_xs,y:batch_ys}
    feeds_test = {x:mnist.test.images,y:mnist.test.labels}
    train_acc = sess.run(accr,feed_dict = feeds_train)
    test_acc = sess.run(accr,feed_dict = feeds_test)
    #每五个epoch打印一次信息
    print("Epoch:%03d/%03d cost:%.9f train_acc:%.3f test_acc: %.3f" %(epoch,training_epochs,avg_cost,train_acc,test_acc))

print("Done")

程序训练结果如下:

Epoch:000/050 cost:1.177228655 train_acc:0.800 test_acc: 0.855
Epoch:005/050 cost:0.440933891 train_acc:0.890 test_acc: 0.894
Epoch:010/050 cost:0.383387268 train_acc:0.930 test_acc: 0.905
Epoch:015/050 cost:0.357281335 train_acc:0.930 test_acc: 0.909
Epoch:020/050 cost:0.341473956 train_acc:0.890 test_acc: 0.913
Epoch:025/050 cost:0.330586549 train_acc:0.920 test_acc: 0.915
Epoch:030/050 cost:0.322370980 train_acc:0.870 test_acc: 0.916
Epoch:035/050 cost:0.315942993 train_acc:0.940 test_acc: 0.916
Epoch:040/050 cost:0.310728854 train_acc:0.890 test_acc: 0.917
Epoch:045/050 cost:0.306357428 train_acc:0.870 test_acc: 0.918
Done

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python strip()函数 介绍
May 24 Python
天翼开放平台免费短信验证码接口使用实例
Dec 18 Python
python通过apply使用元祖和列表调用函数实例
May 26 Python
关于Python中Inf与Nan的判断问题详解
Feb 08 Python
Python调用系统底层API播放wav文件的方法
Aug 11 Python
微信跳一跳python自动代码解读1.0
Jan 12 Python
Python3之文件读写操作的实例讲解
Jan 23 Python
Pandas之排序函数sort_values()的实现
Jul 09 Python
Python中拆分字符串的操作方法
Jul 23 Python
python使用 __init__初始化操作简单示例
Sep 26 Python
Python 依赖库太多了该如何管理
Nov 08 Python
python 解决flask 图片在线浏览或者直接下载的问题
Jan 09 Python
Django实现表单验证
Sep 08 #Python
python实现排序算法解析
Sep 08 #Python
TensorFlow实现Logistic回归
Sep 07 #Python
tensorflow实现简单逻辑回归
Sep 07 #Python
Tensorflow使用支持向量机拟合线性回归
Sep 07 #Python
TensorFlow实现iris数据集线性回归
Sep 07 #Python
TensorFlow实现模型评估
Sep 07 #Python
You might like
php一句话cmdshell新型 (非一句话木马)
2009/04/18 PHP
在PHP中检查PHP文件是否有语法错误的方法
2009/12/23 PHP
PHP中串行化用法示例
2016/11/16 PHP
PHP常用函数总结(180多个)
2016/12/25 PHP
laravel中短信发送验证码的实现方法
2018/04/25 PHP
javascript与asp.net(c#)互相调用方法
2009/12/13 Javascript
基于Jquery的温度计动画效果
2010/06/18 Javascript
Javascript匿名函数的一种应用 代码封装
2010/06/27 Javascript
Jquery实现视频播放页面的关灯开灯效果
2013/05/27 Javascript
js调用图片隐藏&显示实现代码
2013/09/13 Javascript
JS 仿腾讯发表微博的效果代码
2013/12/25 Javascript
将查询条件的input、select清空
2014/01/14 Javascript
JavaScript Sort 的一个错误用法示例
2015/03/20 Javascript
开启Javascript中apply、call、bind的用法之旅模式
2015/10/28 Javascript
JavaScript前端开发之实现二进制读写操作
2015/11/04 Javascript
Bootstrap3学习笔记(三)之表格
2016/05/20 Javascript
jQuery插件HighCharts绘制简单2D柱状图效果示例【附demo源码】
2017/03/21 jQuery
基于jQuery实现瀑布流页面
2017/04/11 jQuery
Node.js进阶之核心模块https入门
2018/05/23 Javascript
Mint UI组件库CheckList使用及踩坑总结
2018/12/20 Javascript
vue中使用 pako.js 解密 gzip加密字符串的方法
2019/06/10 Javascript
nodejs简单抓包工具使用详解
2019/08/23 NodeJs
[05:49]DOTA2-DPC中国联赛 正赛 Elephant vs LBZS 选手采访
2021/03/11 DOTA
利用Python开发实现简单的记事本
2016/11/15 Python
python时间日期函数与利用pandas进行时间序列处理详解
2018/03/13 Python
python文件选择对话框的操作方法
2019/06/27 Python
python实现while循环打印星星的四种形状
2019/11/23 Python
python 实现提取log文件中的关键句子,并进行统计分析
2019/12/24 Python
Python利用matplotlib绘制散点图的新手教程
2020/11/05 Python
Holland & Barrett爱尔兰:英国领先的健康零售商
2019/03/31 全球购物
介绍一下游标
2012/01/10 面试题
2014年作风建设心得体会
2014/10/22 职场文书
行风评议整改报告
2014/11/06 职场文书
入党积极分子个人总结
2015/03/02 职场文书
SSM项目使用拦截器实现登录验证功能
2022/01/22 Java/Android
日元符号 ¥
2022/02/17 杂记