详解用TensorFlow实现逻辑回归算法


Posted in Python onMay 02, 2018

本文将实现逻辑回归算法,预测低出生体重的概率。

# Logistic Regression
# 逻辑回归
#----------------------------------
#
# This function shows how to use TensorFlow to
# solve logistic regression.
# y = sigmoid(Ax + b)
#
# We will use the low birth weight data, specifically:
# y = 0 or 1 = low birth weight
# x = demographic and medical history data

import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
import requests
from tensorflow.python.framework import ops
import os.path
import csv


ops.reset_default_graph()

# Create graph
sess = tf.Session()

###
# Obtain and prepare data for modeling
###

# name of data file
birth_weight_file = 'birth_weight.csv'

# download data and create data file if file does not exist in current directory
if not os.path.exists(birth_weight_file):
  birthdata_url = 'https://github.com/nfmcclure/tensorflow_cookbook/raw/master/01_Introduction/07_Working_with_Data_Sources/birthweight_data/birthweight.dat'
  birth_file = requests.get(birthdata_url)
  birth_data = birth_file.text.split('\r\n')
  birth_header = birth_data[0].split('\t')
  birth_data = [[float(x) for x in y.split('\t') if len(x)>=1] for y in birth_data[1:] if len(y)>=1]
  with open(birth_weight_file, "w") as f:
    writer = csv.writer(f)
    writer.writerow(birth_header)
    writer.writerows(birth_data)
    f.close()

# read birth weight data into memory
birth_data = []
with open(birth_weight_file, newline='') as csvfile:
   csv_reader = csv.reader(csvfile)
   birth_header = next(csv_reader)
   for row in csv_reader:
     birth_data.append(row)

birth_data = [[float(x) for x in row] for row in birth_data]

# Pull out target variable
y_vals = np.array([x[0] for x in birth_data])
# Pull out predictor variables (not id, not target, and not birthweight)
x_vals = np.array([x[1:8] for x in birth_data])

# set for reproducible results
seed = 99
np.random.seed(seed)
tf.set_random_seed(seed)

# Split data into train/test = 80%/20%
# 分割数据集为测试集和训练集
train_indices = np.random.choice(len(x_vals), round(len(x_vals)*0.8), replace=False)
test_indices = np.array(list(set(range(len(x_vals))) - set(train_indices)))
x_vals_train = x_vals[train_indices]
x_vals_test = x_vals[test_indices]
y_vals_train = y_vals[train_indices]
y_vals_test = y_vals[test_indices]

# Normalize by column (min-max norm)
# 将所有特征缩放到0和1区间(min-max缩放),逻辑回归收敛的效果更好
# 归一化特征
def normalize_cols(m):
  col_max = m.max(axis=0)
  col_min = m.min(axis=0)
  return (m-col_min) / (col_max - col_min)

x_vals_train = np.nan_to_num(normalize_cols(x_vals_train))
x_vals_test = np.nan_to_num(normalize_cols(x_vals_test))

###
# Define Tensorflow computational graph¶
###

# Declare batch size
batch_size = 25

# Initialize placeholders
x_data = tf.placeholder(shape=[None, 7], dtype=tf.float32)
y_target = tf.placeholder(shape=[None, 1], dtype=tf.float32)

# Create variables for linear regression
A = tf.Variable(tf.random_normal(shape=[7,1]))
b = tf.Variable(tf.random_normal(shape=[1,1]))

# Declare model operations
model_output = tf.add(tf.matmul(x_data, A), b)

# Declare loss function (Cross Entropy loss)
loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=model_output, labels=y_target))

# Declare optimizer
my_opt = tf.train.GradientDescentOptimizer(0.01)
train_step = my_opt.minimize(loss)

###
# Train model
###

# Initialize variables
init = tf.global_variables_initializer()
sess.run(init)

# Actual Prediction
# 除记录损失函数外,也需要记录分类器在训练集和测试集上的准确度。
# 所以创建一个返回准确度的预测函数
prediction = tf.round(tf.sigmoid(model_output))
predictions_correct = tf.cast(tf.equal(prediction, y_target), tf.float32)
accuracy = tf.reduce_mean(predictions_correct)

# Training loop
# 开始遍历迭代训练,记录损失值和准确度
loss_vec = []
train_acc = []
test_acc = []
for i in range(1500):
  rand_index = np.random.choice(len(x_vals_train), size=batch_size)
  rand_x = x_vals_train[rand_index]
  rand_y = np.transpose([y_vals_train[rand_index]])
  sess.run(train_step, feed_dict={x_data: rand_x, y_target: rand_y})

  temp_loss = sess.run(loss, feed_dict={x_data: rand_x, y_target: rand_y})
  loss_vec.append(temp_loss)
  temp_acc_train = sess.run(accuracy, feed_dict={x_data: x_vals_train, y_target: np.transpose([y_vals_train])})
  train_acc.append(temp_acc_train)
  temp_acc_test = sess.run(accuracy, feed_dict={x_data: x_vals_test, y_target: np.transpose([y_vals_test])})
  test_acc.append(temp_acc_test)
  if (i+1)%300==0:
    print('Loss = ' + str(temp_loss))


###
# Display model performance
###

# 绘制损失和准确度
plt.plot(loss_vec, 'k-')
plt.title('Cross Entropy Loss per Generation')
plt.xlabel('Generation')
plt.ylabel('Cross Entropy Loss')
plt.show()

# Plot train and test accuracy
plt.plot(train_acc, 'k-', label='Train Set Accuracy')
plt.plot(test_acc, 'r--', label='Test Set Accuracy')
plt.title('Train and Test Accuracy')
plt.xlabel('Generation')
plt.ylabel('Accuracy')
plt.legend(loc='lower right')
plt.show()

数据结果:

Loss = 0.845124
Loss = 0.658061
Loss = 0.471852
Loss = 0.643469
Loss = 0.672077

详解用TensorFlow实现逻辑回归算法

迭代1500次的交叉熵损失图

详解用TensorFlow实现逻辑回归算法

迭代1500次的测试集和训练集的准确度图

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
将Python中的数据存储到系统本地的简单方法
Apr 11 Python
python实现微信跳一跳辅助工具步骤详解
Jan 04 Python
Python中判断输入是否为数字的实现代码
May 26 Python
Python使用post及get方式提交数据的实例
Jan 24 Python
Python并发:多线程与多进程的详解
Jan 24 Python
python实现二维数组的对角线遍历
Mar 02 Python
python3中pip3安装出错,找不到SSL的解决方式
Dec 12 Python
python通过nmap扫描在线设备并尝试AAA登录(实例代码)
Dec 30 Python
Python callable内置函数原理解析
Mar 05 Python
django项目中新增app的2种实现方法
Apr 01 Python
基于python实现MQTT发布订阅过程原理解析
Jul 27 Python
python opencv肤色检测的实现示例
Dec 21 Python
Python获取指定字符前面的所有字符方法
May 02 #Python
Python 查找字符在字符串中的位置实例
May 02 #Python
python 巧用正则寻找字符串中的特定字符的位置方法
May 02 #Python
Python 在字符串中加入变量的实例讲解
May 02 #Python
Python 实现字符串中指定位置插入一个字符
May 02 #Python
Python3实现的简单验证码识别功能示例
May 02 #Python
利用Python在一个文件的头部插入数据的实例
May 02 #Python
You might like
PHP 图片水印类代码
2012/08/27 PHP
PHP递归复制、移动目录的自定义函数分享
2014/11/18 PHP
thinkphp模板继承实例简述
2014/11/26 PHP
PHP设计模式之工厂模式(Factory Pattern)的讲解
2019/03/21 PHP
微信公众平台开发教程⑥ 微信开发集成类的使用图文详解
2019/04/10 PHP
PHP SESSION跨页面传递失败解决方案
2020/12/11 PHP
javascript 添加和移除函数的通用方法
2009/10/20 Javascript
JS关键字球状旋转效果的实例代码
2013/11/29 Javascript
一个html5播放视频的video控件只支持android的默认格式mp4和3gp
2014/05/08 Javascript
JS获取及设置TextArea或input文本框选择文本位置的方法
2015/03/24 Javascript
JS如何实现文本框随文本的长度而增长
2015/07/30 Javascript
js带前后翻页的图片切换效果代码分享
2015/09/08 Javascript
JointJS JavaScript流程图绘制框架解析
2019/08/15 Javascript
a标签调用js的方法总结
2019/09/05 Javascript
js 判断当前时间是否处于某个一个时间段内
2019/09/19 Javascript
[01:28:44]DOTA2-DPC中国联赛定级赛 RNG vs iG BO3第一场 1月10日
2021/03/11 DOTA
python中找出numpy array数组的最值及其索引方法
2018/04/17 Python
python线程池threadpool实现篇
2018/04/27 Python
python 中不同包 类 方法 之间的调用详解
2020/03/09 Python
Python3实现建造者模式的示例代码
2020/06/28 Python
PyTorch的torch.cat用法
2020/06/28 Python
英国顶级家庭折扣店:The Works
2017/09/06 全球购物
服装设计专业毕业生推荐信
2013/11/09 职场文书
电子商务网站的创业计划书
2014/01/05 职场文书
劳动竞赛活动方案
2014/02/20 职场文书
八荣八耻的活动方案
2014/08/16 职场文书
小学生放飞梦想演讲稿
2014/08/26 职场文书
中学生运动会新闻稿
2014/09/24 职场文书
教师个人教学总结
2015/02/11 职场文书
保卫工作个人总结
2015/03/03 职场文书
2015年度质量工作总结报告
2015/04/27 职场文书
运动会加油稿30字
2015/07/21 职场文书
大学军训通讯稿(2016最新版)
2015/12/21 职场文书
小学生优秀作文范文(六篇)
2019/07/10 职场文书
Python办公自动化解决world文件批量转换
2021/09/15 Python