详解用TensorFlow实现逻辑回归算法


Posted in Python onMay 02, 2018

本文将实现逻辑回归算法,预测低出生体重的概率。

# Logistic Regression
# 逻辑回归
#----------------------------------
#
# This function shows how to use TensorFlow to
# solve logistic regression.
# y = sigmoid(Ax + b)
#
# We will use the low birth weight data, specifically:
# y = 0 or 1 = low birth weight
# x = demographic and medical history data

import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
import requests
from tensorflow.python.framework import ops
import os.path
import csv


ops.reset_default_graph()

# Create graph
sess = tf.Session()

###
# Obtain and prepare data for modeling
###

# name of data file
birth_weight_file = 'birth_weight.csv'

# download data and create data file if file does not exist in current directory
if not os.path.exists(birth_weight_file):
  birthdata_url = 'https://github.com/nfmcclure/tensorflow_cookbook/raw/master/01_Introduction/07_Working_with_Data_Sources/birthweight_data/birthweight.dat'
  birth_file = requests.get(birthdata_url)
  birth_data = birth_file.text.split('\r\n')
  birth_header = birth_data[0].split('\t')
  birth_data = [[float(x) for x in y.split('\t') if len(x)>=1] for y in birth_data[1:] if len(y)>=1]
  with open(birth_weight_file, "w") as f:
    writer = csv.writer(f)
    writer.writerow(birth_header)
    writer.writerows(birth_data)
    f.close()

# read birth weight data into memory
birth_data = []
with open(birth_weight_file, newline='') as csvfile:
   csv_reader = csv.reader(csvfile)
   birth_header = next(csv_reader)
   for row in csv_reader:
     birth_data.append(row)

birth_data = [[float(x) for x in row] for row in birth_data]

# Pull out target variable
y_vals = np.array([x[0] for x in birth_data])
# Pull out predictor variables (not id, not target, and not birthweight)
x_vals = np.array([x[1:8] for x in birth_data])

# set for reproducible results
seed = 99
np.random.seed(seed)
tf.set_random_seed(seed)

# Split data into train/test = 80%/20%
# 分割数据集为测试集和训练集
train_indices = np.random.choice(len(x_vals), round(len(x_vals)*0.8), replace=False)
test_indices = np.array(list(set(range(len(x_vals))) - set(train_indices)))
x_vals_train = x_vals[train_indices]
x_vals_test = x_vals[test_indices]
y_vals_train = y_vals[train_indices]
y_vals_test = y_vals[test_indices]

# Normalize by column (min-max norm)
# 将所有特征缩放到0和1区间(min-max缩放),逻辑回归收敛的效果更好
# 归一化特征
def normalize_cols(m):
  col_max = m.max(axis=0)
  col_min = m.min(axis=0)
  return (m-col_min) / (col_max - col_min)

x_vals_train = np.nan_to_num(normalize_cols(x_vals_train))
x_vals_test = np.nan_to_num(normalize_cols(x_vals_test))

###
# Define Tensorflow computational graph¶
###

# Declare batch size
batch_size = 25

# Initialize placeholders
x_data = tf.placeholder(shape=[None, 7], dtype=tf.float32)
y_target = tf.placeholder(shape=[None, 1], dtype=tf.float32)

# Create variables for linear regression
A = tf.Variable(tf.random_normal(shape=[7,1]))
b = tf.Variable(tf.random_normal(shape=[1,1]))

# Declare model operations
model_output = tf.add(tf.matmul(x_data, A), b)

# Declare loss function (Cross Entropy loss)
loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=model_output, labels=y_target))

# Declare optimizer
my_opt = tf.train.GradientDescentOptimizer(0.01)
train_step = my_opt.minimize(loss)

###
# Train model
###

# Initialize variables
init = tf.global_variables_initializer()
sess.run(init)

# Actual Prediction
# 除记录损失函数外,也需要记录分类器在训练集和测试集上的准确度。
# 所以创建一个返回准确度的预测函数
prediction = tf.round(tf.sigmoid(model_output))
predictions_correct = tf.cast(tf.equal(prediction, y_target), tf.float32)
accuracy = tf.reduce_mean(predictions_correct)

# Training loop
# 开始遍历迭代训练,记录损失值和准确度
loss_vec = []
train_acc = []
test_acc = []
for i in range(1500):
  rand_index = np.random.choice(len(x_vals_train), size=batch_size)
  rand_x = x_vals_train[rand_index]
  rand_y = np.transpose([y_vals_train[rand_index]])
  sess.run(train_step, feed_dict={x_data: rand_x, y_target: rand_y})

  temp_loss = sess.run(loss, feed_dict={x_data: rand_x, y_target: rand_y})
  loss_vec.append(temp_loss)
  temp_acc_train = sess.run(accuracy, feed_dict={x_data: x_vals_train, y_target: np.transpose([y_vals_train])})
  train_acc.append(temp_acc_train)
  temp_acc_test = sess.run(accuracy, feed_dict={x_data: x_vals_test, y_target: np.transpose([y_vals_test])})
  test_acc.append(temp_acc_test)
  if (i+1)%300==0:
    print('Loss = ' + str(temp_loss))


###
# Display model performance
###

# 绘制损失和准确度
plt.plot(loss_vec, 'k-')
plt.title('Cross Entropy Loss per Generation')
plt.xlabel('Generation')
plt.ylabel('Cross Entropy Loss')
plt.show()

# Plot train and test accuracy
plt.plot(train_acc, 'k-', label='Train Set Accuracy')
plt.plot(test_acc, 'r--', label='Test Set Accuracy')
plt.title('Train and Test Accuracy')
plt.xlabel('Generation')
plt.ylabel('Accuracy')
plt.legend(loc='lower right')
plt.show()

数据结果:

Loss = 0.845124
Loss = 0.658061
Loss = 0.471852
Loss = 0.643469
Loss = 0.672077

详解用TensorFlow实现逻辑回归算法

迭代1500次的交叉熵损失图

详解用TensorFlow实现逻辑回归算法

迭代1500次的测试集和训练集的准确度图

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python正则表达式经典入门教程
May 22 Python
Python中一些不为人知的基础技巧总结
May 19 Python
Python logging模块用法示例
Aug 28 Python
python实现反转部分单向链表
Sep 27 Python
Django中的ajax请求
Oct 19 Python
Python面向对象之类的内置attr属性示例
Dec 14 Python
对python模块中多个类的用法详解
Jan 10 Python
Python对象与引用的介绍
Jan 24 Python
python 图片二值化处理(处理后为纯黑白的图片)
Nov 01 Python
python hash每次调用结果不同的原因
Nov 21 Python
Python数组拼接np.concatenate实现过程
Apr 18 Python
python glom模块的使用简介
Apr 13 Python
Python获取指定字符前面的所有字符方法
May 02 #Python
Python 查找字符在字符串中的位置实例
May 02 #Python
python 巧用正则寻找字符串中的特定字符的位置方法
May 02 #Python
Python 在字符串中加入变量的实例讲解
May 02 #Python
Python 实现字符串中指定位置插入一个字符
May 02 #Python
Python3实现的简单验证码识别功能示例
May 02 #Python
利用Python在一个文件的头部插入数据的实例
May 02 #Python
You might like
php array_map array_multisort 高效处理多维数组排序
2009/06/11 PHP
php 中文字符入库或显示乱码问题的解决方法
2010/04/12 PHP
PHP 中 Orientation 属性判断上传图片是否需要旋转
2015/10/16 PHP
php简单检测404页面的方法示例
2019/08/23 PHP
js滚动条多种样式,推荐
2007/02/05 Javascript
jquery实现点击文字可编辑并修改保存至数据库
2014/04/15 Javascript
滚动条响应鼠标滑轮事件实现上下滚动的js代码
2014/06/30 Javascript
上传图片预览JS脚本 Input file图片预览的实现示例
2014/10/23 Javascript
Javascript常用字符串判断函数代码分享
2014/12/08 Javascript
使用JQuery FancyBox插件实现图片展示特效
2015/11/16 Javascript
jQuery使用JSONP实现跨域获取数据的三种方法详解
2017/05/04 jQuery
微信小程序开发图片拖拽实例详解
2017/05/05 Javascript
用 Vue.js 递归组件实现可折叠的树形菜单(demo)
2017/12/25 Javascript
原生nodejs使用websocket代码分享
2018/04/07 NodeJs
详解vue-router 命名路由和命名视图
2018/06/01 Javascript
vue.js内置组件之keep-alive组件使用
2018/07/10 Javascript
微信小程序实现的3d轮播图效果示例【基于swiper组件】
2018/12/11 Javascript
vue中实现点击按钮滚动到页面对应位置的方法(使用c3平滑属性实现)
2019/12/29 Javascript
[01:45]2014DOTA2 TI预选赛预选赛 大神专访第二弹!
2014/05/20 DOTA
[01:07:02]DOTA2-DPC中国联赛 正赛 iG vs PSG.LGD BO3 第三场 2月26日
2021/03/11 DOTA
Python 分析Nginx访问日志并保存到MySQL数据库实例
2014/03/13 Python
Python中使用bidict模块双向字典结构的奇技淫巧
2016/07/12 Python
Python实现去除图片中指定颜色的像素功能示例
2019/04/13 Python
python ffmpeg任意提取视频帧的方法
2020/02/21 Python
Python 自由定制表格的实现示例
2020/03/20 Python
Python3爬虫中Selenium的用法详解
2020/07/10 Python
PyTorch 中的傅里叶卷积实现示例
2020/12/11 Python
使用Filters滤镜弥补CSS3的跨浏览器问题以及兼容低版本IE
2013/01/23 HTML / CSS
一些Unix笔试题和面试题
2013/01/22 面试题
高校自主招生自荐信
2013/12/09 职场文书
留学自荐信写作方法
2014/01/27 职场文书
入职担保书怎么写
2014/05/12 职场文书
幼儿园六一儿童节活动方案
2014/08/26 职场文书
擅自离岗检讨书
2014/09/12 职场文书
卫生院义诊活动总结
2015/05/07 职场文书
企业宣传语大全
2015/07/13 职场文书