用TensorFlow实现lasso回归和岭回归算法的示例


Posted in Python onMay 02, 2018

也有些正则方法可以限制回归算法输出结果中系数的影响,其中最常用的两种正则方法是lasso回归和岭回归。

lasso回归和岭回归算法跟常规线性回归算法极其相似,有一点不同的是,在公式中增加正则项来限制斜率(或者净斜率)。这样做的主要原因是限制特征对因变量的影响,通过增加一个依赖斜率A的损失函数实现。

对于lasso回归算法,在损失函数上增加一项:斜率A的某个给定倍数。我们使用TensorFlow的逻辑操作,但没有这些操作相关的梯度,而是使用阶跃函数的连续估计,也称作连续阶跃函数,其会在截止点跳跃扩大。一会就可以看到如何使用lasso回归算法。

对于岭回归算法,增加一个L2范数,即斜率系数的L2正则。

# LASSO and Ridge Regression
# lasso回归和岭回归
# 
# This function shows how to use TensorFlow to solve LASSO or 
# Ridge regression for 
# y = Ax + b
# 
# We will use the iris data, specifically: 
#  y = Sepal Length 
#  x = Petal Width

# import required libraries
import matplotlib.pyplot as plt
import sys
import numpy as np
import tensorflow as tf
from sklearn import datasets
from tensorflow.python.framework import ops


# Specify 'Ridge' or 'LASSO'
regression_type = 'LASSO'

# clear out old graph
ops.reset_default_graph()

# Create graph
sess = tf.Session()

###
# Load iris data
###

# iris.data = [(Sepal Length, Sepal Width, Petal Length, Petal Width)]
iris = datasets.load_iris()
x_vals = np.array([x[3] for x in iris.data])
y_vals = np.array([y[0] for y in iris.data])

###
# Model Parameters
###

# Declare batch size
batch_size = 50

# Initialize placeholders
x_data = tf.placeholder(shape=[None, 1], dtype=tf.float32)
y_target = tf.placeholder(shape=[None, 1], dtype=tf.float32)

# make results reproducible
seed = 13
np.random.seed(seed)
tf.set_random_seed(seed)

# Create variables for linear regression
A = tf.Variable(tf.random_normal(shape=[1,1]))
b = tf.Variable(tf.random_normal(shape=[1,1]))

# Declare model operations
model_output = tf.add(tf.matmul(x_data, A), b)

###
# Loss Functions
###

# Select appropriate loss function based on regression type

if regression_type == 'LASSO':
  # Declare Lasso loss function
  # 增加损失函数,其为改良过的连续阶跃函数,lasso回归的截止点设为0.9。
  # 这意味着限制斜率系数不超过0.9
  # Lasso Loss = L2_Loss + heavyside_step,
  # Where heavyside_step ~ 0 if A < constant, otherwise ~ 99
  lasso_param = tf.constant(0.9)
  heavyside_step = tf.truediv(1., tf.add(1., tf.exp(tf.multiply(-50., tf.subtract(A, lasso_param)))))
  regularization_param = tf.multiply(heavyside_step, 99.)
  loss = tf.add(tf.reduce_mean(tf.square(y_target - model_output)), regularization_param)

elif regression_type == 'Ridge':
  # Declare the Ridge loss function
  # Ridge loss = L2_loss + L2 norm of slope
  ridge_param = tf.constant(1.)
  ridge_loss = tf.reduce_mean(tf.square(A))
  loss = tf.expand_dims(tf.add(tf.reduce_mean(tf.square(y_target - model_output)), tf.multiply(ridge_param, ridge_loss)), 0)

else:
  print('Invalid regression_type parameter value',file=sys.stderr)


###
# Optimizer
###

# Declare optimizer
my_opt = tf.train.GradientDescentOptimizer(0.001)
train_step = my_opt.minimize(loss)

###
# Run regression
###

# Initialize variables
init = tf.global_variables_initializer()
sess.run(init)

# Training loop
loss_vec = []
for i in range(1500):
  rand_index = np.random.choice(len(x_vals), size=batch_size)
  rand_x = np.transpose([x_vals[rand_index]])
  rand_y = np.transpose([y_vals[rand_index]])
  sess.run(train_step, feed_dict={x_data: rand_x, y_target: rand_y})
  temp_loss = sess.run(loss, feed_dict={x_data: rand_x, y_target: rand_y})
  loss_vec.append(temp_loss[0])
  if (i+1)%300==0:
    print('Step #' + str(i+1) + ' A = ' + str(sess.run(A)) + ' b = ' + str(sess.run(b)))
    print('Loss = ' + str(temp_loss))
    print('\n')

###
# Extract regression results
###

# Get the optimal coefficients
[slope] = sess.run(A)
[y_intercept] = sess.run(b)

# Get best fit line
best_fit = []
for i in x_vals:
 best_fit.append(slope*i+y_intercept)


###
# Plot results
###

# Plot regression line against data points
plt.plot(x_vals, y_vals, 'o', label='Data Points')
plt.plot(x_vals, best_fit, 'r-', label='Best fit line', linewidth=3)
plt.legend(loc='upper left')
plt.title('Sepal Length vs Pedal Width')
plt.xlabel('Pedal Width')
plt.ylabel('Sepal Length')
plt.show()

# Plot loss over time
plt.plot(loss_vec, 'k-')
plt.title(regression_type + ' Loss per Generation')
plt.xlabel('Generation')
plt.ylabel('Loss')
plt.show()

输出结果:

Step #300 A = [[ 0.77170753]] b = [[ 1.82499862]]
Loss = [[ 10.26473045]]
Step #600 A = [[ 0.75908542]] b = [[ 3.2220633]]
Loss = [[ 3.06292033]]
Step #900 A = [[ 0.74843585]] b = [[ 3.9975822]]
Loss = [[ 1.23220456]]
Step #1200 A = [[ 0.73752165]] b = [[ 4.42974091]]
Loss = [[ 0.57872057]]
Step #1500 A = [[ 0.72942668]] b = [[ 4.67253113]]
Loss = [[ 0.40874988]]

用TensorFlow实现lasso回归和岭回归算法的示例 

用TensorFlow实现lasso回归和岭回归算法的示例

通过在标准线性回归估计的基础上,增加一个连续的阶跃函数,实现lasso回归算法。由于阶跃函数的坡度,我们需要注意步长,因为太大的步长会导致最终不收敛。

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python检测lvs real server状态
Jan 22 Python
python字符串与url编码的转换实例
May 10 Python
python使用Matplotlib画饼图
Sep 25 Python
python实现kmp算法的实例代码
Apr 03 Python
python3 xpath和requests应用详解
Mar 06 Python
django-利用session机制实现唯一登录的例子
Mar 16 Python
在jupyter notebook中调用.ipynb文件方式
Apr 14 Python
Numpy一维线性插值函数的用法
Apr 22 Python
python查找特定名称文件并按序号、文件名分行打印输出的方法
Apr 24 Python
flask开启多线程的具体方法
Aug 02 Python
Python3如何使用多线程升程序运行速度
Aug 11 Python
python实现移动木板小游戏
Oct 09 Python
Python实现确认字符串是否包含指定字符串的实例
May 02 #Python
详解用TensorFlow实现逻辑回归算法
May 02 #Python
Python获取指定字符前面的所有字符方法
May 02 #Python
Python 查找字符在字符串中的位置实例
May 02 #Python
python 巧用正则寻找字符串中的特定字符的位置方法
May 02 #Python
Python 在字符串中加入变量的实例讲解
May 02 #Python
Python 实现字符串中指定位置插入一个字符
May 02 #Python
You might like
Notice: Undefined index: page in E:\PHP\test.php on line 14
2010/11/02 PHP
Laravel 5 框架入门(三)
2015/04/09 PHP
PHP书写格式详解(必看)
2016/05/23 PHP
PHP convert_uudecode()函数讲解
2019/02/14 PHP
ajax 文件上传应用简单实现
2009/03/03 Javascript
无闪烁更新网页内容JS实现
2013/12/19 Javascript
在JavaScript的jQuery库中操作AJAX的方法讲解
2015/08/15 Javascript
jQuery抛物线运动实现方法(附完整demo源码下载)
2016/01/08 Javascript
详解webpack+angular2开发环境搭建
2017/06/28 Javascript
关于jquery form表单序列化的注意事项详解
2017/08/01 jQuery
深入浅析ES6 Class 中的 super 关键字
2017/10/20 Javascript
jQuery实现的鼠标响应缓冲动画效果示例
2018/02/13 jQuery
jQuery 改变P标签文本值方法
2018/02/24 jQuery
Vue中的组件及路由使用实例代码详解
2019/05/22 Javascript
vue实现商城秒杀倒计时功能
2019/12/12 Javascript
Vue实现 点击显示再点击隐藏效果(点击页面空白区域也隐藏效果)
2020/01/16 Javascript
使用Python脚本来获取Cisco设备信息的示例
2015/05/04 Python
Python使用面向对象方式创建线程实现12306售票系统
2015/12/24 Python
Python+Opencv识别两张相似图片
2020/03/23 Python
Python通过调用mysql存储过程实现更新数据功能示例
2018/04/03 Python
浅谈关于Python3中venv虚拟环境
2018/08/01 Python
Python文件读写常见用法总结
2019/02/22 Python
给Python学习者的文件读写指南(含基础与进阶)
2020/01/29 Python
Python实现疫情地图可视化
2021/02/05 Python
施华洛世奇美国官网:SWAROVSKI美国
2018/02/08 全球购物
Orvis官网:自1856年以来,优质服装、飞钓装备等
2018/12/17 全球购物
PHP高级工程师面试问题推荐
2013/01/18 面试题
大学生自荐书范文
2013/12/10 职场文书
护理专业毕业生自荐信范文
2014/01/05 职场文书
八一建军节部队活动方案
2014/02/04 职场文书
股权转让意向书
2014/04/01 职场文书
2015高考寄语集锦
2015/02/27 职场文书
集团财务总监岗位职责
2015/04/03 职场文书
监理中标通知书
2015/04/16 职场文书
2015年大学迎新工作总结
2015/07/16 职场文书
vue中使用mockjs配置和使用方式
2022/04/06 Vue.js