用TensorFlow实现lasso回归和岭回归算法的示例


Posted in Python onMay 02, 2018

也有些正则方法可以限制回归算法输出结果中系数的影响,其中最常用的两种正则方法是lasso回归和岭回归。

lasso回归和岭回归算法跟常规线性回归算法极其相似,有一点不同的是,在公式中增加正则项来限制斜率(或者净斜率)。这样做的主要原因是限制特征对因变量的影响,通过增加一个依赖斜率A的损失函数实现。

对于lasso回归算法,在损失函数上增加一项:斜率A的某个给定倍数。我们使用TensorFlow的逻辑操作,但没有这些操作相关的梯度,而是使用阶跃函数的连续估计,也称作连续阶跃函数,其会在截止点跳跃扩大。一会就可以看到如何使用lasso回归算法。

对于岭回归算法,增加一个L2范数,即斜率系数的L2正则。

# LASSO and Ridge Regression
# lasso回归和岭回归
# 
# This function shows how to use TensorFlow to solve LASSO or 
# Ridge regression for 
# y = Ax + b
# 
# We will use the iris data, specifically: 
#  y = Sepal Length 
#  x = Petal Width

# import required libraries
import matplotlib.pyplot as plt
import sys
import numpy as np
import tensorflow as tf
from sklearn import datasets
from tensorflow.python.framework import ops


# Specify 'Ridge' or 'LASSO'
regression_type = 'LASSO'

# clear out old graph
ops.reset_default_graph()

# Create graph
sess = tf.Session()

###
# Load iris data
###

# iris.data = [(Sepal Length, Sepal Width, Petal Length, Petal Width)]
iris = datasets.load_iris()
x_vals = np.array([x[3] for x in iris.data])
y_vals = np.array([y[0] for y in iris.data])

###
# Model Parameters
###

# Declare batch size
batch_size = 50

# Initialize placeholders
x_data = tf.placeholder(shape=[None, 1], dtype=tf.float32)
y_target = tf.placeholder(shape=[None, 1], dtype=tf.float32)

# make results reproducible
seed = 13
np.random.seed(seed)
tf.set_random_seed(seed)

# Create variables for linear regression
A = tf.Variable(tf.random_normal(shape=[1,1]))
b = tf.Variable(tf.random_normal(shape=[1,1]))

# Declare model operations
model_output = tf.add(tf.matmul(x_data, A), b)

###
# Loss Functions
###

# Select appropriate loss function based on regression type

if regression_type == 'LASSO':
  # Declare Lasso loss function
  # 增加损失函数,其为改良过的连续阶跃函数,lasso回归的截止点设为0.9。
  # 这意味着限制斜率系数不超过0.9
  # Lasso Loss = L2_Loss + heavyside_step,
  # Where heavyside_step ~ 0 if A < constant, otherwise ~ 99
  lasso_param = tf.constant(0.9)
  heavyside_step = tf.truediv(1., tf.add(1., tf.exp(tf.multiply(-50., tf.subtract(A, lasso_param)))))
  regularization_param = tf.multiply(heavyside_step, 99.)
  loss = tf.add(tf.reduce_mean(tf.square(y_target - model_output)), regularization_param)

elif regression_type == 'Ridge':
  # Declare the Ridge loss function
  # Ridge loss = L2_loss + L2 norm of slope
  ridge_param = tf.constant(1.)
  ridge_loss = tf.reduce_mean(tf.square(A))
  loss = tf.expand_dims(tf.add(tf.reduce_mean(tf.square(y_target - model_output)), tf.multiply(ridge_param, ridge_loss)), 0)

else:
  print('Invalid regression_type parameter value',file=sys.stderr)


###
# Optimizer
###

# Declare optimizer
my_opt = tf.train.GradientDescentOptimizer(0.001)
train_step = my_opt.minimize(loss)

###
# Run regression
###

# Initialize variables
init = tf.global_variables_initializer()
sess.run(init)

# Training loop
loss_vec = []
for i in range(1500):
  rand_index = np.random.choice(len(x_vals), size=batch_size)
  rand_x = np.transpose([x_vals[rand_index]])
  rand_y = np.transpose([y_vals[rand_index]])
  sess.run(train_step, feed_dict={x_data: rand_x, y_target: rand_y})
  temp_loss = sess.run(loss, feed_dict={x_data: rand_x, y_target: rand_y})
  loss_vec.append(temp_loss[0])
  if (i+1)%300==0:
    print('Step #' + str(i+1) + ' A = ' + str(sess.run(A)) + ' b = ' + str(sess.run(b)))
    print('Loss = ' + str(temp_loss))
    print('\n')

###
# Extract regression results
###

# Get the optimal coefficients
[slope] = sess.run(A)
[y_intercept] = sess.run(b)

# Get best fit line
best_fit = []
for i in x_vals:
 best_fit.append(slope*i+y_intercept)


###
# Plot results
###

# Plot regression line against data points
plt.plot(x_vals, y_vals, 'o', label='Data Points')
plt.plot(x_vals, best_fit, 'r-', label='Best fit line', linewidth=3)
plt.legend(loc='upper left')
plt.title('Sepal Length vs Pedal Width')
plt.xlabel('Pedal Width')
plt.ylabel('Sepal Length')
plt.show()

# Plot loss over time
plt.plot(loss_vec, 'k-')
plt.title(regression_type + ' Loss per Generation')
plt.xlabel('Generation')
plt.ylabel('Loss')
plt.show()

输出结果:

Step #300 A = [[ 0.77170753]] b = [[ 1.82499862]]
Loss = [[ 10.26473045]]
Step #600 A = [[ 0.75908542]] b = [[ 3.2220633]]
Loss = [[ 3.06292033]]
Step #900 A = [[ 0.74843585]] b = [[ 3.9975822]]
Loss = [[ 1.23220456]]
Step #1200 A = [[ 0.73752165]] b = [[ 4.42974091]]
Loss = [[ 0.57872057]]
Step #1500 A = [[ 0.72942668]] b = [[ 4.67253113]]
Loss = [[ 0.40874988]]

用TensorFlow实现lasso回归和岭回归算法的示例 

用TensorFlow实现lasso回归和岭回归算法的示例

通过在标准线性回归估计的基础上,增加一个连续的阶跃函数,实现lasso回归算法。由于阶跃函数的坡度,我们需要注意步长,因为太大的步长会导致最终不收敛。

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python海龟绘图实例教程
Jul 24 Python
python实现的jpg格式图片修复代码
Apr 21 Python
简单介绍使用Python解析并修改XML文档的方法
Oct 15 Python
Python工程师面试题 与Python基础语法相关
Jan 14 Python
Python中矩阵库Numpy基本操作详解
Nov 21 Python
python中abs&amp;map&amp;reduce简介
Feb 20 Python
浅谈pandas中DataFrame关于显示值省略的解决方法
Apr 08 Python
使用pandas对两个dataframe进行join的实例
Jun 08 Python
python远程连接服务器MySQL数据库
Jul 02 Python
Python检查图片是否损坏及图片类型是否正确过程详解
Sep 30 Python
计算Python Numpy向量之间的欧氏距离实例
May 22 Python
Python GUI库Tkiner使用方法代码示例
Nov 27 Python
Python实现确认字符串是否包含指定字符串的实例
May 02 #Python
详解用TensorFlow实现逻辑回归算法
May 02 #Python
Python获取指定字符前面的所有字符方法
May 02 #Python
Python 查找字符在字符串中的位置实例
May 02 #Python
python 巧用正则寻找字符串中的特定字符的位置方法
May 02 #Python
Python 在字符串中加入变量的实例讲解
May 02 #Python
Python 实现字符串中指定位置插入一个字符
May 02 #Python
You might like
PHP与javascript实现变量交互的示例代码
2013/07/23 PHP
php实现把数组按指定的个数分隔
2014/02/17 PHP
thinkphp3.x中display方法及show方法的用法实例
2016/05/19 PHP
RSA实现JS前端加密与PHP后端解密功能示例
2019/08/05 PHP
JavaScript与C# Windows应用程序交互方法
2007/06/29 Javascript
javascript FormatNumber函数实现方法
2008/12/30 Javascript
vue-router2.0 组件之间传参及获取动态参数的方法
2017/11/10 Javascript
arcgis for js栅格图层叠加(Raster Layer)问题
2017/11/22 Javascript
Vue2.0系列之过滤器的使用
2018/03/01 Javascript
Bootstrap table中toolbar新增条件查询及refresh参数使用方法
2018/05/18 Javascript
webpack+vue-cil中proxyTable处理跨域的方法
2018/07/20 Javascript
jQuery实现input输入框获取焦点与失去焦点时提示的消失与显示功能示例
2019/05/27 jQuery
[01:02:04]EG vs Liquid 2019国际邀请赛淘汰赛 败者组 BO3 第一场 8.23
2019/09/05 DOTA
详解Python中用于计算指数的exp()方法
2015/05/14 Python
Python自动调用IE打开某个网站的方法
2015/06/03 Python
Python使用urllib2模块实现断点续传下载的方法
2015/06/17 Python
python自动zip压缩目录的方法
2015/06/28 Python
python的依赖管理的实现
2019/05/14 Python
PYQT5实现控制台显示功能的方法
2019/06/25 Python
python中如何实现将数据分成训练集与测试集的方法
2019/09/13 Python
python实现输入任意一个大写字母生成金字塔的示例
2019/10/27 Python
pygame库实现移动底座弹球小游戏
2020/04/14 Python
python如何实时获取tcpdump输出
2020/09/16 Python
Django+Django-Celery+Celery的整合实战
2021/01/20 Python
英国玛莎百货美国官网:Marks & Spencer美国
2018/11/06 全球购物
Onzie官网:美国时尚瑜伽品牌
2019/08/21 全球购物
个人近期表现材料
2014/02/11 职场文书
交通事故私了协议书
2014/04/16 职场文书
中学生纪念九一八事变演讲稿
2014/09/14 职场文书
信息合作协议书
2014/10/09 职场文书
普宁寺导游词
2015/02/04 职场文书
党支部创先争优公开承诺书
2015/04/30 职场文书
邓小平文选读书笔记
2015/06/29 职场文书
小学一年级班主任工作经验交流材料
2015/11/02 职场文书
妇联2016年六一国际儿童节活动总结
2016/04/06 职场文书
手把手教你导入Go语言第三方库
2021/08/04 Golang