详解用TensorFlow实现逻辑回归算法


Posted in Python onMay 02, 2018

本文将实现逻辑回归算法,预测低出生体重的概率。

# Logistic Regression
# 逻辑回归
#----------------------------------
#
# This function shows how to use TensorFlow to
# solve logistic regression.
# y = sigmoid(Ax + b)
#
# We will use the low birth weight data, specifically:
# y = 0 or 1 = low birth weight
# x = demographic and medical history data

import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
import requests
from tensorflow.python.framework import ops
import os.path
import csv


ops.reset_default_graph()

# Create graph
sess = tf.Session()

###
# Obtain and prepare data for modeling
###

# name of data file
birth_weight_file = 'birth_weight.csv'

# download data and create data file if file does not exist in current directory
if not os.path.exists(birth_weight_file):
  birthdata_url = 'https://github.com/nfmcclure/tensorflow_cookbook/raw/master/01_Introduction/07_Working_with_Data_Sources/birthweight_data/birthweight.dat'
  birth_file = requests.get(birthdata_url)
  birth_data = birth_file.text.split('\r\n')
  birth_header = birth_data[0].split('\t')
  birth_data = [[float(x) for x in y.split('\t') if len(x)>=1] for y in birth_data[1:] if len(y)>=1]
  with open(birth_weight_file, "w") as f:
    writer = csv.writer(f)
    writer.writerow(birth_header)
    writer.writerows(birth_data)
    f.close()

# read birth weight data into memory
birth_data = []
with open(birth_weight_file, newline='') as csvfile:
   csv_reader = csv.reader(csvfile)
   birth_header = next(csv_reader)
   for row in csv_reader:
     birth_data.append(row)

birth_data = [[float(x) for x in row] for row in birth_data]

# Pull out target variable
y_vals = np.array([x[0] for x in birth_data])
# Pull out predictor variables (not id, not target, and not birthweight)
x_vals = np.array([x[1:8] for x in birth_data])

# set for reproducible results
seed = 99
np.random.seed(seed)
tf.set_random_seed(seed)

# Split data into train/test = 80%/20%
# 分割数据集为测试集和训练集
train_indices = np.random.choice(len(x_vals), round(len(x_vals)*0.8), replace=False)
test_indices = np.array(list(set(range(len(x_vals))) - set(train_indices)))
x_vals_train = x_vals[train_indices]
x_vals_test = x_vals[test_indices]
y_vals_train = y_vals[train_indices]
y_vals_test = y_vals[test_indices]

# Normalize by column (min-max norm)
# 将所有特征缩放到0和1区间(min-max缩放),逻辑回归收敛的效果更好
# 归一化特征
def normalize_cols(m):
  col_max = m.max(axis=0)
  col_min = m.min(axis=0)
  return (m-col_min) / (col_max - col_min)

x_vals_train = np.nan_to_num(normalize_cols(x_vals_train))
x_vals_test = np.nan_to_num(normalize_cols(x_vals_test))

###
# Define Tensorflow computational graph¶
###

# Declare batch size
batch_size = 25

# Initialize placeholders
x_data = tf.placeholder(shape=[None, 7], dtype=tf.float32)
y_target = tf.placeholder(shape=[None, 1], dtype=tf.float32)

# Create variables for linear regression
A = tf.Variable(tf.random_normal(shape=[7,1]))
b = tf.Variable(tf.random_normal(shape=[1,1]))

# Declare model operations
model_output = tf.add(tf.matmul(x_data, A), b)

# Declare loss function (Cross Entropy loss)
loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=model_output, labels=y_target))

# Declare optimizer
my_opt = tf.train.GradientDescentOptimizer(0.01)
train_step = my_opt.minimize(loss)

###
# Train model
###

# Initialize variables
init = tf.global_variables_initializer()
sess.run(init)

# Actual Prediction
# 除记录损失函数外,也需要记录分类器在训练集和测试集上的准确度。
# 所以创建一个返回准确度的预测函数
prediction = tf.round(tf.sigmoid(model_output))
predictions_correct = tf.cast(tf.equal(prediction, y_target), tf.float32)
accuracy = tf.reduce_mean(predictions_correct)

# Training loop
# 开始遍历迭代训练,记录损失值和准确度
loss_vec = []
train_acc = []
test_acc = []
for i in range(1500):
  rand_index = np.random.choice(len(x_vals_train), size=batch_size)
  rand_x = x_vals_train[rand_index]
  rand_y = np.transpose([y_vals_train[rand_index]])
  sess.run(train_step, feed_dict={x_data: rand_x, y_target: rand_y})

  temp_loss = sess.run(loss, feed_dict={x_data: rand_x, y_target: rand_y})
  loss_vec.append(temp_loss)
  temp_acc_train = sess.run(accuracy, feed_dict={x_data: x_vals_train, y_target: np.transpose([y_vals_train])})
  train_acc.append(temp_acc_train)
  temp_acc_test = sess.run(accuracy, feed_dict={x_data: x_vals_test, y_target: np.transpose([y_vals_test])})
  test_acc.append(temp_acc_test)
  if (i+1)%300==0:
    print('Loss = ' + str(temp_loss))


###
# Display model performance
###

# 绘制损失和准确度
plt.plot(loss_vec, 'k-')
plt.title('Cross Entropy Loss per Generation')
plt.xlabel('Generation')
plt.ylabel('Cross Entropy Loss')
plt.show()

# Plot train and test accuracy
plt.plot(train_acc, 'k-', label='Train Set Accuracy')
plt.plot(test_acc, 'r--', label='Test Set Accuracy')
plt.title('Train and Test Accuracy')
plt.xlabel('Generation')
plt.ylabel('Accuracy')
plt.legend(loc='lower right')
plt.show()

数据结果:

Loss = 0.845124
Loss = 0.658061
Loss = 0.471852
Loss = 0.643469
Loss = 0.672077

详解用TensorFlow实现逻辑回归算法

迭代1500次的交叉熵损失图

详解用TensorFlow实现逻辑回归算法

迭代1500次的测试集和训练集的准确度图

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python求斐波那契数列示例分享
Feb 14 Python
Python标准库之多进程(multiprocessing包)介绍
Nov 25 Python
Python之日期与时间处理模块(date和datetime)
Feb 16 Python
详解python实现读取邮件数据并下载附件的实例
Aug 03 Python
Python3安装Scrapy的方法步骤
Nov 23 Python
Django使用HttpResponse返回图片并显示的方法
May 22 Python
对Python中数组的几种使用方法总结
Jun 28 Python
Selenium+Python 自动化操控登录界面实例(有简单验证码图片校验)
Jun 28 Python
Django框架自定义模型管理器与元选项用法分析
Jul 22 Python
python扫描线填充算法详解
Feb 19 Python
对Python中 \r, \n, \r\n的彻底理解
Mar 06 Python
Anaconda的安装与虚拟环境建立
Nov 18 Python
Python获取指定字符前面的所有字符方法
May 02 #Python
Python 查找字符在字符串中的位置实例
May 02 #Python
python 巧用正则寻找字符串中的特定字符的位置方法
May 02 #Python
Python 在字符串中加入变量的实例讲解
May 02 #Python
Python 实现字符串中指定位置插入一个字符
May 02 #Python
Python3实现的简单验证码识别功能示例
May 02 #Python
利用Python在一个文件的头部插入数据的实例
May 02 #Python
You might like
自动跳转中英文页面
2006/10/09 PHP
php删除文件夹及其文件夹下所有文件的函数代码
2013/01/23 PHP
深入解析yii权限分级式访问控制的实现(非RBAC法)
2013/06/13 PHP
php常用Output和ptions/Info函数集介绍
2013/06/19 PHP
php通过curl添加cookie伪造登陆抓取数据的方法
2016/04/02 PHP
让iframe框架网页在任何浏览器下自动伸缩
2006/08/18 Javascript
很酷的javascript loading效果代码
2008/06/18 Javascript
javascript 设置某DIV区域内的checkbox复选框
2009/11/30 Javascript
新手常遇到的一些jquery问题整理
2010/08/16 Javascript
javascript 实现简单的table排序及table操作练习
2012/12/28 Javascript
Javascript中浮点数相乘的一个解决方法
2014/06/03 Javascript
window.print打印指定div指定网页指定区域的方法
2014/08/04 Javascript
iframe里面的元素触发父窗口元素事件的jquery代码
2014/10/19 Javascript
scrollWidth,clientWidth,offsetWidth的区别
2015/01/13 Javascript
Angularjs全局变量被作用域监听的正确姿势
2016/02/06 Javascript
jQuery 获取select选中值及清除选中状态
2016/12/13 Javascript
fullPage.js和CSS3实现全屏滚动效果
2017/05/05 Javascript
微信小程序 获取手机号 JavaScript解密示例代码详解
2020/05/14 Javascript
python3音乐播放器简单实现代码
2020/04/20 Python
python爬虫_自动获取seebug的poc实例
2017/08/05 Python
python 调用c语言函数的方法
2017/09/29 Python
Tensorflow的常用矩阵生成方式
2020/01/04 Python
Python3 读取Word文件方式
2020/02/13 Python
Python实现http接口自动化测试的示例代码
2020/10/09 Python
Django使用django-simple-captcha做验证码的实现示例
2021/01/07 Python
Python爬虫获取op.gg英雄联盟英雄对位胜率的源码
2021/01/29 Python
python中os.remove()用法及注意事项
2021/01/31 Python
CSS3实现精美横向滚动菜单按钮
2017/04/14 HTML / CSS
在数据文件自动增长时,自动增长是否会阻塞对文件的更新
2014/05/01 面试题
介绍一下JMS编程步骤
2015/09/22 面试题
毕业论文评语大全
2014/04/29 职场文书
工伤私了协议书范本
2014/11/24 职场文书
小学安全工作总结2015
2015/05/18 职场文书
企业宣传语大全
2015/07/13 职场文书
JavaScript展开运算符和剩余运算符的区别详解
2022/02/18 Javascript
一文搞懂Redis中String数据类型
2022/04/03 Redis