用TensorFlow实现lasso回归和岭回归算法的示例


Posted in Python onMay 02, 2018

也有些正则方法可以限制回归算法输出结果中系数的影响,其中最常用的两种正则方法是lasso回归和岭回归。

lasso回归和岭回归算法跟常规线性回归算法极其相似,有一点不同的是,在公式中增加正则项来限制斜率(或者净斜率)。这样做的主要原因是限制特征对因变量的影响,通过增加一个依赖斜率A的损失函数实现。

对于lasso回归算法,在损失函数上增加一项:斜率A的某个给定倍数。我们使用TensorFlow的逻辑操作,但没有这些操作相关的梯度,而是使用阶跃函数的连续估计,也称作连续阶跃函数,其会在截止点跳跃扩大。一会就可以看到如何使用lasso回归算法。

对于岭回归算法,增加一个L2范数,即斜率系数的L2正则。

# LASSO and Ridge Regression
# lasso回归和岭回归
# 
# This function shows how to use TensorFlow to solve LASSO or 
# Ridge regression for 
# y = Ax + b
# 
# We will use the iris data, specifically: 
#  y = Sepal Length 
#  x = Petal Width

# import required libraries
import matplotlib.pyplot as plt
import sys
import numpy as np
import tensorflow as tf
from sklearn import datasets
from tensorflow.python.framework import ops


# Specify 'Ridge' or 'LASSO'
regression_type = 'LASSO'

# clear out old graph
ops.reset_default_graph()

# Create graph
sess = tf.Session()

###
# Load iris data
###

# iris.data = [(Sepal Length, Sepal Width, Petal Length, Petal Width)]
iris = datasets.load_iris()
x_vals = np.array([x[3] for x in iris.data])
y_vals = np.array([y[0] for y in iris.data])

###
# Model Parameters
###

# Declare batch size
batch_size = 50

# Initialize placeholders
x_data = tf.placeholder(shape=[None, 1], dtype=tf.float32)
y_target = tf.placeholder(shape=[None, 1], dtype=tf.float32)

# make results reproducible
seed = 13
np.random.seed(seed)
tf.set_random_seed(seed)

# Create variables for linear regression
A = tf.Variable(tf.random_normal(shape=[1,1]))
b = tf.Variable(tf.random_normal(shape=[1,1]))

# Declare model operations
model_output = tf.add(tf.matmul(x_data, A), b)

###
# Loss Functions
###

# Select appropriate loss function based on regression type

if regression_type == 'LASSO':
  # Declare Lasso loss function
  # 增加损失函数,其为改良过的连续阶跃函数,lasso回归的截止点设为0.9。
  # 这意味着限制斜率系数不超过0.9
  # Lasso Loss = L2_Loss + heavyside_step,
  # Where heavyside_step ~ 0 if A < constant, otherwise ~ 99
  lasso_param = tf.constant(0.9)
  heavyside_step = tf.truediv(1., tf.add(1., tf.exp(tf.multiply(-50., tf.subtract(A, lasso_param)))))
  regularization_param = tf.multiply(heavyside_step, 99.)
  loss = tf.add(tf.reduce_mean(tf.square(y_target - model_output)), regularization_param)

elif regression_type == 'Ridge':
  # Declare the Ridge loss function
  # Ridge loss = L2_loss + L2 norm of slope
  ridge_param = tf.constant(1.)
  ridge_loss = tf.reduce_mean(tf.square(A))
  loss = tf.expand_dims(tf.add(tf.reduce_mean(tf.square(y_target - model_output)), tf.multiply(ridge_param, ridge_loss)), 0)

else:
  print('Invalid regression_type parameter value',file=sys.stderr)


###
# Optimizer
###

# Declare optimizer
my_opt = tf.train.GradientDescentOptimizer(0.001)
train_step = my_opt.minimize(loss)

###
# Run regression
###

# Initialize variables
init = tf.global_variables_initializer()
sess.run(init)

# Training loop
loss_vec = []
for i in range(1500):
  rand_index = np.random.choice(len(x_vals), size=batch_size)
  rand_x = np.transpose([x_vals[rand_index]])
  rand_y = np.transpose([y_vals[rand_index]])
  sess.run(train_step, feed_dict={x_data: rand_x, y_target: rand_y})
  temp_loss = sess.run(loss, feed_dict={x_data: rand_x, y_target: rand_y})
  loss_vec.append(temp_loss[0])
  if (i+1)%300==0:
    print('Step #' + str(i+1) + ' A = ' + str(sess.run(A)) + ' b = ' + str(sess.run(b)))
    print('Loss = ' + str(temp_loss))
    print('\n')

###
# Extract regression results
###

# Get the optimal coefficients
[slope] = sess.run(A)
[y_intercept] = sess.run(b)

# Get best fit line
best_fit = []
for i in x_vals:
 best_fit.append(slope*i+y_intercept)


###
# Plot results
###

# Plot regression line against data points
plt.plot(x_vals, y_vals, 'o', label='Data Points')
plt.plot(x_vals, best_fit, 'r-', label='Best fit line', linewidth=3)
plt.legend(loc='upper left')
plt.title('Sepal Length vs Pedal Width')
plt.xlabel('Pedal Width')
plt.ylabel('Sepal Length')
plt.show()

# Plot loss over time
plt.plot(loss_vec, 'k-')
plt.title(regression_type + ' Loss per Generation')
plt.xlabel('Generation')
plt.ylabel('Loss')
plt.show()

输出结果:

Step #300 A = [[ 0.77170753]] b = [[ 1.82499862]]
Loss = [[ 10.26473045]]
Step #600 A = [[ 0.75908542]] b = [[ 3.2220633]]
Loss = [[ 3.06292033]]
Step #900 A = [[ 0.74843585]] b = [[ 3.9975822]]
Loss = [[ 1.23220456]]
Step #1200 A = [[ 0.73752165]] b = [[ 4.42974091]]
Loss = [[ 0.57872057]]
Step #1500 A = [[ 0.72942668]] b = [[ 4.67253113]]
Loss = [[ 0.40874988]]

用TensorFlow实现lasso回归和岭回归算法的示例 

用TensorFlow实现lasso回归和岭回归算法的示例

通过在标准线性回归估计的基础上,增加一个连续的阶跃函数,实现lasso回归算法。由于阶跃函数的坡度,我们需要注意步长,因为太大的步长会导致最终不收敛。

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
django model去掉unique_together报错的解决方案
Oct 18 Python
python模块smtplib实现纯文本邮件发送功能
May 22 Python
numpy.linspace 生成等差数组的方法
Jul 02 Python
python PIL和CV对 图片的读取,显示,裁剪,保存实现方法
Aug 07 Python
Python Django 简单分页的实现代码解析
Aug 21 Python
pyqt5、qtdesigner安装和环境设置教程
Sep 25 Python
用Python解数独的方法示例
Oct 24 Python
python中有函数重载吗
May 28 Python
python各种excel写入方式的速度对比
Nov 10 Python
python中sys模块的介绍与实例
Apr 17 Python
Python办公自动化之教你如何用Python将任意文件转为PDF格式
Jun 28 Python
python的变量和简单数字类型详解
Sep 15 Python
Python实现确认字符串是否包含指定字符串的实例
May 02 #Python
详解用TensorFlow实现逻辑回归算法
May 02 #Python
Python获取指定字符前面的所有字符方法
May 02 #Python
Python 查找字符在字符串中的位置实例
May 02 #Python
python 巧用正则寻找字符串中的特定字符的位置方法
May 02 #Python
Python 在字符串中加入变量的实例讲解
May 02 #Python
Python 实现字符串中指定位置插入一个字符
May 02 #Python
You might like
php实现的zip文件内容比较类
2014/09/24 PHP
laravel框架路由分组,中间件,命名空间,子域名,路由前缀实例分析
2020/02/18 PHP
用javascript获取当页面上鼠标光标位置和触发事件的对象的代码
2009/12/09 Javascript
基于jquery的划词搜索实现(备忘)
2010/09/14 Javascript
Jquery选择器中使用变量实现动态选择例子
2014/07/25 Javascript
jquery单选框radio绑定click事件实现方法
2015/01/14 Javascript
JavaScript实现基于十进制的四舍五入实例
2015/07/17 Javascript
jquery基础知识第一讲之认识jquery
2016/03/17 Javascript
关于cookie的初识和运用(js和jq)
2016/04/07 Javascript
javascript面向对象程序设计高级特性经典教程(值得收藏)
2016/05/19 Javascript
利用fecha进行JS日期处理
2016/11/21 Javascript
Express之get,pos请求参数的获取
2017/05/02 Javascript
微信小程序实现皮肤功能(夜间模式)
2017/06/18 Javascript
详谈js模块化规范
2017/07/07 Javascript
微信小程序的生命周期的详解
2017/10/19 Javascript
Vue.directive 自定义指令的问题小结
2018/03/04 Javascript
手写Node静态资源服务器的实现方法
2018/03/20 Javascript
js中位运算的运用实例分析
2018/12/11 Javascript
使用Karma做vue组件单元测试的实现
2020/01/16 Javascript
原生js实现html手机端城市列表索引选择城市
2020/06/24 Javascript
Python变量和数据类型详解
2017/02/15 Python
Python cookbook(数据结构与算法)从序列中移除重复项且保持元素间顺序不变的方法
2018/03/13 Python
python和shell获取文本内容的方法
2018/06/05 Python
python3+PyQt5 自定义窗口部件--使用窗口部件样式表的方法
2019/06/26 Python
Python Des加密解密如何实现软件注册码机器码
2020/01/08 Python
在Ubuntu 20.04中安装Pycharm 2020.1的图文教程
2020/04/30 Python
Django:使用filter的pk进行多值查询操作
2020/07/15 Python
印度手工编织服装和家居用品商店:Fabindi
2019/10/07 全球购物
财务部经理岗位职责
2014/02/03 职场文书
美术指导求职信
2014/03/17 职场文书
冬季安全检查方案
2014/05/23 职场文书
人力资源管理求职信
2014/08/07 职场文书
涪陵白鹤梁导游词
2015/02/09 职场文书
安全员岗位职责范本
2015/04/11 职场文书
千与千寻观后感
2015/06/04 职场文书
如何用threejs实现实时多边形折射
2021/05/07 Javascript