用TensorFlow实现lasso回归和岭回归算法的示例


Posted in Python onMay 02, 2018

也有些正则方法可以限制回归算法输出结果中系数的影响,其中最常用的两种正则方法是lasso回归和岭回归。

lasso回归和岭回归算法跟常规线性回归算法极其相似,有一点不同的是,在公式中增加正则项来限制斜率(或者净斜率)。这样做的主要原因是限制特征对因变量的影响,通过增加一个依赖斜率A的损失函数实现。

对于lasso回归算法,在损失函数上增加一项:斜率A的某个给定倍数。我们使用TensorFlow的逻辑操作,但没有这些操作相关的梯度,而是使用阶跃函数的连续估计,也称作连续阶跃函数,其会在截止点跳跃扩大。一会就可以看到如何使用lasso回归算法。

对于岭回归算法,增加一个L2范数,即斜率系数的L2正则。

# LASSO and Ridge Regression
# lasso回归和岭回归
# 
# This function shows how to use TensorFlow to solve LASSO or 
# Ridge regression for 
# y = Ax + b
# 
# We will use the iris data, specifically: 
#  y = Sepal Length 
#  x = Petal Width

# import required libraries
import matplotlib.pyplot as plt
import sys
import numpy as np
import tensorflow as tf
from sklearn import datasets
from tensorflow.python.framework import ops


# Specify 'Ridge' or 'LASSO'
regression_type = 'LASSO'

# clear out old graph
ops.reset_default_graph()

# Create graph
sess = tf.Session()

###
# Load iris data
###

# iris.data = [(Sepal Length, Sepal Width, Petal Length, Petal Width)]
iris = datasets.load_iris()
x_vals = np.array([x[3] for x in iris.data])
y_vals = np.array([y[0] for y in iris.data])

###
# Model Parameters
###

# Declare batch size
batch_size = 50

# Initialize placeholders
x_data = tf.placeholder(shape=[None, 1], dtype=tf.float32)
y_target = tf.placeholder(shape=[None, 1], dtype=tf.float32)

# make results reproducible
seed = 13
np.random.seed(seed)
tf.set_random_seed(seed)

# Create variables for linear regression
A = tf.Variable(tf.random_normal(shape=[1,1]))
b = tf.Variable(tf.random_normal(shape=[1,1]))

# Declare model operations
model_output = tf.add(tf.matmul(x_data, A), b)

###
# Loss Functions
###

# Select appropriate loss function based on regression type

if regression_type == 'LASSO':
  # Declare Lasso loss function
  # 增加损失函数,其为改良过的连续阶跃函数,lasso回归的截止点设为0.9。
  # 这意味着限制斜率系数不超过0.9
  # Lasso Loss = L2_Loss + heavyside_step,
  # Where heavyside_step ~ 0 if A < constant, otherwise ~ 99
  lasso_param = tf.constant(0.9)
  heavyside_step = tf.truediv(1., tf.add(1., tf.exp(tf.multiply(-50., tf.subtract(A, lasso_param)))))
  regularization_param = tf.multiply(heavyside_step, 99.)
  loss = tf.add(tf.reduce_mean(tf.square(y_target - model_output)), regularization_param)

elif regression_type == 'Ridge':
  # Declare the Ridge loss function
  # Ridge loss = L2_loss + L2 norm of slope
  ridge_param = tf.constant(1.)
  ridge_loss = tf.reduce_mean(tf.square(A))
  loss = tf.expand_dims(tf.add(tf.reduce_mean(tf.square(y_target - model_output)), tf.multiply(ridge_param, ridge_loss)), 0)

else:
  print('Invalid regression_type parameter value',file=sys.stderr)


###
# Optimizer
###

# Declare optimizer
my_opt = tf.train.GradientDescentOptimizer(0.001)
train_step = my_opt.minimize(loss)

###
# Run regression
###

# Initialize variables
init = tf.global_variables_initializer()
sess.run(init)

# Training loop
loss_vec = []
for i in range(1500):
  rand_index = np.random.choice(len(x_vals), size=batch_size)
  rand_x = np.transpose([x_vals[rand_index]])
  rand_y = np.transpose([y_vals[rand_index]])
  sess.run(train_step, feed_dict={x_data: rand_x, y_target: rand_y})
  temp_loss = sess.run(loss, feed_dict={x_data: rand_x, y_target: rand_y})
  loss_vec.append(temp_loss[0])
  if (i+1)%300==0:
    print('Step #' + str(i+1) + ' A = ' + str(sess.run(A)) + ' b = ' + str(sess.run(b)))
    print('Loss = ' + str(temp_loss))
    print('\n')

###
# Extract regression results
###

# Get the optimal coefficients
[slope] = sess.run(A)
[y_intercept] = sess.run(b)

# Get best fit line
best_fit = []
for i in x_vals:
 best_fit.append(slope*i+y_intercept)


###
# Plot results
###

# Plot regression line against data points
plt.plot(x_vals, y_vals, 'o', label='Data Points')
plt.plot(x_vals, best_fit, 'r-', label='Best fit line', linewidth=3)
plt.legend(loc='upper left')
plt.title('Sepal Length vs Pedal Width')
plt.xlabel('Pedal Width')
plt.ylabel('Sepal Length')
plt.show()

# Plot loss over time
plt.plot(loss_vec, 'k-')
plt.title(regression_type + ' Loss per Generation')
plt.xlabel('Generation')
plt.ylabel('Loss')
plt.show()

输出结果:

Step #300 A = [[ 0.77170753]] b = [[ 1.82499862]]
Loss = [[ 10.26473045]]
Step #600 A = [[ 0.75908542]] b = [[ 3.2220633]]
Loss = [[ 3.06292033]]
Step #900 A = [[ 0.74843585]] b = [[ 3.9975822]]
Loss = [[ 1.23220456]]
Step #1200 A = [[ 0.73752165]] b = [[ 4.42974091]]
Loss = [[ 0.57872057]]
Step #1500 A = [[ 0.72942668]] b = [[ 4.67253113]]
Loss = [[ 0.40874988]]

用TensorFlow实现lasso回归和岭回归算法的示例 

用TensorFlow实现lasso回归和岭回归算法的示例

通过在标准线性回归估计的基础上,增加一个连续的阶跃函数,实现lasso回归算法。由于阶跃函数的坡度,我们需要注意步长,因为太大的步长会导致最终不收敛。

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python实现在sqlite动态创建表的方法
May 08 Python
详解Python list 与 NumPy.ndarry 切片之间的对比
Jul 24 Python
python3.6 +tkinter GUI编程 实现界面化的文本处理工具(推荐)
Dec 20 Python
Vue的el-scrollbar实现自定义滚动
May 29 Python
Pandas过滤dataframe中包含特定字符串的数据方法
Nov 07 Python
对python判断ip是否可达的实例详解
Jan 31 Python
matplotlib实现显示伪彩色图像及色度条
Dec 07 Python
开启Django博客的RSS功能的实现方法
Feb 17 Python
windows、linux下打包Python3程序详细方法
Mar 17 Python
详解Pycharm第三方库的安装及使用方法
Dec 29 Python
python 基于selectors库实现文件上传与下载
Dec 31 Python
Python趣味爬虫之用Python实现智慧校园一键评教
May 28 Python
Python实现确认字符串是否包含指定字符串的实例
May 02 #Python
详解用TensorFlow实现逻辑回归算法
May 02 #Python
Python获取指定字符前面的所有字符方法
May 02 #Python
Python 查找字符在字符串中的位置实例
May 02 #Python
python 巧用正则寻找字符串中的特定字符的位置方法
May 02 #Python
Python 在字符串中加入变量的实例讲解
May 02 #Python
Python 实现字符串中指定位置插入一个字符
May 02 #Python
You might like
国王的咖啡这么大来头,名字的由来是什么
2021/03/03 咖啡文化
MYSQL 小技巧 -- LAST_INSERT_ID
2009/11/24 PHP
PHP数组对比函数,存在交集则返回真,否则返回假
2011/02/03 PHP
PHP获取一年中每个星期的开始和结束日期的方法
2015/02/12 PHP
PHP版本升级到7.x后wordpress的一些修改及wordpress技巧
2015/12/25 PHP
PHP常见过waf webshell以及最简单的检测方法
2019/05/21 PHP
C#中TrimStart,TrimEnd,Trim在javascript上的实现
2011/01/17 Javascript
JavaScript 图像动画的小demo
2012/05/23 Javascript
jQuery自动添加表单项的方法
2015/07/13 Javascript
js判断空对象的实例(超简单)
2016/07/26 Javascript
Angular2使用Angular CLI快速搭建工程(一)
2017/05/21 Javascript
angularjs之$timeout指令详解
2017/06/13 Javascript
JS轮播图实现简单代码
2021/02/19 Javascript
vuex 使用文档小结篇
2018/01/11 Javascript
Vue页面骨架屏注入方法
2018/05/13 Javascript
JavaScript实现新年倒计时效果
2018/11/17 Javascript
使用webpack将ES6转化ES5的实现方法
2019/10/13 Javascript
JavaScript 浏览器对象模型BOM原理与常见用法实例分析
2019/12/16 Javascript
微信小程序图片自适应实现解析
2020/01/21 Javascript
基于vue.js仿淘宝收货地址并设置默认地址的案例分析
2020/08/20 Javascript
[41:13]完美世界DOTA2联赛PWL S2 Forest vs Rebirth 第一场 11.20
2020/11/20 DOTA
Python获取电脑硬件信息及状态的实现方法
2014/08/29 Python
一些Centos Python 生产环境的部署命令(推荐)
2018/05/07 Python
在python中使用with打开多个文件的方法
2019/01/07 Python
python3 property装饰器实现原理与用法示例
2019/05/15 Python
在pytorch 中计算精度、回归率、F1 score等指标的实例
2020/01/18 Python
Python post请求实现代码实例
2020/02/28 Python
python自动脚本的pyautogui入门学习
2020/04/01 Python
Pycharm 使用 Pipenv 新建的虚拟环境(图文详解)
2020/04/16 Python
什么是java序列化,如何实现java序列化
2012/11/14 面试题
通用C#笔试题附答案
2016/11/26 面试题
见习期自我鉴定
2013/11/07 职场文书
幼儿园保育员岗位职责
2014/04/13 职场文书
老干部工作先进集体事迹材料
2014/05/21 职场文书
个人违纪检讨书
2014/09/15 职场文书
综合素质评价自我评价
2015/03/06 职场文书