详解用TensorFlow实现逻辑回归算法


Posted in Python onMay 02, 2018

本文将实现逻辑回归算法,预测低出生体重的概率。

# Logistic Regression
# 逻辑回归
#----------------------------------
#
# This function shows how to use TensorFlow to
# solve logistic regression.
# y = sigmoid(Ax + b)
#
# We will use the low birth weight data, specifically:
# y = 0 or 1 = low birth weight
# x = demographic and medical history data

import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
import requests
from tensorflow.python.framework import ops
import os.path
import csv


ops.reset_default_graph()

# Create graph
sess = tf.Session()

###
# Obtain and prepare data for modeling
###

# name of data file
birth_weight_file = 'birth_weight.csv'

# download data and create data file if file does not exist in current directory
if not os.path.exists(birth_weight_file):
  birthdata_url = 'https://github.com/nfmcclure/tensorflow_cookbook/raw/master/01_Introduction/07_Working_with_Data_Sources/birthweight_data/birthweight.dat'
  birth_file = requests.get(birthdata_url)
  birth_data = birth_file.text.split('\r\n')
  birth_header = birth_data[0].split('\t')
  birth_data = [[float(x) for x in y.split('\t') if len(x)>=1] for y in birth_data[1:] if len(y)>=1]
  with open(birth_weight_file, "w") as f:
    writer = csv.writer(f)
    writer.writerow(birth_header)
    writer.writerows(birth_data)
    f.close()

# read birth weight data into memory
birth_data = []
with open(birth_weight_file, newline='') as csvfile:
   csv_reader = csv.reader(csvfile)
   birth_header = next(csv_reader)
   for row in csv_reader:
     birth_data.append(row)

birth_data = [[float(x) for x in row] for row in birth_data]

# Pull out target variable
y_vals = np.array([x[0] for x in birth_data])
# Pull out predictor variables (not id, not target, and not birthweight)
x_vals = np.array([x[1:8] for x in birth_data])

# set for reproducible results
seed = 99
np.random.seed(seed)
tf.set_random_seed(seed)

# Split data into train/test = 80%/20%
# 分割数据集为测试集和训练集
train_indices = np.random.choice(len(x_vals), round(len(x_vals)*0.8), replace=False)
test_indices = np.array(list(set(range(len(x_vals))) - set(train_indices)))
x_vals_train = x_vals[train_indices]
x_vals_test = x_vals[test_indices]
y_vals_train = y_vals[train_indices]
y_vals_test = y_vals[test_indices]

# Normalize by column (min-max norm)
# 将所有特征缩放到0和1区间(min-max缩放),逻辑回归收敛的效果更好
# 归一化特征
def normalize_cols(m):
  col_max = m.max(axis=0)
  col_min = m.min(axis=0)
  return (m-col_min) / (col_max - col_min)

x_vals_train = np.nan_to_num(normalize_cols(x_vals_train))
x_vals_test = np.nan_to_num(normalize_cols(x_vals_test))

###
# Define Tensorflow computational graph¶
###

# Declare batch size
batch_size = 25

# Initialize placeholders
x_data = tf.placeholder(shape=[None, 7], dtype=tf.float32)
y_target = tf.placeholder(shape=[None, 1], dtype=tf.float32)

# Create variables for linear regression
A = tf.Variable(tf.random_normal(shape=[7,1]))
b = tf.Variable(tf.random_normal(shape=[1,1]))

# Declare model operations
model_output = tf.add(tf.matmul(x_data, A), b)

# Declare loss function (Cross Entropy loss)
loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=model_output, labels=y_target))

# Declare optimizer
my_opt = tf.train.GradientDescentOptimizer(0.01)
train_step = my_opt.minimize(loss)

###
# Train model
###

# Initialize variables
init = tf.global_variables_initializer()
sess.run(init)

# Actual Prediction
# 除记录损失函数外,也需要记录分类器在训练集和测试集上的准确度。
# 所以创建一个返回准确度的预测函数
prediction = tf.round(tf.sigmoid(model_output))
predictions_correct = tf.cast(tf.equal(prediction, y_target), tf.float32)
accuracy = tf.reduce_mean(predictions_correct)

# Training loop
# 开始遍历迭代训练,记录损失值和准确度
loss_vec = []
train_acc = []
test_acc = []
for i in range(1500):
  rand_index = np.random.choice(len(x_vals_train), size=batch_size)
  rand_x = x_vals_train[rand_index]
  rand_y = np.transpose([y_vals_train[rand_index]])
  sess.run(train_step, feed_dict={x_data: rand_x, y_target: rand_y})

  temp_loss = sess.run(loss, feed_dict={x_data: rand_x, y_target: rand_y})
  loss_vec.append(temp_loss)
  temp_acc_train = sess.run(accuracy, feed_dict={x_data: x_vals_train, y_target: np.transpose([y_vals_train])})
  train_acc.append(temp_acc_train)
  temp_acc_test = sess.run(accuracy, feed_dict={x_data: x_vals_test, y_target: np.transpose([y_vals_test])})
  test_acc.append(temp_acc_test)
  if (i+1)%300==0:
    print('Loss = ' + str(temp_loss))


###
# Display model performance
###

# 绘制损失和准确度
plt.plot(loss_vec, 'k-')
plt.title('Cross Entropy Loss per Generation')
plt.xlabel('Generation')
plt.ylabel('Cross Entropy Loss')
plt.show()

# Plot train and test accuracy
plt.plot(train_acc, 'k-', label='Train Set Accuracy')
plt.plot(test_acc, 'r--', label='Test Set Accuracy')
plt.title('Train and Test Accuracy')
plt.xlabel('Generation')
plt.ylabel('Accuracy')
plt.legend(loc='lower right')
plt.show()

数据结果:

Loss = 0.845124
Loss = 0.658061
Loss = 0.471852
Loss = 0.643469
Loss = 0.672077

详解用TensorFlow实现逻辑回归算法

迭代1500次的交叉熵损失图

详解用TensorFlow实现逻辑回归算法

迭代1500次的测试集和训练集的准确度图

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python根据路径导入模块的方法
Sep 30 Python
python实现下载文件的三种方法
Feb 09 Python
高质量Python代码编写的5个优化技巧
Nov 16 Python
Python实现的归并排序算法示例
Nov 21 Python
详谈Pandas中iloc和loc以及ix的区别
Jun 08 Python
详解如何将python3.6软件的py文件打包成exe程序
Oct 09 Python
使用Python实现跳一跳自动跳跃功能
Jul 10 Python
Django ORM 常用字段与不常用字段汇总
Aug 09 Python
Python自动化完成tb喵币任务的操作方法
Oct 30 Python
pytorch模型存储的2种实现方法
Feb 14 Python
基于python实现操作redis及消息队列
Aug 27 Python
python 爬虫基本使用——统计杭电oj题目正确率并排序
Oct 26 Python
Python获取指定字符前面的所有字符方法
May 02 #Python
Python 查找字符在字符串中的位置实例
May 02 #Python
python 巧用正则寻找字符串中的特定字符的位置方法
May 02 #Python
Python 在字符串中加入变量的实例讲解
May 02 #Python
Python 实现字符串中指定位置插入一个字符
May 02 #Python
Python3实现的简单验证码识别功能示例
May 02 #Python
利用Python在一个文件的头部插入数据的实例
May 02 #Python
You might like
PHP输出控制功能在简繁体转换中的应用
2006/10/09 PHP
php字符编码转换之gb2312转为utf8
2013/10/28 PHP
php简单统计在线人数的方法
2016/05/10 PHP
深入理解PHP之OpCode原理详解
2016/06/01 PHP
thinkphp利用模型通用数据编辑添加和删除的实例代码
2016/11/20 PHP
关闭ie窗口清除Session的解决方法
2014/01/10 Javascript
JS实现让网页背景图片斜向移动的方法
2015/02/25 Javascript
js实现iframe框架取值的方法(兼容IE,firefox,chrome等)
2015/11/26 Javascript
全面解析标签页的切换方式
2016/08/21 Javascript
在Vue中使用Compass的方法
2018/03/02 Javascript
javascript使用正则实现去掉字符串前面的所有0
2018/07/23 Javascript
nodejs 递归拷贝、读取目录下所有文件和目录
2019/07/18 NodeJs
vue-router跳转时打开新页面的两种方法
2019/07/29 Javascript
Python编程入门之Hello World的三种实现方式
2015/11/13 Python
实例讲解Python的函数闭包使用中应注意的问题
2016/06/20 Python
Python 数据结构之队列的实现
2017/01/22 Python
python 正确保留多位小数的实例
2018/07/16 Python
Scrapy框架使用的基本知识
2018/10/21 Python
Python 离线工作环境搭建的方法步骤
2019/07/29 Python
Django生成PDF文档显示网页上以及PDF中文显示乱码的解决方法
2019/12/17 Python
pycharm激活码有效到2020年11月底
2020/09/18 Python
探秘TensorFlow 和 NumPy 的 Broadcasting 机制
2020/03/13 Python
用ldap作为django后端用户登录验证的实现
2020/12/07 Python
如何通过安装HomeBrew来安装Python3
2020/12/23 Python
Jones New York官网:美国女装品牌,受白领女性欢迎
2019/11/26 全球购物
护理学毕业生自荐信
2013/10/02 职场文书
餐饮业的创业计划书范文
2013/12/26 职场文书
园林技术个人的自我评价
2014/02/15 职场文书
社区科普工作方案
2014/06/03 职场文书
法学求职信
2014/06/22 职场文书
个人整改方案范文
2014/10/25 职场文书
党的群众路线教育实践活动总结大会主持词
2014/10/30 职场文书
小学国庆节活动总结
2015/03/23 职场文书
2016年党建工作简报
2015/11/26 职场文书
html+css 实现简易导航栏功能
2021/04/07 HTML / CSS
分位数回归模型quantile regeression应用详解及示例教程
2021/11/02 Python