详解用TensorFlow实现逻辑回归算法


Posted in Python onMay 02, 2018

本文将实现逻辑回归算法,预测低出生体重的概率。

# Logistic Regression
# 逻辑回归
#----------------------------------
#
# This function shows how to use TensorFlow to
# solve logistic regression.
# y = sigmoid(Ax + b)
#
# We will use the low birth weight data, specifically:
# y = 0 or 1 = low birth weight
# x = demographic and medical history data

import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
import requests
from tensorflow.python.framework import ops
import os.path
import csv


ops.reset_default_graph()

# Create graph
sess = tf.Session()

###
# Obtain and prepare data for modeling
###

# name of data file
birth_weight_file = 'birth_weight.csv'

# download data and create data file if file does not exist in current directory
if not os.path.exists(birth_weight_file):
  birthdata_url = 'https://github.com/nfmcclure/tensorflow_cookbook/raw/master/01_Introduction/07_Working_with_Data_Sources/birthweight_data/birthweight.dat'
  birth_file = requests.get(birthdata_url)
  birth_data = birth_file.text.split('\r\n')
  birth_header = birth_data[0].split('\t')
  birth_data = [[float(x) for x in y.split('\t') if len(x)>=1] for y in birth_data[1:] if len(y)>=1]
  with open(birth_weight_file, "w") as f:
    writer = csv.writer(f)
    writer.writerow(birth_header)
    writer.writerows(birth_data)
    f.close()

# read birth weight data into memory
birth_data = []
with open(birth_weight_file, newline='') as csvfile:
   csv_reader = csv.reader(csvfile)
   birth_header = next(csv_reader)
   for row in csv_reader:
     birth_data.append(row)

birth_data = [[float(x) for x in row] for row in birth_data]

# Pull out target variable
y_vals = np.array([x[0] for x in birth_data])
# Pull out predictor variables (not id, not target, and not birthweight)
x_vals = np.array([x[1:8] for x in birth_data])

# set for reproducible results
seed = 99
np.random.seed(seed)
tf.set_random_seed(seed)

# Split data into train/test = 80%/20%
# 分割数据集为测试集和训练集
train_indices = np.random.choice(len(x_vals), round(len(x_vals)*0.8), replace=False)
test_indices = np.array(list(set(range(len(x_vals))) - set(train_indices)))
x_vals_train = x_vals[train_indices]
x_vals_test = x_vals[test_indices]
y_vals_train = y_vals[train_indices]
y_vals_test = y_vals[test_indices]

# Normalize by column (min-max norm)
# 将所有特征缩放到0和1区间(min-max缩放),逻辑回归收敛的效果更好
# 归一化特征
def normalize_cols(m):
  col_max = m.max(axis=0)
  col_min = m.min(axis=0)
  return (m-col_min) / (col_max - col_min)

x_vals_train = np.nan_to_num(normalize_cols(x_vals_train))
x_vals_test = np.nan_to_num(normalize_cols(x_vals_test))

###
# Define Tensorflow computational graph¶
###

# Declare batch size
batch_size = 25

# Initialize placeholders
x_data = tf.placeholder(shape=[None, 7], dtype=tf.float32)
y_target = tf.placeholder(shape=[None, 1], dtype=tf.float32)

# Create variables for linear regression
A = tf.Variable(tf.random_normal(shape=[7,1]))
b = tf.Variable(tf.random_normal(shape=[1,1]))

# Declare model operations
model_output = tf.add(tf.matmul(x_data, A), b)

# Declare loss function (Cross Entropy loss)
loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=model_output, labels=y_target))

# Declare optimizer
my_opt = tf.train.GradientDescentOptimizer(0.01)
train_step = my_opt.minimize(loss)

###
# Train model
###

# Initialize variables
init = tf.global_variables_initializer()
sess.run(init)

# Actual Prediction
# 除记录损失函数外,也需要记录分类器在训练集和测试集上的准确度。
# 所以创建一个返回准确度的预测函数
prediction = tf.round(tf.sigmoid(model_output))
predictions_correct = tf.cast(tf.equal(prediction, y_target), tf.float32)
accuracy = tf.reduce_mean(predictions_correct)

# Training loop
# 开始遍历迭代训练,记录损失值和准确度
loss_vec = []
train_acc = []
test_acc = []
for i in range(1500):
  rand_index = np.random.choice(len(x_vals_train), size=batch_size)
  rand_x = x_vals_train[rand_index]
  rand_y = np.transpose([y_vals_train[rand_index]])
  sess.run(train_step, feed_dict={x_data: rand_x, y_target: rand_y})

  temp_loss = sess.run(loss, feed_dict={x_data: rand_x, y_target: rand_y})
  loss_vec.append(temp_loss)
  temp_acc_train = sess.run(accuracy, feed_dict={x_data: x_vals_train, y_target: np.transpose([y_vals_train])})
  train_acc.append(temp_acc_train)
  temp_acc_test = sess.run(accuracy, feed_dict={x_data: x_vals_test, y_target: np.transpose([y_vals_test])})
  test_acc.append(temp_acc_test)
  if (i+1)%300==0:
    print('Loss = ' + str(temp_loss))


###
# Display model performance
###

# 绘制损失和准确度
plt.plot(loss_vec, 'k-')
plt.title('Cross Entropy Loss per Generation')
plt.xlabel('Generation')
plt.ylabel('Cross Entropy Loss')
plt.show()

# Plot train and test accuracy
plt.plot(train_acc, 'k-', label='Train Set Accuracy')
plt.plot(test_acc, 'r--', label='Test Set Accuracy')
plt.title('Train and Test Accuracy')
plt.xlabel('Generation')
plt.ylabel('Accuracy')
plt.legend(loc='lower right')
plt.show()

数据结果:

Loss = 0.845124
Loss = 0.658061
Loss = 0.471852
Loss = 0.643469
Loss = 0.672077

详解用TensorFlow实现逻辑回归算法

迭代1500次的交叉熵损失图

详解用TensorFlow实现逻辑回归算法

迭代1500次的测试集和训练集的准确度图

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python subprocess模块详细解读
Jan 29 Python
tensorflow更改变量的值实例
Jul 30 Python
Python之使用adb shell命令启动应用的方法详解
Jan 07 Python
使用python实现ftp的文件读写方法
Jul 02 Python
Python计算一个点到所有点的欧式距离实现方法
Jul 04 Python
pytorch获取vgg16-feature层输出的例子
Aug 20 Python
对Tensorflow中Device实例的生成和管理详解
Feb 04 Python
将自己的数据集制作成TFRecord格式教程
Feb 17 Python
Python使用20行代码实现微信聊天机器人
Jun 05 Python
新手常见Python错误及异常解决处理方案
Jun 18 Python
python正则表达式的懒惰匹配和贪婪匹配说明
Jul 13 Python
opencv用VS2013调试时用Image Watch插件查看图片
Jul 26 Python
Python获取指定字符前面的所有字符方法
May 02 #Python
Python 查找字符在字符串中的位置实例
May 02 #Python
python 巧用正则寻找字符串中的特定字符的位置方法
May 02 #Python
Python 在字符串中加入变量的实例讲解
May 02 #Python
Python 实现字符串中指定位置插入一个字符
May 02 #Python
Python3实现的简单验证码识别功能示例
May 02 #Python
利用Python在一个文件的头部插入数据的实例
May 02 #Python
You might like
Thinkphp模板中使用自定义函数的方法
2012/09/23 PHP
MongoDB在PHP中的常用操作小结
2014/02/20 PHP
支持中文和其他编码的php截取字符串函数分享(截取中文字符串)
2014/03/13 PHP
PHP实现JS中escape与unescape的方法
2016/07/11 PHP
PHP isset()与empty()的使用区别详解
2017/02/10 PHP
Laravel框架源码解析之反射的使用详解
2020/05/14 PHP
Js的MessageBox
2006/12/03 Javascript
JQuery中使用Ajax赋值给全局变量失败异常的解决方法
2014/08/18 Javascript
图文介绍Vue父组件向子组件传值
2018/02/17 Javascript
vue.js单文件组件中非父子组件的传值实例
2018/09/13 Javascript
Javascript查看大图功能代码实现
2020/05/07 Javascript
Jquery如何使用animation动画效果改变背景色的代码
2020/07/20 jQuery
[03:04]DOTA2超级联赛专访ZSMJ “莫名其妙”的逆袭
2013/05/23 DOTA
[55:32]2018DOTA2亚洲邀请赛 4.4 淘汰赛 EG vs LGD 第二场
2018/04/05 DOTA
Python使用MD5加密字符串示例
2014/08/22 Python
介绍Python的@property装饰器的用法
2015/04/28 Python
python获取各操作系统硬件信息的方法
2015/06/03 Python
TF-IDF算法解析与Python实现方法详解
2017/11/16 Python
Python实现的视频播放器功能完整示例
2018/02/01 Python
Python字符串的全排列算法实例详解
2019/01/07 Python
python如何使用jt400.jar包代码实例
2019/12/20 Python
CSS3实现多样的边框效果
2018/05/04 HTML / CSS
英国家电购物网站:Sonic Direct
2019/03/26 全球购物
瑞士首家网上药店折扣店:McDrogerie
2020/12/22 全球购物
农田水利实习自我鉴定
2013/09/19 职场文书
就业协议书的作用
2014/04/11 职场文书
节能减排倡议书
2014/04/15 职场文书
毕业生工作求职信
2014/06/30 职场文书
工作目标责任书
2014/07/23 职场文书
亲子阅读的活动方案
2014/08/15 职场文书
学习心理学心得体会
2016/01/22 职场文书
导游词之无锡唐城
2019/12/12 职场文书
Python制作表白爱心合集
2022/01/22 Python
Redis中有序集合的内部实现方式的详细介绍
2022/03/16 Redis
frg-100简单操作(设置)说明
2022/04/05 无线电
Elasticsearch 基本查询和组合查询
2022/04/19 Python