用TensorFlow实现lasso回归和岭回归算法的示例


Posted in Python onMay 02, 2018

也有些正则方法可以限制回归算法输出结果中系数的影响,其中最常用的两种正则方法是lasso回归和岭回归。

lasso回归和岭回归算法跟常规线性回归算法极其相似,有一点不同的是,在公式中增加正则项来限制斜率(或者净斜率)。这样做的主要原因是限制特征对因变量的影响,通过增加一个依赖斜率A的损失函数实现。

对于lasso回归算法,在损失函数上增加一项:斜率A的某个给定倍数。我们使用TensorFlow的逻辑操作,但没有这些操作相关的梯度,而是使用阶跃函数的连续估计,也称作连续阶跃函数,其会在截止点跳跃扩大。一会就可以看到如何使用lasso回归算法。

对于岭回归算法,增加一个L2范数,即斜率系数的L2正则。

# LASSO and Ridge Regression
# lasso回归和岭回归
# 
# This function shows how to use TensorFlow to solve LASSO or 
# Ridge regression for 
# y = Ax + b
# 
# We will use the iris data, specifically: 
#  y = Sepal Length 
#  x = Petal Width

# import required libraries
import matplotlib.pyplot as plt
import sys
import numpy as np
import tensorflow as tf
from sklearn import datasets
from tensorflow.python.framework import ops


# Specify 'Ridge' or 'LASSO'
regression_type = 'LASSO'

# clear out old graph
ops.reset_default_graph()

# Create graph
sess = tf.Session()

###
# Load iris data
###

# iris.data = [(Sepal Length, Sepal Width, Petal Length, Petal Width)]
iris = datasets.load_iris()
x_vals = np.array([x[3] for x in iris.data])
y_vals = np.array([y[0] for y in iris.data])

###
# Model Parameters
###

# Declare batch size
batch_size = 50

# Initialize placeholders
x_data = tf.placeholder(shape=[None, 1], dtype=tf.float32)
y_target = tf.placeholder(shape=[None, 1], dtype=tf.float32)

# make results reproducible
seed = 13
np.random.seed(seed)
tf.set_random_seed(seed)

# Create variables for linear regression
A = tf.Variable(tf.random_normal(shape=[1,1]))
b = tf.Variable(tf.random_normal(shape=[1,1]))

# Declare model operations
model_output = tf.add(tf.matmul(x_data, A), b)

###
# Loss Functions
###

# Select appropriate loss function based on regression type

if regression_type == 'LASSO':
  # Declare Lasso loss function
  # 增加损失函数,其为改良过的连续阶跃函数,lasso回归的截止点设为0.9。
  # 这意味着限制斜率系数不超过0.9
  # Lasso Loss = L2_Loss + heavyside_step,
  # Where heavyside_step ~ 0 if A < constant, otherwise ~ 99
  lasso_param = tf.constant(0.9)
  heavyside_step = tf.truediv(1., tf.add(1., tf.exp(tf.multiply(-50., tf.subtract(A, lasso_param)))))
  regularization_param = tf.multiply(heavyside_step, 99.)
  loss = tf.add(tf.reduce_mean(tf.square(y_target - model_output)), regularization_param)

elif regression_type == 'Ridge':
  # Declare the Ridge loss function
  # Ridge loss = L2_loss + L2 norm of slope
  ridge_param = tf.constant(1.)
  ridge_loss = tf.reduce_mean(tf.square(A))
  loss = tf.expand_dims(tf.add(tf.reduce_mean(tf.square(y_target - model_output)), tf.multiply(ridge_param, ridge_loss)), 0)

else:
  print('Invalid regression_type parameter value',file=sys.stderr)


###
# Optimizer
###

# Declare optimizer
my_opt = tf.train.GradientDescentOptimizer(0.001)
train_step = my_opt.minimize(loss)

###
# Run regression
###

# Initialize variables
init = tf.global_variables_initializer()
sess.run(init)

# Training loop
loss_vec = []
for i in range(1500):
  rand_index = np.random.choice(len(x_vals), size=batch_size)
  rand_x = np.transpose([x_vals[rand_index]])
  rand_y = np.transpose([y_vals[rand_index]])
  sess.run(train_step, feed_dict={x_data: rand_x, y_target: rand_y})
  temp_loss = sess.run(loss, feed_dict={x_data: rand_x, y_target: rand_y})
  loss_vec.append(temp_loss[0])
  if (i+1)%300==0:
    print('Step #' + str(i+1) + ' A = ' + str(sess.run(A)) + ' b = ' + str(sess.run(b)))
    print('Loss = ' + str(temp_loss))
    print('\n')

###
# Extract regression results
###

# Get the optimal coefficients
[slope] = sess.run(A)
[y_intercept] = sess.run(b)

# Get best fit line
best_fit = []
for i in x_vals:
 best_fit.append(slope*i+y_intercept)


###
# Plot results
###

# Plot regression line against data points
plt.plot(x_vals, y_vals, 'o', label='Data Points')
plt.plot(x_vals, best_fit, 'r-', label='Best fit line', linewidth=3)
plt.legend(loc='upper left')
plt.title('Sepal Length vs Pedal Width')
plt.xlabel('Pedal Width')
plt.ylabel('Sepal Length')
plt.show()

# Plot loss over time
plt.plot(loss_vec, 'k-')
plt.title(regression_type + ' Loss per Generation')
plt.xlabel('Generation')
plt.ylabel('Loss')
plt.show()

输出结果:

Step #300 A = [[ 0.77170753]] b = [[ 1.82499862]]
Loss = [[ 10.26473045]]
Step #600 A = [[ 0.75908542]] b = [[ 3.2220633]]
Loss = [[ 3.06292033]]
Step #900 A = [[ 0.74843585]] b = [[ 3.9975822]]
Loss = [[ 1.23220456]]
Step #1200 A = [[ 0.73752165]] b = [[ 4.42974091]]
Loss = [[ 0.57872057]]
Step #1500 A = [[ 0.72942668]] b = [[ 4.67253113]]
Loss = [[ 0.40874988]]

用TensorFlow实现lasso回归和岭回归算法的示例 

用TensorFlow实现lasso回归和岭回归算法的示例

通过在标准线性回归估计的基础上,增加一个连续的阶跃函数,实现lasso回归算法。由于阶跃函数的坡度,我们需要注意步长,因为太大的步长会导致最终不收敛。

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python逐行读取文件内容的三种方法
Jan 20 Python
python端口扫描系统实现方法
Nov 19 Python
简述Python中的面向对象编程的概念
Apr 27 Python
Python中import导入上一级目录模块及循环import问题的解决
Jun 04 Python
Python判断列表是否已排序的各种方法及其性能分析
Jun 20 Python
Python PyQt5实现的简易计算器功能示例
Aug 23 Python
python使用opencv对图像mask处理的方法
Jul 05 Python
Python坐标线性插值应用实现
Nov 13 Python
Pandas时间序列重采样(resample)方法中closed、label的作用详解
Dec 10 Python
TensorFlow实现保存训练模型为pd文件并恢复
Feb 06 Python
python 读取yaml文件的两种方法(在unittest中使用)
Dec 01 Python
如何在Python项目中引入日志
May 31 Python
Python实现确认字符串是否包含指定字符串的实例
May 02 #Python
详解用TensorFlow实现逻辑回归算法
May 02 #Python
Python获取指定字符前面的所有字符方法
May 02 #Python
Python 查找字符在字符串中的位置实例
May 02 #Python
python 巧用正则寻找字符串中的特定字符的位置方法
May 02 #Python
Python 在字符串中加入变量的实例讲解
May 02 #Python
Python 实现字符串中指定位置插入一个字符
May 02 #Python
You might like
小偷PHP+Html+缓存
2006/12/20 PHP
采集邮箱的php代码(抓取网页中的邮箱地址)
2012/07/17 PHP
thinkPHP基于ajax实现的菜单与分页示例
2016/07/12 PHP
js验证表单大全
2006/11/25 Javascript
ie和firefox中img对象区别的困惑
2006/12/27 Javascript
js 跨域和ajax 跨域问题小结
2009/07/01 Javascript
基于JQuery实现鼠标点击文本框显示隐藏提示文本
2012/02/23 Javascript
基于JavaScript实现继承机制之调用call()与apply()的方法详解
2013/05/07 Javascript
jQuery实现的多屏图像图层切换效果实例
2015/05/07 Javascript
关于JS中prototype的理解
2015/09/07 Javascript
JS实现图片平面旋转的方法
2016/03/01 Javascript
Bootstrap表格和栅格分页实例详解
2016/05/20 Javascript
JS实现的手机端精简幻灯片效果
2016/09/05 Javascript
d3.js中冷门却实用的内置函数总结
2017/02/04 Javascript
jquery实现静态搜索功能(可输入搜索文字)
2017/03/28 jQuery
微信小程序 功能函数小结(手机号验证*、密码验证*、获取验证码*)
2017/12/08 Javascript
原生js实现随机点名功能
2019/11/05 Javascript
vue+render+jsx实现可编辑动态多级表头table的实例代码
2020/04/01 Javascript
jquery html添加元素/删除元素操作实例详解
2020/05/20 jQuery
[03:09]2014DOTA2国际邀请赛 Mushi前队友送上祝福
2014/07/12 DOTA
[01:17:47]TNC vs VGJ.S 2018国际邀请赛小组赛BO2 第一场 8.18
2018/08/19 DOTA
python实现颜色rgb和hex相互转换的函数
2015/03/19 Python
在Python中使用模块的教程
2015/04/27 Python
python 自动重连wifi windows的方法
2018/12/18 Python
python简单实现AES加密和解密
2019/03/28 Python
python 求某条线上特定x值或y值的点坐标方法
2019/07/09 Python
pytorch之Resize()函数具体使用详解
2020/02/27 Python
Python实现对adb命令封装
2020/03/06 Python
详解pandas.DataFrame.plot() 画图函数
2020/06/14 Python
Python -m参数原理及使用方法解析
2020/08/21 Python
HTML5 LocalStorage 本地存储详细概括(多图)
2017/08/18 HTML / CSS
JackJones官方旗舰店:杰克琼斯男装
2018/03/27 全球购物
高校学生干部的自我评价分享
2013/11/04 职场文书
学生检讨书范文
2014/10/30 职场文书
golang中的空slice案例
2021/04/27 Golang
python not运算符的实例用法
2021/06/30 Python