详解用TensorFlow实现逻辑回归算法


Posted in Python onMay 02, 2018

本文将实现逻辑回归算法,预测低出生体重的概率。

# Logistic Regression
# 逻辑回归
#----------------------------------
#
# This function shows how to use TensorFlow to
# solve logistic regression.
# y = sigmoid(Ax + b)
#
# We will use the low birth weight data, specifically:
# y = 0 or 1 = low birth weight
# x = demographic and medical history data

import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
import requests
from tensorflow.python.framework import ops
import os.path
import csv


ops.reset_default_graph()

# Create graph
sess = tf.Session()

###
# Obtain and prepare data for modeling
###

# name of data file
birth_weight_file = 'birth_weight.csv'

# download data and create data file if file does not exist in current directory
if not os.path.exists(birth_weight_file):
  birthdata_url = 'https://github.com/nfmcclure/tensorflow_cookbook/raw/master/01_Introduction/07_Working_with_Data_Sources/birthweight_data/birthweight.dat'
  birth_file = requests.get(birthdata_url)
  birth_data = birth_file.text.split('\r\n')
  birth_header = birth_data[0].split('\t')
  birth_data = [[float(x) for x in y.split('\t') if len(x)>=1] for y in birth_data[1:] if len(y)>=1]
  with open(birth_weight_file, "w") as f:
    writer = csv.writer(f)
    writer.writerow(birth_header)
    writer.writerows(birth_data)
    f.close()

# read birth weight data into memory
birth_data = []
with open(birth_weight_file, newline='') as csvfile:
   csv_reader = csv.reader(csvfile)
   birth_header = next(csv_reader)
   for row in csv_reader:
     birth_data.append(row)

birth_data = [[float(x) for x in row] for row in birth_data]

# Pull out target variable
y_vals = np.array([x[0] for x in birth_data])
# Pull out predictor variables (not id, not target, and not birthweight)
x_vals = np.array([x[1:8] for x in birth_data])

# set for reproducible results
seed = 99
np.random.seed(seed)
tf.set_random_seed(seed)

# Split data into train/test = 80%/20%
# 分割数据集为测试集和训练集
train_indices = np.random.choice(len(x_vals), round(len(x_vals)*0.8), replace=False)
test_indices = np.array(list(set(range(len(x_vals))) - set(train_indices)))
x_vals_train = x_vals[train_indices]
x_vals_test = x_vals[test_indices]
y_vals_train = y_vals[train_indices]
y_vals_test = y_vals[test_indices]

# Normalize by column (min-max norm)
# 将所有特征缩放到0和1区间(min-max缩放),逻辑回归收敛的效果更好
# 归一化特征
def normalize_cols(m):
  col_max = m.max(axis=0)
  col_min = m.min(axis=0)
  return (m-col_min) / (col_max - col_min)

x_vals_train = np.nan_to_num(normalize_cols(x_vals_train))
x_vals_test = np.nan_to_num(normalize_cols(x_vals_test))

###
# Define Tensorflow computational graph¶
###

# Declare batch size
batch_size = 25

# Initialize placeholders
x_data = tf.placeholder(shape=[None, 7], dtype=tf.float32)
y_target = tf.placeholder(shape=[None, 1], dtype=tf.float32)

# Create variables for linear regression
A = tf.Variable(tf.random_normal(shape=[7,1]))
b = tf.Variable(tf.random_normal(shape=[1,1]))

# Declare model operations
model_output = tf.add(tf.matmul(x_data, A), b)

# Declare loss function (Cross Entropy loss)
loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=model_output, labels=y_target))

# Declare optimizer
my_opt = tf.train.GradientDescentOptimizer(0.01)
train_step = my_opt.minimize(loss)

###
# Train model
###

# Initialize variables
init = tf.global_variables_initializer()
sess.run(init)

# Actual Prediction
# 除记录损失函数外,也需要记录分类器在训练集和测试集上的准确度。
# 所以创建一个返回准确度的预测函数
prediction = tf.round(tf.sigmoid(model_output))
predictions_correct = tf.cast(tf.equal(prediction, y_target), tf.float32)
accuracy = tf.reduce_mean(predictions_correct)

# Training loop
# 开始遍历迭代训练,记录损失值和准确度
loss_vec = []
train_acc = []
test_acc = []
for i in range(1500):
  rand_index = np.random.choice(len(x_vals_train), size=batch_size)
  rand_x = x_vals_train[rand_index]
  rand_y = np.transpose([y_vals_train[rand_index]])
  sess.run(train_step, feed_dict={x_data: rand_x, y_target: rand_y})

  temp_loss = sess.run(loss, feed_dict={x_data: rand_x, y_target: rand_y})
  loss_vec.append(temp_loss)
  temp_acc_train = sess.run(accuracy, feed_dict={x_data: x_vals_train, y_target: np.transpose([y_vals_train])})
  train_acc.append(temp_acc_train)
  temp_acc_test = sess.run(accuracy, feed_dict={x_data: x_vals_test, y_target: np.transpose([y_vals_test])})
  test_acc.append(temp_acc_test)
  if (i+1)%300==0:
    print('Loss = ' + str(temp_loss))


###
# Display model performance
###

# 绘制损失和准确度
plt.plot(loss_vec, 'k-')
plt.title('Cross Entropy Loss per Generation')
plt.xlabel('Generation')
plt.ylabel('Cross Entropy Loss')
plt.show()

# Plot train and test accuracy
plt.plot(train_acc, 'k-', label='Train Set Accuracy')
plt.plot(test_acc, 'r--', label='Test Set Accuracy')
plt.title('Train and Test Accuracy')
plt.xlabel('Generation')
plt.ylabel('Accuracy')
plt.legend(loc='lower right')
plt.show()

数据结果:

Loss = 0.845124
Loss = 0.658061
Loss = 0.471852
Loss = 0.643469
Loss = 0.672077

详解用TensorFlow实现逻辑回归算法

迭代1500次的交叉熵损失图

详解用TensorFlow实现逻辑回归算法

迭代1500次的测试集和训练集的准确度图

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python中的一些类型转换函数小结
Feb 10 Python
采用python实现简单QQ单用户机器人的方法
Jul 03 Python
Django基于ORM操作数据库的方法详解
Mar 27 Python
解决Pycharm中import时无法识别自己写的程序方法
May 18 Python
浅谈Python脚本开头及导包注释自动添加方法
Oct 27 Python
使用CodeMirror实现Python3在线编辑器的示例代码
Jan 14 Python
python读出当前时间精度到秒的代码
Jul 05 Python
在Matplotlib图中插入LaTex公式实例
Apr 17 Python
python+selenium+chrome批量文件下载并自动创建文件夹实例
Apr 27 Python
django日志默认打印request请求信息的方法示例
May 17 Python
从python读取sql的实例方法
Jul 21 Python
python用Configobj模块读取配置文件
Sep 26 Python
Python获取指定字符前面的所有字符方法
May 02 #Python
Python 查找字符在字符串中的位置实例
May 02 #Python
python 巧用正则寻找字符串中的特定字符的位置方法
May 02 #Python
Python 在字符串中加入变量的实例讲解
May 02 #Python
Python 实现字符串中指定位置插入一个字符
May 02 #Python
Python3实现的简单验证码识别功能示例
May 02 #Python
利用Python在一个文件的头部插入数据的实例
May 02 #Python
You might like
非常实用的PHP常用函数汇总
2014/12/17 PHP
PHP单例模式与工厂模式详解
2017/08/29 PHP
发现的以前不知道的函数
2006/09/19 Javascript
div+css布局的图片连续滚动js实现代码
2010/05/04 Javascript
防止页面被iframe(兼容IE,Firefox火狐)
2010/07/04 Javascript
jquery中ajax学习笔记4
2011/10/16 Javascript
用html+css+js实现的一个简单的图片切换特效
2014/05/28 Javascript
jQuery实现二级下拉菜单效果
2016/01/05 Javascript
jQuery实现拖拽页面元素并将其保存到cookie的方法
2016/06/12 Javascript
关于Iframe父页面与子页面之间的相互调用
2016/11/22 Javascript
JS排序之选择排序详解
2017/04/08 Javascript
json字符串传到前台input的方法
2018/08/06 Javascript
vue-cli 3.x 配置Axios(proxyTable)跨域代理方法
2018/09/19 Javascript
vue实现条件叠加搜索的解决方法
2019/05/28 Javascript
详解vue 2.6 中 slot 的新用法
2019/07/09 Javascript
layui button 按钮弹出提示窗口,确定才进行的方法
2019/09/06 Javascript
使用layer弹窗,制作编辑User信息页面的方法
2019/09/27 Javascript
jQuery实现鼠标放置名字上显示详细内容气泡提示框效果的方法分析
2020/04/04 jQuery
解决ant design vue中树形控件defaultExpandAll设置无效的问题
2020/10/26 Javascript
[47:46]完美世界DOTA2联赛 Magma vs GXR 第三场 11.07
2020/11/10 DOTA
python从网络读取图片并直接进行处理的方法
2015/05/22 Python
Python实现购物车程序
2018/04/16 Python
Python企业编码生成系统之主程序模块设计详解
2019/07/26 Python
python实现人性化显示金额数字实例详解
2020/09/25 Python
Pretty You London官网:英国拖鞋和睡衣品牌
2019/05/08 全球购物
迪卡侬比利时官网:Decathlon比利时
2019/12/28 全球购物
The North Face官方旗舰店:美国著名户外品牌
2020/09/28 全球购物
中专毕业自我鉴定
2013/10/16 职场文书
计算机应用职专应届生求职信
2013/11/12 职场文书
寒假实习自荐信
2014/01/26 职场文书
料理师求职信
2014/01/30 职场文书
餐厅总厨求职信
2014/03/04 职场文书
2014年班长个人工作总结
2014/11/14 职场文书
2016新年慰问信范文
2015/03/25 职场文书
《秋天的怀念》教学反思
2016/02/17 职场文书
Python读取文件夹下的所有文件实例代码
2021/04/02 Python