详解用TensorFlow实现逻辑回归算法


Posted in Python onMay 02, 2018

本文将实现逻辑回归算法,预测低出生体重的概率。

# Logistic Regression
# 逻辑回归
#----------------------------------
#
# This function shows how to use TensorFlow to
# solve logistic regression.
# y = sigmoid(Ax + b)
#
# We will use the low birth weight data, specifically:
# y = 0 or 1 = low birth weight
# x = demographic and medical history data

import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
import requests
from tensorflow.python.framework import ops
import os.path
import csv


ops.reset_default_graph()

# Create graph
sess = tf.Session()

###
# Obtain and prepare data for modeling
###

# name of data file
birth_weight_file = 'birth_weight.csv'

# download data and create data file if file does not exist in current directory
if not os.path.exists(birth_weight_file):
  birthdata_url = 'https://github.com/nfmcclure/tensorflow_cookbook/raw/master/01_Introduction/07_Working_with_Data_Sources/birthweight_data/birthweight.dat'
  birth_file = requests.get(birthdata_url)
  birth_data = birth_file.text.split('\r\n')
  birth_header = birth_data[0].split('\t')
  birth_data = [[float(x) for x in y.split('\t') if len(x)>=1] for y in birth_data[1:] if len(y)>=1]
  with open(birth_weight_file, "w") as f:
    writer = csv.writer(f)
    writer.writerow(birth_header)
    writer.writerows(birth_data)
    f.close()

# read birth weight data into memory
birth_data = []
with open(birth_weight_file, newline='') as csvfile:
   csv_reader = csv.reader(csvfile)
   birth_header = next(csv_reader)
   for row in csv_reader:
     birth_data.append(row)

birth_data = [[float(x) for x in row] for row in birth_data]

# Pull out target variable
y_vals = np.array([x[0] for x in birth_data])
# Pull out predictor variables (not id, not target, and not birthweight)
x_vals = np.array([x[1:8] for x in birth_data])

# set for reproducible results
seed = 99
np.random.seed(seed)
tf.set_random_seed(seed)

# Split data into train/test = 80%/20%
# 分割数据集为测试集和训练集
train_indices = np.random.choice(len(x_vals), round(len(x_vals)*0.8), replace=False)
test_indices = np.array(list(set(range(len(x_vals))) - set(train_indices)))
x_vals_train = x_vals[train_indices]
x_vals_test = x_vals[test_indices]
y_vals_train = y_vals[train_indices]
y_vals_test = y_vals[test_indices]

# Normalize by column (min-max norm)
# 将所有特征缩放到0和1区间(min-max缩放),逻辑回归收敛的效果更好
# 归一化特征
def normalize_cols(m):
  col_max = m.max(axis=0)
  col_min = m.min(axis=0)
  return (m-col_min) / (col_max - col_min)

x_vals_train = np.nan_to_num(normalize_cols(x_vals_train))
x_vals_test = np.nan_to_num(normalize_cols(x_vals_test))

###
# Define Tensorflow computational graph¶
###

# Declare batch size
batch_size = 25

# Initialize placeholders
x_data = tf.placeholder(shape=[None, 7], dtype=tf.float32)
y_target = tf.placeholder(shape=[None, 1], dtype=tf.float32)

# Create variables for linear regression
A = tf.Variable(tf.random_normal(shape=[7,1]))
b = tf.Variable(tf.random_normal(shape=[1,1]))

# Declare model operations
model_output = tf.add(tf.matmul(x_data, A), b)

# Declare loss function (Cross Entropy loss)
loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=model_output, labels=y_target))

# Declare optimizer
my_opt = tf.train.GradientDescentOptimizer(0.01)
train_step = my_opt.minimize(loss)

###
# Train model
###

# Initialize variables
init = tf.global_variables_initializer()
sess.run(init)

# Actual Prediction
# 除记录损失函数外,也需要记录分类器在训练集和测试集上的准确度。
# 所以创建一个返回准确度的预测函数
prediction = tf.round(tf.sigmoid(model_output))
predictions_correct = tf.cast(tf.equal(prediction, y_target), tf.float32)
accuracy = tf.reduce_mean(predictions_correct)

# Training loop
# 开始遍历迭代训练,记录损失值和准确度
loss_vec = []
train_acc = []
test_acc = []
for i in range(1500):
  rand_index = np.random.choice(len(x_vals_train), size=batch_size)
  rand_x = x_vals_train[rand_index]
  rand_y = np.transpose([y_vals_train[rand_index]])
  sess.run(train_step, feed_dict={x_data: rand_x, y_target: rand_y})

  temp_loss = sess.run(loss, feed_dict={x_data: rand_x, y_target: rand_y})
  loss_vec.append(temp_loss)
  temp_acc_train = sess.run(accuracy, feed_dict={x_data: x_vals_train, y_target: np.transpose([y_vals_train])})
  train_acc.append(temp_acc_train)
  temp_acc_test = sess.run(accuracy, feed_dict={x_data: x_vals_test, y_target: np.transpose([y_vals_test])})
  test_acc.append(temp_acc_test)
  if (i+1)%300==0:
    print('Loss = ' + str(temp_loss))


###
# Display model performance
###

# 绘制损失和准确度
plt.plot(loss_vec, 'k-')
plt.title('Cross Entropy Loss per Generation')
plt.xlabel('Generation')
plt.ylabel('Cross Entropy Loss')
plt.show()

# Plot train and test accuracy
plt.plot(train_acc, 'k-', label='Train Set Accuracy')
plt.plot(test_acc, 'r--', label='Test Set Accuracy')
plt.title('Train and Test Accuracy')
plt.xlabel('Generation')
plt.ylabel('Accuracy')
plt.legend(loc='lower right')
plt.show()

数据结果:

Loss = 0.845124
Loss = 0.658061
Loss = 0.471852
Loss = 0.643469
Loss = 0.672077

详解用TensorFlow实现逻辑回归算法

迭代1500次的交叉熵损失图

详解用TensorFlow实现逻辑回归算法

迭代1500次的测试集和训练集的准确度图

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python33 urllib2使用方法细节讲解
Dec 03 Python
利用python生成一个导出数据库的bat脚本文件的方法
Dec 30 Python
django实现用户登陆功能详解
Dec 11 Python
windows下python 3.6.4安装配置图文教程
Aug 21 Python
Python实现钉钉发送报警消息的方法
Feb 20 Python
Django如何防止定时任务并发浅析
May 14 Python
Pandas之Fillna填充缺失数据的方法
Jun 25 Python
Python下opencv图像阈值处理的使用笔记
Aug 04 Python
linux环境下安装python虚拟环境及注意事项
Jan 07 Python
python 如何对logging日志封装
Dec 02 Python
Python  Asyncio模块实现的生产消费者模型的方法
Mar 01 Python
python 实现图与图之间的间距调整subplots_adjust
May 21 Python
Python获取指定字符前面的所有字符方法
May 02 #Python
Python 查找字符在字符串中的位置实例
May 02 #Python
python 巧用正则寻找字符串中的特定字符的位置方法
May 02 #Python
Python 在字符串中加入变量的实例讲解
May 02 #Python
Python 实现字符串中指定位置插入一个字符
May 02 #Python
Python3实现的简单验证码识别功能示例
May 02 #Python
利用Python在一个文件的头部插入数据的实例
May 02 #Python
You might like
实用函数7
2007/11/08 PHP
用PHP和Shell写Hadoop的MapReduce程序
2014/04/15 PHP
php使用str_shuffle()函数生成随机字符串的方法分析
2017/02/17 PHP
PHP使用imagick扩展实现合并图像的方法
2017/04/25 PHP
Ecshop 后台添加新功能栏目及管理权限设置教程
2017/11/21 PHP
Laravel 之url参数,获取路由参数的例子
2019/10/21 PHP
javascript 数据类型转换(parseInt,parseFloat)
2010/07/20 Javascript
原生js结合html5制作简易的双色子游戏
2015/03/30 Javascript
使用struts2+Ajax+jquery验证用户名是否已被注册
2016/03/22 Javascript
关于JS 预解释的相关理解
2016/06/28 Javascript
Js动态设置rem来实现移动端字体的自适应代码
2016/10/14 Javascript
jQuery获取Table某列的值(推荐)
2017/03/03 Javascript
vue.js 获取当前自定义属性值
2017/06/01 Javascript
VUE项目中加载已保存的笔记实例方法
2019/09/14 Javascript
javascript+css实现进度条效果
2020/03/25 Javascript
Python实现TCP/IP协议下的端口转发及重定向示例
2016/06/14 Python
Python中支持向量机SVM的使用方法详解
2017/12/26 Python
用Cython加速Python到“起飞”(推荐)
2019/08/01 Python
python读取大文件越来越慢的原因与解决
2019/08/08 Python
Python多线程及其基本使用方法实例分析
2019/10/29 Python
PyTorch 对应点相乘、矩阵相乘实例
2019/12/27 Python
Python面向对象编程基础实例分析
2020/01/17 Python
安装多个版本的TensorFlow的方法步骤
2020/04/21 Python
Python调用.net动态库实现过程解析
2020/06/05 Python
使用Numpy对特征中的异常值进行替换及条件替换方式
2020/06/08 Python
使paramiko库执行命令时在给定的时间强制退出功能的实现
2021/03/03 Python
Canon佳能美国官方商店:购买数码相机、数码单反相机、镜头和打印机
2016/11/15 全球购物
实习生自我鉴定范文
2013/12/05 职场文书
优秀教师主要事迹
2014/02/01 职场文书
《我的伯父鲁迅先生》教学反思
2014/02/12 职场文书
青蓝工程实施方案
2014/03/27 职场文书
品牌服务方案
2014/06/03 职场文书
公司任命书模板
2014/06/06 职场文书
2015年中学元旦晚会活动方案
2014/12/09 职场文书
签订劳动合同通知书
2015/04/16 职场文书
卫生保健工作总结2015
2015/05/18 职场文书