详解用TensorFlow实现逻辑回归算法


Posted in Python onMay 02, 2018

本文将实现逻辑回归算法,预测低出生体重的概率。

# Logistic Regression
# 逻辑回归
#----------------------------------
#
# This function shows how to use TensorFlow to
# solve logistic regression.
# y = sigmoid(Ax + b)
#
# We will use the low birth weight data, specifically:
# y = 0 or 1 = low birth weight
# x = demographic and medical history data

import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
import requests
from tensorflow.python.framework import ops
import os.path
import csv


ops.reset_default_graph()

# Create graph
sess = tf.Session()

###
# Obtain and prepare data for modeling
###

# name of data file
birth_weight_file = 'birth_weight.csv'

# download data and create data file if file does not exist in current directory
if not os.path.exists(birth_weight_file):
  birthdata_url = 'https://github.com/nfmcclure/tensorflow_cookbook/raw/master/01_Introduction/07_Working_with_Data_Sources/birthweight_data/birthweight.dat'
  birth_file = requests.get(birthdata_url)
  birth_data = birth_file.text.split('\r\n')
  birth_header = birth_data[0].split('\t')
  birth_data = [[float(x) for x in y.split('\t') if len(x)>=1] for y in birth_data[1:] if len(y)>=1]
  with open(birth_weight_file, "w") as f:
    writer = csv.writer(f)
    writer.writerow(birth_header)
    writer.writerows(birth_data)
    f.close()

# read birth weight data into memory
birth_data = []
with open(birth_weight_file, newline='') as csvfile:
   csv_reader = csv.reader(csvfile)
   birth_header = next(csv_reader)
   for row in csv_reader:
     birth_data.append(row)

birth_data = [[float(x) for x in row] for row in birth_data]

# Pull out target variable
y_vals = np.array([x[0] for x in birth_data])
# Pull out predictor variables (not id, not target, and not birthweight)
x_vals = np.array([x[1:8] for x in birth_data])

# set for reproducible results
seed = 99
np.random.seed(seed)
tf.set_random_seed(seed)

# Split data into train/test = 80%/20%
# 分割数据集为测试集和训练集
train_indices = np.random.choice(len(x_vals), round(len(x_vals)*0.8), replace=False)
test_indices = np.array(list(set(range(len(x_vals))) - set(train_indices)))
x_vals_train = x_vals[train_indices]
x_vals_test = x_vals[test_indices]
y_vals_train = y_vals[train_indices]
y_vals_test = y_vals[test_indices]

# Normalize by column (min-max norm)
# 将所有特征缩放到0和1区间(min-max缩放),逻辑回归收敛的效果更好
# 归一化特征
def normalize_cols(m):
  col_max = m.max(axis=0)
  col_min = m.min(axis=0)
  return (m-col_min) / (col_max - col_min)

x_vals_train = np.nan_to_num(normalize_cols(x_vals_train))
x_vals_test = np.nan_to_num(normalize_cols(x_vals_test))

###
# Define Tensorflow computational graph¶
###

# Declare batch size
batch_size = 25

# Initialize placeholders
x_data = tf.placeholder(shape=[None, 7], dtype=tf.float32)
y_target = tf.placeholder(shape=[None, 1], dtype=tf.float32)

# Create variables for linear regression
A = tf.Variable(tf.random_normal(shape=[7,1]))
b = tf.Variable(tf.random_normal(shape=[1,1]))

# Declare model operations
model_output = tf.add(tf.matmul(x_data, A), b)

# Declare loss function (Cross Entropy loss)
loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=model_output, labels=y_target))

# Declare optimizer
my_opt = tf.train.GradientDescentOptimizer(0.01)
train_step = my_opt.minimize(loss)

###
# Train model
###

# Initialize variables
init = tf.global_variables_initializer()
sess.run(init)

# Actual Prediction
# 除记录损失函数外,也需要记录分类器在训练集和测试集上的准确度。
# 所以创建一个返回准确度的预测函数
prediction = tf.round(tf.sigmoid(model_output))
predictions_correct = tf.cast(tf.equal(prediction, y_target), tf.float32)
accuracy = tf.reduce_mean(predictions_correct)

# Training loop
# 开始遍历迭代训练,记录损失值和准确度
loss_vec = []
train_acc = []
test_acc = []
for i in range(1500):
  rand_index = np.random.choice(len(x_vals_train), size=batch_size)
  rand_x = x_vals_train[rand_index]
  rand_y = np.transpose([y_vals_train[rand_index]])
  sess.run(train_step, feed_dict={x_data: rand_x, y_target: rand_y})

  temp_loss = sess.run(loss, feed_dict={x_data: rand_x, y_target: rand_y})
  loss_vec.append(temp_loss)
  temp_acc_train = sess.run(accuracy, feed_dict={x_data: x_vals_train, y_target: np.transpose([y_vals_train])})
  train_acc.append(temp_acc_train)
  temp_acc_test = sess.run(accuracy, feed_dict={x_data: x_vals_test, y_target: np.transpose([y_vals_test])})
  test_acc.append(temp_acc_test)
  if (i+1)%300==0:
    print('Loss = ' + str(temp_loss))


###
# Display model performance
###

# 绘制损失和准确度
plt.plot(loss_vec, 'k-')
plt.title('Cross Entropy Loss per Generation')
plt.xlabel('Generation')
plt.ylabel('Cross Entropy Loss')
plt.show()

# Plot train and test accuracy
plt.plot(train_acc, 'k-', label='Train Set Accuracy')
plt.plot(test_acc, 'r--', label='Test Set Accuracy')
plt.title('Train and Test Accuracy')
plt.xlabel('Generation')
plt.ylabel('Accuracy')
plt.legend(loc='lower right')
plt.show()

数据结果:

Loss = 0.845124
Loss = 0.658061
Loss = 0.471852
Loss = 0.643469
Loss = 0.672077

详解用TensorFlow实现逻辑回归算法

迭代1500次的交叉熵损失图

详解用TensorFlow实现逻辑回归算法

迭代1500次的测试集和训练集的准确度图

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
对于Python的框架中一些会话程序的管理
Apr 20 Python
pycharm中成功运行图片的配置教程
Oct 28 Python
详解Python 解压缩文件
Apr 09 Python
python集合的创建、添加及删除操作示例
Oct 08 Python
浅谈Python 参数与变量
Jun 20 Python
python反爬虫方法的优缺点分析
Nov 25 Python
python模拟点击玩游戏的实例讲解
Nov 26 Python
pycharm中leetcode插件使用图文详解
Dec 07 Python
Pandas之缺失数据的实现
Jan 06 Python
python绘图模块之利用turtle画图
Feb 12 Python
Python用SSH连接到网络设备
Feb 18 Python
浅谈Python 中的复数问题
May 19 Python
Python获取指定字符前面的所有字符方法
May 02 #Python
Python 查找字符在字符串中的位置实例
May 02 #Python
python 巧用正则寻找字符串中的特定字符的位置方法
May 02 #Python
Python 在字符串中加入变量的实例讲解
May 02 #Python
Python 实现字符串中指定位置插入一个字符
May 02 #Python
Python3实现的简单验证码识别功能示例
May 02 #Python
利用Python在一个文件的头部插入数据的实例
May 02 #Python
You might like
NT IIS下用ODBC连接数据库
2006/10/09 PHP
解析PHP中的正则表达式以及模式匹配
2013/06/19 PHP
ThinkPHP 404页面的设置方法
2015/01/14 PHP
浅析php如何实现爬取数据原理
2018/09/27 PHP
form中限制文本字节数js代码
2007/06/10 Javascript
用javascript作一个通用向导说明
2011/08/30 Javascript
多浏览器兼容性比较好的复制到剪贴板的js代码
2011/10/09 Javascript
浅析JavaScript中两种类型的全局对象/函数
2013/12/05 Javascript
Extjs的FileUploadField文件上传出现了两个上传按钮
2014/04/29 Javascript
使用jQuery和Bootstrap实现多层、自适应模态窗口
2014/12/22 Javascript
Js实现无刷新删除内容
2015/04/29 Javascript
相册展示PhotoSwipe.js插件实现
2016/08/25 Javascript
模板视图和AngularJS之间冲突的解决方法
2016/11/22 Javascript
Angular.js中数组操作的方法教程
2017/07/31 Javascript
vue中的ref和$refs的使用
2018/11/22 Javascript
js array数组对象操作方法汇总
2019/03/18 Javascript
vue百度地图 + 定位的详解
2019/05/13 Javascript
Vue 实现点击空白处隐藏某节点的三种方式(指令、普通、遮罩)
2019/10/23 Javascript
[55:16]Mski vs VGJ.S Supermajor小组赛C组 BO3 第二场 6.3
2018/06/04 DOTA
Python实现基本线性数据结构
2016/08/22 Python
简单谈谈Python流程控制语句
2016/12/04 Python
Python实现图片尺寸缩放脚本
2018/03/10 Python
Python使用MD5加密算法对字符串进行加密操作示例
2018/03/30 Python
详解Python3 pandas.merge用法
2019/09/05 Python
Python使用tkinter模块实现推箱子游戏
2019/10/08 Python
Python使用psutil获取进程信息的例子
2019/12/17 Python
python图形开发GUI库pyqt5的基本使用方法详解
2020/02/14 Python
TensorFlow2.1.0最新版本安装详细教程
2020/04/08 Python
Python实现电视里的5毛特效实例代码详解
2020/05/15 Python
美国彩妆品牌:Coastal Scents
2017/04/01 全球购物
美国翻新电子产品商店:The Store
2019/10/08 全球购物
怎样声明子类
2013/07/02 面试题
2015年建筑工程工作总结
2015/05/13 职场文书
当幸福来敲门英文观后感
2015/06/01 职场文书
秋季运动会加油词
2015/07/18 职场文书
总结三种用 Python 作为小程序后端的方式
2022/05/02 Python