tensorflow实现逻辑回归模型


Posted in Python onSeptember 08, 2018

逻辑回归模型

逻辑回归是应用非常广泛的一个分类机器学习算法,它将数据拟合到一个logit函数(或者叫做logistic函数)中,从而能够完成对事件发生的概率进行预测。

import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
from tensorflow.examples.tutorials.mnist import input_data
#下载好的mnist数据集存在F:/mnist/data/中
mnist = input_data.read_data_sets('F:/mnist/data/',one_hot = True)
print(mnist.train.num_examples)
print(mnist.test.num_examples)

trainimg = mnist.train.images
trainlabel = mnist.train.labels
testimg = mnist.test.images
testlabel = mnist.test.labels

print(type(trainimg))
print(trainimg.shape,)
print(trainlabel.shape,)
print(testimg.shape,)
print(testlabel.shape,)

nsample = 5
randidx = np.random.randint(trainimg.shape[0],size = nsample)

for i in randidx:
  curr_img = np.reshape(trainimg[i,:],(28,28))
  curr_label = np.argmax(trainlabel[i,:])
  plt.matshow(curr_img,cmap=plt.get_cmap('gray'))
  plt.title(""+str(i)+"th Training Data"+"label is"+str(curr_label))
  print(""+str(i)+"th Training Data"+"label is"+str(curr_label))
  plt.show()


x = tf.placeholder("float",[None,784])
y = tf.placeholder("float",[None,10])
W = tf.Variable(tf.zeros([784,10]))
b = tf.Variable(tf.zeros([10]))

#
actv = tf.nn.softmax(tf.matmul(x,W)+b)
#计算损失
cost = tf.reduce_mean(-tf.reduce_sum(y*tf.log(actv),reduction_indices=1))
#学习率
learning_rate = 0.01
#随机梯度下降
optm = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost)

#求1位置索引值 对比预测值索引与label索引是否一样,一样返回True
pred = tf.equal(tf.argmax(actv,1),tf.argmax(y,1))
#tf.cast把True和false转换为float类型 0,1
#把所有预测结果加在一起求精度
accr = tf.reduce_mean(tf.cast(pred,"float"))
init = tf.global_variables_initializer()
"""
#测试代码 
sess = tf.InteractiveSession()
arr = np.array([[31,23,4,24,27,34],[18,3,25,4,5,6],[4,3,2,1,5,67]])
#返回数组的维数 2
print(tf.rank(arr).eval())
#返回数组的行列数 [3 6]
print(tf.shape(arr).eval())
#返回数组中每一列中最大元素的索引[0 0 1 0 0 2]
print(tf.argmax(arr,0).eval())
#返回数组中每一行中最大元素的索引[5 2 5]
print(tf.argmax(arr,1).eval()) 
J"""
#把所有样本迭代50次
training_epochs = 50
#每次迭代选择多少样本
batch_size = 100
display_step = 5

sess = tf.Session()
sess.run(init)

#循环迭代
for epoch in range(training_epochs):
  avg_cost = 0
  num_batch = int(mnist.train.num_examples/batch_size)
  for i in range(num_batch):
    batch_xs,batch_ys = mnist.train.next_batch(batch_size)
    sess.run(optm,feed_dict = {x:batch_xs,y:batch_ys})
    feeds = {x:batch_xs,y:batch_ys}
    avg_cost += sess.run(cost,feed_dict = feeds)/num_batch

  if epoch % display_step ==0:
    feeds_train = {x:batch_xs,y:batch_ys}
    feeds_test = {x:mnist.test.images,y:mnist.test.labels}
    train_acc = sess.run(accr,feed_dict = feeds_train)
    test_acc = sess.run(accr,feed_dict = feeds_test)
    #每五个epoch打印一次信息
    print("Epoch:%03d/%03d cost:%.9f train_acc:%.3f test_acc: %.3f" %(epoch,training_epochs,avg_cost,train_acc,test_acc))

print("Done")

程序训练结果如下:

Epoch:000/050 cost:1.177228655 train_acc:0.800 test_acc: 0.855
Epoch:005/050 cost:0.440933891 train_acc:0.890 test_acc: 0.894
Epoch:010/050 cost:0.383387268 train_acc:0.930 test_acc: 0.905
Epoch:015/050 cost:0.357281335 train_acc:0.930 test_acc: 0.909
Epoch:020/050 cost:0.341473956 train_acc:0.890 test_acc: 0.913
Epoch:025/050 cost:0.330586549 train_acc:0.920 test_acc: 0.915
Epoch:030/050 cost:0.322370980 train_acc:0.870 test_acc: 0.916
Epoch:035/050 cost:0.315942993 train_acc:0.940 test_acc: 0.916
Epoch:040/050 cost:0.310728854 train_acc:0.890 test_acc: 0.917
Epoch:045/050 cost:0.306357428 train_acc:0.870 test_acc: 0.918
Done

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
用实例解释Python中的继承和多态的概念
Apr 27 Python
Python计算已经过去多少个周末的方法
Jul 25 Python
利用Python中unittest实现简单的单元测试实例详解
Jan 09 Python
Python创建xml文件示例
Mar 22 Python
Python实现的径向基(RBF)神经网络示例
Feb 06 Python
python-docx修改已存在的Word文档的表格的字体格式方法
May 08 Python
Django rest framework工具包简单用法示例
Jul 20 Python
30秒学会30个超实用Python代码片段【收藏版】
Oct 15 Python
使用Python paramiko模块利用多线程实现ssh并发执行操作
Dec 05 Python
matplotlib 曲线图 和 折线图 plt.plot()实例
Apr 17 Python
使用Python内置模块与函数进行不同进制的数的转换
Apr 26 Python
python3访问字典里的值实例方法
Nov 18 Python
Django实现表单验证
Sep 08 #Python
python实现排序算法解析
Sep 08 #Python
TensorFlow实现Logistic回归
Sep 07 #Python
tensorflow实现简单逻辑回归
Sep 07 #Python
Tensorflow使用支持向量机拟合线性回归
Sep 07 #Python
TensorFlow实现iris数据集线性回归
Sep 07 #Python
TensorFlow实现模型评估
Sep 07 #Python
You might like
Zend Studio 实用快捷键一览表(精心整理)
2013/08/10 PHP
PHP封装XML和JSON格式数据接口操作示例
2019/03/06 PHP
laravel执行php artisan migrate报错的解决方法
2019/10/09 PHP
js innerHTML 的一些问题的解决方法
2008/06/22 Javascript
Jquery css函数用法(判断标签是否拥有某属性)
2011/05/28 Javascript
同时使用n个window onload加载实例介绍
2013/04/25 Javascript
JavaScript立即执行函数的三种不同写法
2014/09/05 Javascript
浏览器缩放检测的js代码
2014/09/28 Javascript
初始Nodejs
2014/11/08 NodeJs
PHP中使用微秒计算脚本执行时间例子
2014/11/19 Javascript
node.js中的fs.mkdirSync方法使用说明
2014/12/17 Javascript
详解ECharts使用心得总结
2016/12/06 Javascript
搭建简单的nodejs http服务器详解
2017/03/09 NodeJs
深入理解Vue-cli搭建项目后的目录结构探秘
2017/07/13 Javascript
vue封装swiper代码实例解析
2019/10/08 Javascript
[55:26]DOTA2-DPC中国联赛 正赛 Aster vs LBZS BO3 第一场 2月23日
2021/03/11 DOTA
Python实现的计算器功能示例
2018/04/26 Python
对pandas中两种数据类型Series和DataFrame的区别详解
2018/11/12 Python
python利用7z批量解压rar的实现
2019/08/07 Python
Django错误:TypeError at / 'bool' object is not callable解决
2019/08/16 Python
Python中格式化字符串的四种实现
2020/05/26 Python
英国手机壳购买网站:Case Hut
2019/04/11 全球购物
巴西化妆品商店:Lojas Rede
2019/07/26 全球购物
New delete 与malloc free 的联系与区别
2013/02/04 面试题
师范生的个人求职信范文
2014/01/04 职场文书
时尚休闲吧创业计划书
2014/01/25 职场文书
经销商年会策划方案
2014/05/29 职场文书
和谐社区口号
2014/06/19 职场文书
企业标语大全
2014/07/01 职场文书
公司户外活动总结
2014/07/04 职场文书
2014年学习委员工作总结
2014/11/14 职场文书
房屋维修申请报告
2015/05/18 职场文书
公安机关起诉意见书
2015/05/20 职场文书
2016关于军训的心得体会
2016/01/11 职场文书
Angular CLI发布路径的配置项浅析
2021/03/29 Javascript
Java代码规范与质量检测插件SonarLint的使用
2022/08/05 Java/Android