用TensorFlow实现lasso回归和岭回归算法的示例


Posted in Python onMay 02, 2018

也有些正则方法可以限制回归算法输出结果中系数的影响,其中最常用的两种正则方法是lasso回归和岭回归。

lasso回归和岭回归算法跟常规线性回归算法极其相似,有一点不同的是,在公式中增加正则项来限制斜率(或者净斜率)。这样做的主要原因是限制特征对因变量的影响,通过增加一个依赖斜率A的损失函数实现。

对于lasso回归算法,在损失函数上增加一项:斜率A的某个给定倍数。我们使用TensorFlow的逻辑操作,但没有这些操作相关的梯度,而是使用阶跃函数的连续估计,也称作连续阶跃函数,其会在截止点跳跃扩大。一会就可以看到如何使用lasso回归算法。

对于岭回归算法,增加一个L2范数,即斜率系数的L2正则。

# LASSO and Ridge Regression
# lasso回归和岭回归
# 
# This function shows how to use TensorFlow to solve LASSO or 
# Ridge regression for 
# y = Ax + b
# 
# We will use the iris data, specifically: 
#  y = Sepal Length 
#  x = Petal Width

# import required libraries
import matplotlib.pyplot as plt
import sys
import numpy as np
import tensorflow as tf
from sklearn import datasets
from tensorflow.python.framework import ops


# Specify 'Ridge' or 'LASSO'
regression_type = 'LASSO'

# clear out old graph
ops.reset_default_graph()

# Create graph
sess = tf.Session()

###
# Load iris data
###

# iris.data = [(Sepal Length, Sepal Width, Petal Length, Petal Width)]
iris = datasets.load_iris()
x_vals = np.array([x[3] for x in iris.data])
y_vals = np.array([y[0] for y in iris.data])

###
# Model Parameters
###

# Declare batch size
batch_size = 50

# Initialize placeholders
x_data = tf.placeholder(shape=[None, 1], dtype=tf.float32)
y_target = tf.placeholder(shape=[None, 1], dtype=tf.float32)

# make results reproducible
seed = 13
np.random.seed(seed)
tf.set_random_seed(seed)

# Create variables for linear regression
A = tf.Variable(tf.random_normal(shape=[1,1]))
b = tf.Variable(tf.random_normal(shape=[1,1]))

# Declare model operations
model_output = tf.add(tf.matmul(x_data, A), b)

###
# Loss Functions
###

# Select appropriate loss function based on regression type

if regression_type == 'LASSO':
  # Declare Lasso loss function
  # 增加损失函数,其为改良过的连续阶跃函数,lasso回归的截止点设为0.9。
  # 这意味着限制斜率系数不超过0.9
  # Lasso Loss = L2_Loss + heavyside_step,
  # Where heavyside_step ~ 0 if A < constant, otherwise ~ 99
  lasso_param = tf.constant(0.9)
  heavyside_step = tf.truediv(1., tf.add(1., tf.exp(tf.multiply(-50., tf.subtract(A, lasso_param)))))
  regularization_param = tf.multiply(heavyside_step, 99.)
  loss = tf.add(tf.reduce_mean(tf.square(y_target - model_output)), regularization_param)

elif regression_type == 'Ridge':
  # Declare the Ridge loss function
  # Ridge loss = L2_loss + L2 norm of slope
  ridge_param = tf.constant(1.)
  ridge_loss = tf.reduce_mean(tf.square(A))
  loss = tf.expand_dims(tf.add(tf.reduce_mean(tf.square(y_target - model_output)), tf.multiply(ridge_param, ridge_loss)), 0)

else:
  print('Invalid regression_type parameter value',file=sys.stderr)


###
# Optimizer
###

# Declare optimizer
my_opt = tf.train.GradientDescentOptimizer(0.001)
train_step = my_opt.minimize(loss)

###
# Run regression
###

# Initialize variables
init = tf.global_variables_initializer()
sess.run(init)

# Training loop
loss_vec = []
for i in range(1500):
  rand_index = np.random.choice(len(x_vals), size=batch_size)
  rand_x = np.transpose([x_vals[rand_index]])
  rand_y = np.transpose([y_vals[rand_index]])
  sess.run(train_step, feed_dict={x_data: rand_x, y_target: rand_y})
  temp_loss = sess.run(loss, feed_dict={x_data: rand_x, y_target: rand_y})
  loss_vec.append(temp_loss[0])
  if (i+1)%300==0:
    print('Step #' + str(i+1) + ' A = ' + str(sess.run(A)) + ' b = ' + str(sess.run(b)))
    print('Loss = ' + str(temp_loss))
    print('\n')

###
# Extract regression results
###

# Get the optimal coefficients
[slope] = sess.run(A)
[y_intercept] = sess.run(b)

# Get best fit line
best_fit = []
for i in x_vals:
 best_fit.append(slope*i+y_intercept)


###
# Plot results
###

# Plot regression line against data points
plt.plot(x_vals, y_vals, 'o', label='Data Points')
plt.plot(x_vals, best_fit, 'r-', label='Best fit line', linewidth=3)
plt.legend(loc='upper left')
plt.title('Sepal Length vs Pedal Width')
plt.xlabel('Pedal Width')
plt.ylabel('Sepal Length')
plt.show()

# Plot loss over time
plt.plot(loss_vec, 'k-')
plt.title(regression_type + ' Loss per Generation')
plt.xlabel('Generation')
plt.ylabel('Loss')
plt.show()

输出结果:

Step #300 A = [[ 0.77170753]] b = [[ 1.82499862]]
Loss = [[ 10.26473045]]
Step #600 A = [[ 0.75908542]] b = [[ 3.2220633]]
Loss = [[ 3.06292033]]
Step #900 A = [[ 0.74843585]] b = [[ 3.9975822]]
Loss = [[ 1.23220456]]
Step #1200 A = [[ 0.73752165]] b = [[ 4.42974091]]
Loss = [[ 0.57872057]]
Step #1500 A = [[ 0.72942668]] b = [[ 4.67253113]]
Loss = [[ 0.40874988]]

用TensorFlow实现lasso回归和岭回归算法的示例 

用TensorFlow实现lasso回归和岭回归算法的示例

通过在标准线性回归估计的基础上,增加一个连续的阶跃函数,实现lasso回归算法。由于阶跃函数的坡度,我们需要注意步长,因为太大的步长会导致最终不收敛。

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python 实现删除文件或文件夹实例详解
Dec 04 Python
python 实现自动远程登陆scp文件实例代码
Mar 13 Python
Python如何发布程序的详细教程
Oct 09 Python
解读python如何实现决策树算法
Oct 11 Python
基于python历史天气采集的分析
Feb 14 Python
Python语言检测模块langid和langdetect的使用实例
Feb 19 Python
Python Gitlab Api 使用方法
Aug 28 Python
PyTorch预训练的实现
Sep 18 Python
最新2019Pycharm安装教程 亲测
Feb 28 Python
Python通过Schema实现数据验证方式
Nov 12 Python
python读取excel数据并且画图的实现示例
Feb 08 Python
如何在向量化NumPy数组上进行移动窗口
May 18 Python
Python实现确认字符串是否包含指定字符串的实例
May 02 #Python
详解用TensorFlow实现逻辑回归算法
May 02 #Python
Python获取指定字符前面的所有字符方法
May 02 #Python
Python 查找字符在字符串中的位置实例
May 02 #Python
python 巧用正则寻找字符串中的特定字符的位置方法
May 02 #Python
Python 在字符串中加入变量的实例讲解
May 02 #Python
Python 实现字符串中指定位置插入一个字符
May 02 #Python
You might like
PHP文本操作类
2006/11/25 PHP
discuz Passport 通行证 整合笔记
2008/06/30 PHP
使用PHP实现阻止用户上传成人照片或者裸照
2014/12/25 PHP
在JavaScript中实现命名空间
2006/11/23 Javascript
JavaScript开发规范要求(规范化代码)
2010/08/16 Javascript
js中复制行和删除行的操作实例
2013/06/25 Javascript
JS文本框追加多个下拉框的值的简单实例
2013/07/12 Javascript
Javascript字符串对象的常用方法简明版
2014/06/26 Javascript
js实现文字垂直滚动和鼠标悬停效果
2015/12/31 Javascript
AngularJS入门教程之更多模板详解
2016/08/19 Javascript
Javascript 实现微信分享(QQ、朋友圈、分享给朋友)
2016/10/21 Javascript
jQuery插件zTree实现更新根节点中第i个节点名称的方法示例
2017/03/08 Javascript
微信小程序 动态绑定数据及动态事件处理
2017/03/14 Javascript
深入理解JavaScript创建对象的多种方式以及优缺点
2017/06/01 Javascript
详解如何写出一个利于扩展的vue路由配置
2019/05/16 Javascript
[04:45]DOTA2-DPC中国联赛正赛 iG vs LBZS 赛后选手采访
2021/03/11 DOTA
python ip正则式
2009/05/07 Python
零基础写python爬虫之抓取糗事百科代码分享
2014/11/06 Python
使用Python对MySQL数据操作
2017/04/06 Python
Python使用Matplotlib模块时坐标轴标题中文及各种特殊符号显示方法
2018/05/04 Python
Python多线程同步---文件读写控制方法
2019/02/12 Python
Python 转换RGB颜色值的示例代码
2019/10/13 Python
Python编程快速上手——PDF文件操作案例分析
2020/02/28 Python
Python 输出详细的异常信息(traceback)方式
2020/04/08 Python
使用keras根据层名称来初始化网络
2020/05/21 Python
意大利团购网站:Groupon意大利
2016/10/11 全球购物
美国领先的奢侈美容零售商:Bluemercury
2017/07/26 全球购物
连卡佛中国官网:Lane Crawford中文站
2018/01/27 全球购物
联想台湾官网:Lenovo TW
2018/05/09 全球购物
某公司部分笔试题
2013/11/05 面试题
领导班子三严三实对照检查材料
2014/09/25 职场文书
安全员岗位职责
2015/02/10 职场文书
大学生简历自我评价2015
2015/03/03 职场文书
药店收银员岗位职责
2015/04/07 职场文书
航班延误投诉信
2015/07/02 职场文书
golang 在windows中设置环境变量的操作
2021/04/29 Golang