详解用TensorFlow实现逻辑回归算法


Posted in Python onMay 02, 2018

本文将实现逻辑回归算法,预测低出生体重的概率。

# Logistic Regression
# 逻辑回归
#----------------------------------
#
# This function shows how to use TensorFlow to
# solve logistic regression.
# y = sigmoid(Ax + b)
#
# We will use the low birth weight data, specifically:
# y = 0 or 1 = low birth weight
# x = demographic and medical history data

import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
import requests
from tensorflow.python.framework import ops
import os.path
import csv


ops.reset_default_graph()

# Create graph
sess = tf.Session()

###
# Obtain and prepare data for modeling
###

# name of data file
birth_weight_file = 'birth_weight.csv'

# download data and create data file if file does not exist in current directory
if not os.path.exists(birth_weight_file):
  birthdata_url = 'https://github.com/nfmcclure/tensorflow_cookbook/raw/master/01_Introduction/07_Working_with_Data_Sources/birthweight_data/birthweight.dat'
  birth_file = requests.get(birthdata_url)
  birth_data = birth_file.text.split('\r\n')
  birth_header = birth_data[0].split('\t')
  birth_data = [[float(x) for x in y.split('\t') if len(x)>=1] for y in birth_data[1:] if len(y)>=1]
  with open(birth_weight_file, "w") as f:
    writer = csv.writer(f)
    writer.writerow(birth_header)
    writer.writerows(birth_data)
    f.close()

# read birth weight data into memory
birth_data = []
with open(birth_weight_file, newline='') as csvfile:
   csv_reader = csv.reader(csvfile)
   birth_header = next(csv_reader)
   for row in csv_reader:
     birth_data.append(row)

birth_data = [[float(x) for x in row] for row in birth_data]

# Pull out target variable
y_vals = np.array([x[0] for x in birth_data])
# Pull out predictor variables (not id, not target, and not birthweight)
x_vals = np.array([x[1:8] for x in birth_data])

# set for reproducible results
seed = 99
np.random.seed(seed)
tf.set_random_seed(seed)

# Split data into train/test = 80%/20%
# 分割数据集为测试集和训练集
train_indices = np.random.choice(len(x_vals), round(len(x_vals)*0.8), replace=False)
test_indices = np.array(list(set(range(len(x_vals))) - set(train_indices)))
x_vals_train = x_vals[train_indices]
x_vals_test = x_vals[test_indices]
y_vals_train = y_vals[train_indices]
y_vals_test = y_vals[test_indices]

# Normalize by column (min-max norm)
# 将所有特征缩放到0和1区间(min-max缩放),逻辑回归收敛的效果更好
# 归一化特征
def normalize_cols(m):
  col_max = m.max(axis=0)
  col_min = m.min(axis=0)
  return (m-col_min) / (col_max - col_min)

x_vals_train = np.nan_to_num(normalize_cols(x_vals_train))
x_vals_test = np.nan_to_num(normalize_cols(x_vals_test))

###
# Define Tensorflow computational graph¶
###

# Declare batch size
batch_size = 25

# Initialize placeholders
x_data = tf.placeholder(shape=[None, 7], dtype=tf.float32)
y_target = tf.placeholder(shape=[None, 1], dtype=tf.float32)

# Create variables for linear regression
A = tf.Variable(tf.random_normal(shape=[7,1]))
b = tf.Variable(tf.random_normal(shape=[1,1]))

# Declare model operations
model_output = tf.add(tf.matmul(x_data, A), b)

# Declare loss function (Cross Entropy loss)
loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=model_output, labels=y_target))

# Declare optimizer
my_opt = tf.train.GradientDescentOptimizer(0.01)
train_step = my_opt.minimize(loss)

###
# Train model
###

# Initialize variables
init = tf.global_variables_initializer()
sess.run(init)

# Actual Prediction
# 除记录损失函数外,也需要记录分类器在训练集和测试集上的准确度。
# 所以创建一个返回准确度的预测函数
prediction = tf.round(tf.sigmoid(model_output))
predictions_correct = tf.cast(tf.equal(prediction, y_target), tf.float32)
accuracy = tf.reduce_mean(predictions_correct)

# Training loop
# 开始遍历迭代训练,记录损失值和准确度
loss_vec = []
train_acc = []
test_acc = []
for i in range(1500):
  rand_index = np.random.choice(len(x_vals_train), size=batch_size)
  rand_x = x_vals_train[rand_index]
  rand_y = np.transpose([y_vals_train[rand_index]])
  sess.run(train_step, feed_dict={x_data: rand_x, y_target: rand_y})

  temp_loss = sess.run(loss, feed_dict={x_data: rand_x, y_target: rand_y})
  loss_vec.append(temp_loss)
  temp_acc_train = sess.run(accuracy, feed_dict={x_data: x_vals_train, y_target: np.transpose([y_vals_train])})
  train_acc.append(temp_acc_train)
  temp_acc_test = sess.run(accuracy, feed_dict={x_data: x_vals_test, y_target: np.transpose([y_vals_test])})
  test_acc.append(temp_acc_test)
  if (i+1)%300==0:
    print('Loss = ' + str(temp_loss))


###
# Display model performance
###

# 绘制损失和准确度
plt.plot(loss_vec, 'k-')
plt.title('Cross Entropy Loss per Generation')
plt.xlabel('Generation')
plt.ylabel('Cross Entropy Loss')
plt.show()

# Plot train and test accuracy
plt.plot(train_acc, 'k-', label='Train Set Accuracy')
plt.plot(test_acc, 'r--', label='Test Set Accuracy')
plt.title('Train and Test Accuracy')
plt.xlabel('Generation')
plt.ylabel('Accuracy')
plt.legend(loc='lower right')
plt.show()

数据结果:

Loss = 0.845124
Loss = 0.658061
Loss = 0.471852
Loss = 0.643469
Loss = 0.672077

详解用TensorFlow实现逻辑回归算法

迭代1500次的交叉熵损失图

详解用TensorFlow实现逻辑回归算法

迭代1500次的测试集和训练集的准确度图

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python控制台显示时钟的示例
Feb 24 Python
对于Python装饰器使用的一些建议
Jun 03 Python
快速查询Python文档方法分享
Dec 27 Python
对python中的pop函数和append函数详解
May 04 Python
Python使用pyodbc访问数据库操作方法详解
Jul 05 Python
python 3.3 下载固定链接文件并保存的方法
Dec 18 Python
django ModelForm修改显示缩略图 imagefield类型的实例
Jul 28 Python
python实现udp传输图片功能
Mar 20 Python
基于python实现生成指定大小txt文档
Jul 20 Python
django表单中的按钮获取数据的实例分析
Jul 31 Python
如何在python中实现线性回归
Aug 10 Python
python 中[0]*2与0*2的区别说明
May 10 Python
Python获取指定字符前面的所有字符方法
May 02 #Python
Python 查找字符在字符串中的位置实例
May 02 #Python
python 巧用正则寻找字符串中的特定字符的位置方法
May 02 #Python
Python 在字符串中加入变量的实例讲解
May 02 #Python
Python 实现字符串中指定位置插入一个字符
May 02 #Python
Python3实现的简单验证码识别功能示例
May 02 #Python
利用Python在一个文件的头部插入数据的实例
May 02 #Python
You might like
re0第二季蕾姆被制作组打入冷宫!艾米莉亚女主扶正,原因唏嘘
2020/04/02 日漫
Php Image Resize图片大小调整的函数代码
2011/01/17 PHP
如何用php获取程序执行的时间
2013/06/09 PHP
php限制ip地址范围的方法
2015/03/31 PHP
PHP自定义函数实现数组比较功能示例
2017/10/19 PHP
Avengerls vs Newbee BO3 第三场2.18
2021/03/10 DOTA
JS对URL字符串进行编码/解码分析
2008/10/25 Javascript
javascript nextSibling 与 getNextElement(node) 使用介绍
2011/10/13 Javascript
模拟select的代码
2011/10/19 Javascript
VBS通过WMI监视注册表变动的代码
2011/10/27 Javascript
js操作IE浏览器弹出浏览文件夹可以返回目录路径
2014/07/14 Javascript
node.js中的fs.createWriteStream方法使用说明
2014/12/17 Javascript
Angular.js中ng-include用法及多标签页面的实现方式详解
2017/05/07 Javascript
谈谈vue中mixin的一点理解
2017/12/12 Javascript
JS实现区分中英文并统计字符个数的方法示例
2018/06/09 Javascript
Webpack devServer中的 proxy 实现跨域的解决
2018/06/15 Javascript
vue强制刷新组件的方法示例
2019/02/28 Javascript
用原生JS实现爱奇艺首页导航栏代码实例
2019/09/19 Javascript
vue 更改连接后台的api示例
2019/11/11 Javascript
浅析Vue 防抖与节流的使用
2019/11/14 Javascript
JavaScript链式调用原理与实现方法详解
2020/05/16 Javascript
Python3的urllib.parse常用函数小结(urlencode,quote,quote_plus,unquote,unquote_plus等)
2016/09/18 Python
Python实现获取系统临时目录及临时文件的方法示例
2019/06/26 Python
python向图片里添加文字
2019/11/26 Python
关于多元线性回归分析——Python&SPSS
2020/02/24 Python
Python全面分析系统的时域特性和频率域特性
2020/02/26 Python
pyinstaller将含有多个py文件的python程序做成exe
2020/04/29 Python
解决Keras中CNN输入维度报错问题
2020/06/29 Python
法国综合购物网站:RueDuCommerce
2016/09/12 全球购物
捷科时代的软件测试笔试题
2015/11/09 面试题
小学六年级学生评语
2014/04/22 职场文书
公司采购主管岗位职责
2014/06/17 职场文书
授权委托书范文
2014/07/31 职场文书
2014年党务工作总结
2014/11/25 职场文书
2015年度党风廉政建设工作情况汇报
2015/01/02 职场文书
顶岗实习计划书
2015/01/16 职场文书