tensorflow实现逻辑回归模型


Posted in Python onSeptember 08, 2018

逻辑回归模型

逻辑回归是应用非常广泛的一个分类机器学习算法,它将数据拟合到一个logit函数(或者叫做logistic函数)中,从而能够完成对事件发生的概率进行预测。

import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
from tensorflow.examples.tutorials.mnist import input_data
#下载好的mnist数据集存在F:/mnist/data/中
mnist = input_data.read_data_sets('F:/mnist/data/',one_hot = True)
print(mnist.train.num_examples)
print(mnist.test.num_examples)

trainimg = mnist.train.images
trainlabel = mnist.train.labels
testimg = mnist.test.images
testlabel = mnist.test.labels

print(type(trainimg))
print(trainimg.shape,)
print(trainlabel.shape,)
print(testimg.shape,)
print(testlabel.shape,)

nsample = 5
randidx = np.random.randint(trainimg.shape[0],size = nsample)

for i in randidx:
  curr_img = np.reshape(trainimg[i,:],(28,28))
  curr_label = np.argmax(trainlabel[i,:])
  plt.matshow(curr_img,cmap=plt.get_cmap('gray'))
  plt.title(""+str(i)+"th Training Data"+"label is"+str(curr_label))
  print(""+str(i)+"th Training Data"+"label is"+str(curr_label))
  plt.show()


x = tf.placeholder("float",[None,784])
y = tf.placeholder("float",[None,10])
W = tf.Variable(tf.zeros([784,10]))
b = tf.Variable(tf.zeros([10]))

#
actv = tf.nn.softmax(tf.matmul(x,W)+b)
#计算损失
cost = tf.reduce_mean(-tf.reduce_sum(y*tf.log(actv),reduction_indices=1))
#学习率
learning_rate = 0.01
#随机梯度下降
optm = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost)

#求1位置索引值 对比预测值索引与label索引是否一样,一样返回True
pred = tf.equal(tf.argmax(actv,1),tf.argmax(y,1))
#tf.cast把True和false转换为float类型 0,1
#把所有预测结果加在一起求精度
accr = tf.reduce_mean(tf.cast(pred,"float"))
init = tf.global_variables_initializer()
"""
#测试代码 
sess = tf.InteractiveSession()
arr = np.array([[31,23,4,24,27,34],[18,3,25,4,5,6],[4,3,2,1,5,67]])
#返回数组的维数 2
print(tf.rank(arr).eval())
#返回数组的行列数 [3 6]
print(tf.shape(arr).eval())
#返回数组中每一列中最大元素的索引[0 0 1 0 0 2]
print(tf.argmax(arr,0).eval())
#返回数组中每一行中最大元素的索引[5 2 5]
print(tf.argmax(arr,1).eval()) 
J"""
#把所有样本迭代50次
training_epochs = 50
#每次迭代选择多少样本
batch_size = 100
display_step = 5

sess = tf.Session()
sess.run(init)

#循环迭代
for epoch in range(training_epochs):
  avg_cost = 0
  num_batch = int(mnist.train.num_examples/batch_size)
  for i in range(num_batch):
    batch_xs,batch_ys = mnist.train.next_batch(batch_size)
    sess.run(optm,feed_dict = {x:batch_xs,y:batch_ys})
    feeds = {x:batch_xs,y:batch_ys}
    avg_cost += sess.run(cost,feed_dict = feeds)/num_batch

  if epoch % display_step ==0:
    feeds_train = {x:batch_xs,y:batch_ys}
    feeds_test = {x:mnist.test.images,y:mnist.test.labels}
    train_acc = sess.run(accr,feed_dict = feeds_train)
    test_acc = sess.run(accr,feed_dict = feeds_test)
    #每五个epoch打印一次信息
    print("Epoch:%03d/%03d cost:%.9f train_acc:%.3f test_acc: %.3f" %(epoch,training_epochs,avg_cost,train_acc,test_acc))

print("Done")

程序训练结果如下:

Epoch:000/050 cost:1.177228655 train_acc:0.800 test_acc: 0.855
Epoch:005/050 cost:0.440933891 train_acc:0.890 test_acc: 0.894
Epoch:010/050 cost:0.383387268 train_acc:0.930 test_acc: 0.905
Epoch:015/050 cost:0.357281335 train_acc:0.930 test_acc: 0.909
Epoch:020/050 cost:0.341473956 train_acc:0.890 test_acc: 0.913
Epoch:025/050 cost:0.330586549 train_acc:0.920 test_acc: 0.915
Epoch:030/050 cost:0.322370980 train_acc:0.870 test_acc: 0.916
Epoch:035/050 cost:0.315942993 train_acc:0.940 test_acc: 0.916
Epoch:040/050 cost:0.310728854 train_acc:0.890 test_acc: 0.917
Epoch:045/050 cost:0.306357428 train_acc:0.870 test_acc: 0.918
Done

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python搭建虚拟环境的步骤详解
Sep 27 Python
Python实现的下载网页源码功能示例
Jun 13 Python
python+pyqt实现右下角弹出框
Oct 26 Python
Python各类图像库的图片读写方式总结(推荐)
Feb 23 Python
Python使用matplotlib绘制余弦的散点图示例
Mar 14 Python
python解决js文件utf-8编码乱码问题(推荐)
May 02 Python
十分钟利用Python制作属于你自己的个性logo
May 07 Python
在双python下设置python3为默认的方法
Oct 31 Python
使用Python向DataFrame中指定位置添加一列或多列的方法
Jan 29 Python
基于opencv实现简单画板功能
Aug 02 Python
Python 保存加载mat格式文件的示例代码
Aug 04 Python
python常量折叠基础知识点讲解
Feb 28 Python
Django实现表单验证
Sep 08 #Python
python实现排序算法解析
Sep 08 #Python
TensorFlow实现Logistic回归
Sep 07 #Python
tensorflow实现简单逻辑回归
Sep 07 #Python
Tensorflow使用支持向量机拟合线性回归
Sep 07 #Python
TensorFlow实现iris数据集线性回归
Sep 07 #Python
TensorFlow实现模型评估
Sep 07 #Python
You might like
用PHP 4.2书写安全的脚本
2006/10/09 PHP
通过文字传递创建的图形按钮
2006/10/09 PHP
Google韩国首页图标动画效果
2007/08/26 Javascript
javascript html 静态页面传参数
2009/04/10 Javascript
jquery.ui.progressbar 中文文档
2009/11/26 Javascript
javascript面向对象编程(一) 实例代码
2010/06/25 Javascript
33个优秀的 jQuery 图片展示插件分享
2012/03/14 Javascript
javascript实现checkbox复选框实例代码
2016/01/10 Javascript
JavaScript 经典实例日常收集整理(常用经典)
2016/03/30 Javascript
两种js监听滚轮事件的实现方法
2016/05/13 Javascript
移动适配的几种方案(三种方案)
2016/11/25 Javascript
详解angularjs 学习之 scope作用域
2018/01/15 Javascript
JS限制输入框输入的实现代码
2018/07/02 Javascript
ExtJs使用自定义插件动态保存表头配置(隐藏或显示)
2018/09/25 Javascript
layui表格设计以及数据初始化详解
2019/10/26 Javascript
js实现带搜索功能的下拉框
2020/01/11 Javascript
vue实现计算器功能
2020/02/22 Javascript
Vue 按照创建时间和当前时间显示操作(刚刚,几小时前,几天前)
2020/09/10 Javascript
python结合API实现即时天气信息
2016/01/19 Python
Python的自动化部署模块Fabric的安装及使用指南
2016/01/19 Python
Python输出汉字字库及将文字转换为图片的方法
2016/06/04 Python
Python用sndhdr模块识别音频格式详解
2018/01/11 Python
Python利用splinter实现浏览器自动化操作方法
2018/05/11 Python
python 处理数字,把大于上限的数字置零实现方法
2019/01/28 Python
ubuntu上安装python的实例方法
2019/09/30 Python
Python Selenium 设置元素等待的三种方式
2020/03/18 Python
python 实现Harris角点检测算法
2020/12/11 Python
css3设置box-pack和box-align让div里面的元素垂直居中
2014/09/01 HTML / CSS
阿拉伯世界最大的电子商务网站:Souq沙特阿拉伯
2016/10/28 全球购物
用C或者C++语言实现SOCKET通信
2015/02/24 面试题
建筑设计学生的自我评价
2014/01/16 职场文书
学校教研活动总结
2014/07/02 职场文书
装饰公司活动策划方案
2014/08/23 职场文书
师德师风自查材料
2014/10/14 职场文书
毕业生班级鉴定评语
2015/01/04 职场文书
Pycharm 如何设置HTML文件自动补全代码或标签
2021/05/21 Python