详解用TensorFlow实现逻辑回归算法


Posted in Python onMay 02, 2018

本文将实现逻辑回归算法,预测低出生体重的概率。

# Logistic Regression
# 逻辑回归
#----------------------------------
#
# This function shows how to use TensorFlow to
# solve logistic regression.
# y = sigmoid(Ax + b)
#
# We will use the low birth weight data, specifically:
# y = 0 or 1 = low birth weight
# x = demographic and medical history data

import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
import requests
from tensorflow.python.framework import ops
import os.path
import csv


ops.reset_default_graph()

# Create graph
sess = tf.Session()

###
# Obtain and prepare data for modeling
###

# name of data file
birth_weight_file = 'birth_weight.csv'

# download data and create data file if file does not exist in current directory
if not os.path.exists(birth_weight_file):
  birthdata_url = 'https://github.com/nfmcclure/tensorflow_cookbook/raw/master/01_Introduction/07_Working_with_Data_Sources/birthweight_data/birthweight.dat'
  birth_file = requests.get(birthdata_url)
  birth_data = birth_file.text.split('\r\n')
  birth_header = birth_data[0].split('\t')
  birth_data = [[float(x) for x in y.split('\t') if len(x)>=1] for y in birth_data[1:] if len(y)>=1]
  with open(birth_weight_file, "w") as f:
    writer = csv.writer(f)
    writer.writerow(birth_header)
    writer.writerows(birth_data)
    f.close()

# read birth weight data into memory
birth_data = []
with open(birth_weight_file, newline='') as csvfile:
   csv_reader = csv.reader(csvfile)
   birth_header = next(csv_reader)
   for row in csv_reader:
     birth_data.append(row)

birth_data = [[float(x) for x in row] for row in birth_data]

# Pull out target variable
y_vals = np.array([x[0] for x in birth_data])
# Pull out predictor variables (not id, not target, and not birthweight)
x_vals = np.array([x[1:8] for x in birth_data])

# set for reproducible results
seed = 99
np.random.seed(seed)
tf.set_random_seed(seed)

# Split data into train/test = 80%/20%
# 分割数据集为测试集和训练集
train_indices = np.random.choice(len(x_vals), round(len(x_vals)*0.8), replace=False)
test_indices = np.array(list(set(range(len(x_vals))) - set(train_indices)))
x_vals_train = x_vals[train_indices]
x_vals_test = x_vals[test_indices]
y_vals_train = y_vals[train_indices]
y_vals_test = y_vals[test_indices]

# Normalize by column (min-max norm)
# 将所有特征缩放到0和1区间(min-max缩放),逻辑回归收敛的效果更好
# 归一化特征
def normalize_cols(m):
  col_max = m.max(axis=0)
  col_min = m.min(axis=0)
  return (m-col_min) / (col_max - col_min)

x_vals_train = np.nan_to_num(normalize_cols(x_vals_train))
x_vals_test = np.nan_to_num(normalize_cols(x_vals_test))

###
# Define Tensorflow computational graph¶
###

# Declare batch size
batch_size = 25

# Initialize placeholders
x_data = tf.placeholder(shape=[None, 7], dtype=tf.float32)
y_target = tf.placeholder(shape=[None, 1], dtype=tf.float32)

# Create variables for linear regression
A = tf.Variable(tf.random_normal(shape=[7,1]))
b = tf.Variable(tf.random_normal(shape=[1,1]))

# Declare model operations
model_output = tf.add(tf.matmul(x_data, A), b)

# Declare loss function (Cross Entropy loss)
loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=model_output, labels=y_target))

# Declare optimizer
my_opt = tf.train.GradientDescentOptimizer(0.01)
train_step = my_opt.minimize(loss)

###
# Train model
###

# Initialize variables
init = tf.global_variables_initializer()
sess.run(init)

# Actual Prediction
# 除记录损失函数外,也需要记录分类器在训练集和测试集上的准确度。
# 所以创建一个返回准确度的预测函数
prediction = tf.round(tf.sigmoid(model_output))
predictions_correct = tf.cast(tf.equal(prediction, y_target), tf.float32)
accuracy = tf.reduce_mean(predictions_correct)

# Training loop
# 开始遍历迭代训练,记录损失值和准确度
loss_vec = []
train_acc = []
test_acc = []
for i in range(1500):
  rand_index = np.random.choice(len(x_vals_train), size=batch_size)
  rand_x = x_vals_train[rand_index]
  rand_y = np.transpose([y_vals_train[rand_index]])
  sess.run(train_step, feed_dict={x_data: rand_x, y_target: rand_y})

  temp_loss = sess.run(loss, feed_dict={x_data: rand_x, y_target: rand_y})
  loss_vec.append(temp_loss)
  temp_acc_train = sess.run(accuracy, feed_dict={x_data: x_vals_train, y_target: np.transpose([y_vals_train])})
  train_acc.append(temp_acc_train)
  temp_acc_test = sess.run(accuracy, feed_dict={x_data: x_vals_test, y_target: np.transpose([y_vals_test])})
  test_acc.append(temp_acc_test)
  if (i+1)%300==0:
    print('Loss = ' + str(temp_loss))


###
# Display model performance
###

# 绘制损失和准确度
plt.plot(loss_vec, 'k-')
plt.title('Cross Entropy Loss per Generation')
plt.xlabel('Generation')
plt.ylabel('Cross Entropy Loss')
plt.show()

# Plot train and test accuracy
plt.plot(train_acc, 'k-', label='Train Set Accuracy')
plt.plot(test_acc, 'r--', label='Test Set Accuracy')
plt.title('Train and Test Accuracy')
plt.xlabel('Generation')
plt.ylabel('Accuracy')
plt.legend(loc='lower right')
plt.show()

数据结果:

Loss = 0.845124
Loss = 0.658061
Loss = 0.471852
Loss = 0.643469
Loss = 0.672077

详解用TensorFlow实现逻辑回归算法

迭代1500次的交叉熵损失图

详解用TensorFlow实现逻辑回归算法

迭代1500次的测试集和训练集的准确度图

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python操作Mysql实例代码教程在线版(查询手册)
Feb 18 Python
使用Python的Twisted框架构建非阻塞下载程序的实例教程
May 25 Python
Python  pip安装lxml出错的问题解决办法
Feb 10 Python
一个基于flask的web应用诞生 组织结构调整(7)
Apr 11 Python
django站点管理详解
Dec 12 Python
解决Tensorflow使用pip安装后没有model目录的问题
Jun 13 Python
Django如何使用第三方服务发送电子邮件
Aug 14 Python
Python 脚本拉取 Docker 镜像问题
Nov 10 Python
python3 webp转gif格式的实现示例
Dec 10 Python
Django中Aggregation聚合的基本使用方法
Jul 09 Python
Python3 pyecharts生成Html文件柱状图及折线图代码实例
Sep 29 Python
python 模块导入问题汇总
Feb 01 Python
Python获取指定字符前面的所有字符方法
May 02 #Python
Python 查找字符在字符串中的位置实例
May 02 #Python
python 巧用正则寻找字符串中的特定字符的位置方法
May 02 #Python
Python 在字符串中加入变量的实例讲解
May 02 #Python
Python 实现字符串中指定位置插入一个字符
May 02 #Python
Python3实现的简单验证码识别功能示例
May 02 #Python
利用Python在一个文件的头部插入数据的实例
May 02 #Python
You might like
PHP新手上路(四)
2006/10/09 PHP
弄了个检测传输的参数是否为数字的Function
2006/12/06 PHP
PHP 采集心得技巧
2009/05/15 PHP
php实现的九九乘法口诀表简洁版
2014/07/28 PHP
PHP开发注意事项总结
2015/02/04 PHP
PHP发送短信代码分享
2015/08/11 PHP
[原创]smarty简单模板变量输出方法
2016/07/09 PHP
Jquery 组合form元素为json格式,asp.net反序列化
2009/07/09 Javascript
Javascript面象对象成员、共享成员变量实验
2010/11/19 Javascript
各浏览器对link标签onload/onreadystatechange事件支持的差异分析
2011/04/27 Javascript
jquery增加和删除元素的方法
2015/01/14 Javascript
javascript实现的简单的表单验证
2015/07/10 Javascript
使用nodejs中httpProxy代理时候出现404异常的解决方法
2016/08/15 NodeJs
jQuery常用样式操作实例分析(获取、设置、追加、删除、判断等)
2016/09/08 Javascript
JS基于面向对象实现的选项卡效果示例
2016/12/20 Javascript
Node.js+ES6+dropload.js实现移动端下拉加载实例
2017/06/01 Javascript
echart简介_动力节点Java学院整理
2017/08/11 Javascript
基于layui框架响应式布局的一些使用详解
2019/09/16 Javascript
[01:00:26]Ti4主赛事胜者组第一天 EG vs NEWBEE 1
2014/07/19 DOTA
在Python中使用CasperJS获取JS渲染生成的HTML内容的教程
2015/04/09 Python
python将控制台输出保存至文件的方法
2019/01/07 Python
python里dict变成list实例方法
2019/06/26 Python
python里运用私有属性和方法总结
2019/07/08 Python
python树的同构学习笔记
2019/09/14 Python
tensorflow生成多个tfrecord文件实例
2020/02/17 Python
Python解析微信dat文件的方法
2020/11/30 Python
Django搭建项目实战与避坑细节详解
2020/12/06 Python
美国网上眼镜商城:Zenni Optical
2016/11/20 全球购物
培训主管的岗位职责
2013/11/23 职场文书
百货商场楼层班组长竞聘书
2014/03/31 职场文书
平面设计专业求职信
2014/08/09 职场文书
农村文化活动总结
2014/08/28 职场文书
化妆品促销活动总结
2015/05/07 职场文书
2015年公司后勤管理工作总结
2015/05/13 职场文书
《比的意义》教学反思
2016/02/18 职场文书
python神经网络Xception模型
2022/05/06 Python