用TensorFlow实现lasso回归和岭回归算法的示例


Posted in Python onMay 02, 2018

也有些正则方法可以限制回归算法输出结果中系数的影响,其中最常用的两种正则方法是lasso回归和岭回归。

lasso回归和岭回归算法跟常规线性回归算法极其相似,有一点不同的是,在公式中增加正则项来限制斜率(或者净斜率)。这样做的主要原因是限制特征对因变量的影响,通过增加一个依赖斜率A的损失函数实现。

对于lasso回归算法,在损失函数上增加一项:斜率A的某个给定倍数。我们使用TensorFlow的逻辑操作,但没有这些操作相关的梯度,而是使用阶跃函数的连续估计,也称作连续阶跃函数,其会在截止点跳跃扩大。一会就可以看到如何使用lasso回归算法。

对于岭回归算法,增加一个L2范数,即斜率系数的L2正则。

# LASSO and Ridge Regression
# lasso回归和岭回归
# 
# This function shows how to use TensorFlow to solve LASSO or 
# Ridge regression for 
# y = Ax + b
# 
# We will use the iris data, specifically: 
#  y = Sepal Length 
#  x = Petal Width

# import required libraries
import matplotlib.pyplot as plt
import sys
import numpy as np
import tensorflow as tf
from sklearn import datasets
from tensorflow.python.framework import ops


# Specify 'Ridge' or 'LASSO'
regression_type = 'LASSO'

# clear out old graph
ops.reset_default_graph()

# Create graph
sess = tf.Session()

###
# Load iris data
###

# iris.data = [(Sepal Length, Sepal Width, Petal Length, Petal Width)]
iris = datasets.load_iris()
x_vals = np.array([x[3] for x in iris.data])
y_vals = np.array([y[0] for y in iris.data])

###
# Model Parameters
###

# Declare batch size
batch_size = 50

# Initialize placeholders
x_data = tf.placeholder(shape=[None, 1], dtype=tf.float32)
y_target = tf.placeholder(shape=[None, 1], dtype=tf.float32)

# make results reproducible
seed = 13
np.random.seed(seed)
tf.set_random_seed(seed)

# Create variables for linear regression
A = tf.Variable(tf.random_normal(shape=[1,1]))
b = tf.Variable(tf.random_normal(shape=[1,1]))

# Declare model operations
model_output = tf.add(tf.matmul(x_data, A), b)

###
# Loss Functions
###

# Select appropriate loss function based on regression type

if regression_type == 'LASSO':
  # Declare Lasso loss function
  # 增加损失函数,其为改良过的连续阶跃函数,lasso回归的截止点设为0.9。
  # 这意味着限制斜率系数不超过0.9
  # Lasso Loss = L2_Loss + heavyside_step,
  # Where heavyside_step ~ 0 if A < constant, otherwise ~ 99
  lasso_param = tf.constant(0.9)
  heavyside_step = tf.truediv(1., tf.add(1., tf.exp(tf.multiply(-50., tf.subtract(A, lasso_param)))))
  regularization_param = tf.multiply(heavyside_step, 99.)
  loss = tf.add(tf.reduce_mean(tf.square(y_target - model_output)), regularization_param)

elif regression_type == 'Ridge':
  # Declare the Ridge loss function
  # Ridge loss = L2_loss + L2 norm of slope
  ridge_param = tf.constant(1.)
  ridge_loss = tf.reduce_mean(tf.square(A))
  loss = tf.expand_dims(tf.add(tf.reduce_mean(tf.square(y_target - model_output)), tf.multiply(ridge_param, ridge_loss)), 0)

else:
  print('Invalid regression_type parameter value',file=sys.stderr)


###
# Optimizer
###

# Declare optimizer
my_opt = tf.train.GradientDescentOptimizer(0.001)
train_step = my_opt.minimize(loss)

###
# Run regression
###

# Initialize variables
init = tf.global_variables_initializer()
sess.run(init)

# Training loop
loss_vec = []
for i in range(1500):
  rand_index = np.random.choice(len(x_vals), size=batch_size)
  rand_x = np.transpose([x_vals[rand_index]])
  rand_y = np.transpose([y_vals[rand_index]])
  sess.run(train_step, feed_dict={x_data: rand_x, y_target: rand_y})
  temp_loss = sess.run(loss, feed_dict={x_data: rand_x, y_target: rand_y})
  loss_vec.append(temp_loss[0])
  if (i+1)%300==0:
    print('Step #' + str(i+1) + ' A = ' + str(sess.run(A)) + ' b = ' + str(sess.run(b)))
    print('Loss = ' + str(temp_loss))
    print('\n')

###
# Extract regression results
###

# Get the optimal coefficients
[slope] = sess.run(A)
[y_intercept] = sess.run(b)

# Get best fit line
best_fit = []
for i in x_vals:
 best_fit.append(slope*i+y_intercept)


###
# Plot results
###

# Plot regression line against data points
plt.plot(x_vals, y_vals, 'o', label='Data Points')
plt.plot(x_vals, best_fit, 'r-', label='Best fit line', linewidth=3)
plt.legend(loc='upper left')
plt.title('Sepal Length vs Pedal Width')
plt.xlabel('Pedal Width')
plt.ylabel('Sepal Length')
plt.show()

# Plot loss over time
plt.plot(loss_vec, 'k-')
plt.title(regression_type + ' Loss per Generation')
plt.xlabel('Generation')
plt.ylabel('Loss')
plt.show()

输出结果:

Step #300 A = [[ 0.77170753]] b = [[ 1.82499862]]
Loss = [[ 10.26473045]]
Step #600 A = [[ 0.75908542]] b = [[ 3.2220633]]
Loss = [[ 3.06292033]]
Step #900 A = [[ 0.74843585]] b = [[ 3.9975822]]
Loss = [[ 1.23220456]]
Step #1200 A = [[ 0.73752165]] b = [[ 4.42974091]]
Loss = [[ 0.57872057]]
Step #1500 A = [[ 0.72942668]] b = [[ 4.67253113]]
Loss = [[ 0.40874988]]

用TensorFlow实现lasso回归和岭回归算法的示例 

用TensorFlow实现lasso回归和岭回归算法的示例

通过在标准线性回归估计的基础上,增加一个连续的阶跃函数,实现lasso回归算法。由于阶跃函数的坡度,我们需要注意步长,因为太大的步长会导致最终不收敛。

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python 字符串中的字符倒转
Sep 06 Python
Python获取linux主机ip的简单实现方法
Apr 18 Python
利用Python实现图书超期提醒
Aug 02 Python
python编辑用户登入界面的实现代码
Jul 16 Python
在Python dataframe中出生日期转化为年龄的实现方法
Oct 20 Python
Python魔法方法详解
Feb 13 Python
python如何获取列表中每个元素的下标位置
Jul 01 Python
Python Pandas数据结构简单介绍
Jul 03 Python
python调用并链接MATLAB脚本详解
Jul 05 Python
Python3+Appium实现多台移动设备操作的方法
Jul 05 Python
浅谈pandas.cut与pandas.qcut的使用方法及区别
Mar 03 Python
10个python爬虫入门实例(小结)
Nov 01 Python
Python实现确认字符串是否包含指定字符串的实例
May 02 #Python
详解用TensorFlow实现逻辑回归算法
May 02 #Python
Python获取指定字符前面的所有字符方法
May 02 #Python
Python 查找字符在字符串中的位置实例
May 02 #Python
python 巧用正则寻找字符串中的特定字符的位置方法
May 02 #Python
Python 在字符串中加入变量的实例讲解
May 02 #Python
Python 实现字符串中指定位置插入一个字符
May 02 #Python
You might like
用PHP进行MySQL删除记录操作代码
2008/06/07 PHP
PHP实现二叉树的深度优先与广度优先遍历方法
2015/09/28 PHP
PHP 设计模式系列之 specification规格模式
2016/01/10 PHP
Thinkphp 框架扩展之Widget扩展实现方法分析
2020/04/23 PHP
Laravel服务容器绑定的几种方法总结
2020/06/14 PHP
Javascript 网页水印(非图片水印)实现代码
2010/03/01 Javascript
JavaScript中常用的运算符小结
2012/01/18 Javascript
js自动查找select下拉的菜单并选择(示例代码)
2014/02/26 Javascript
javascript验证邮件地址和MX记录的方法
2015/06/16 Javascript
Bootstrap Metronic完全响应式管理模板之菜单栏学习笔记
2016/07/08 Javascript
Bootstrap基本组件学习笔记之列表组(11)
2016/12/07 Javascript
详解nodejs模板引擎制作
2017/06/14 NodeJs
shiro授权的实现原理
2017/09/21 Javascript
JavaScript实现封闭区域布尔运算的示例代码
2018/06/25 Javascript
vue 使用v-for进行循环的实例代码详解
2020/02/19 Javascript
Javascript实现html转pdf高清版(提高分辨率)
2020/02/19 Javascript
详解React的回调渲染模式
2020/09/10 Javascript
Python实现的检测网站挂马程序
2014/11/30 Python
python数据清洗系列之字符串处理详解
2017/02/12 Python
Django框架会话技术实例分析【Cookie与Session】
2019/05/24 Python
Python程序打包工具py2exe和PyInstaller详解
2019/06/28 Python
python读取.mat文件的数据及实例代码
2019/07/12 Python
浅谈Python3 numpy.ptp()最大值与最小值的差
2019/08/24 Python
对django 2.x版本中models.ForeignKey()外键说明介绍
2020/03/30 Python
如何在scrapy中集成selenium爬取网页的方法
2020/11/18 Python
canvas粒子动画背景的实现示例
2018/09/03 HTML / CSS
欧舒丹美国官网:L’Occitane美国
2018/02/23 全球购物
The Kooples美国官方网站:为情侣提供的法国当代时尚品牌
2019/01/03 全球购物
骨干教师培训制度
2014/01/13 职场文书
个人贷款担保书
2014/04/01 职场文书
小学评语大全
2014/04/22 职场文书
出纳试用期工作总结2015
2015/05/28 职场文书
实习单位意见
2015/06/04 职场文书
励志语录:你若不勇敢,谁替你坚强
2019/11/08 职场文书
Python实现Excel文件的合并(以新冠疫情数据为例)
2022/03/20 Python
不负正版帝国之名 《重返帝国》引领SLG手游制作新的标杆
2022/04/07 其他游戏