用TensorFlow实现lasso回归和岭回归算法的示例


Posted in Python onMay 02, 2018

也有些正则方法可以限制回归算法输出结果中系数的影响,其中最常用的两种正则方法是lasso回归和岭回归。

lasso回归和岭回归算法跟常规线性回归算法极其相似,有一点不同的是,在公式中增加正则项来限制斜率(或者净斜率)。这样做的主要原因是限制特征对因变量的影响,通过增加一个依赖斜率A的损失函数实现。

对于lasso回归算法,在损失函数上增加一项:斜率A的某个给定倍数。我们使用TensorFlow的逻辑操作,但没有这些操作相关的梯度,而是使用阶跃函数的连续估计,也称作连续阶跃函数,其会在截止点跳跃扩大。一会就可以看到如何使用lasso回归算法。

对于岭回归算法,增加一个L2范数,即斜率系数的L2正则。

# LASSO and Ridge Regression
# lasso回归和岭回归
# 
# This function shows how to use TensorFlow to solve LASSO or 
# Ridge regression for 
# y = Ax + b
# 
# We will use the iris data, specifically: 
#  y = Sepal Length 
#  x = Petal Width

# import required libraries
import matplotlib.pyplot as plt
import sys
import numpy as np
import tensorflow as tf
from sklearn import datasets
from tensorflow.python.framework import ops


# Specify 'Ridge' or 'LASSO'
regression_type = 'LASSO'

# clear out old graph
ops.reset_default_graph()

# Create graph
sess = tf.Session()

###
# Load iris data
###

# iris.data = [(Sepal Length, Sepal Width, Petal Length, Petal Width)]
iris = datasets.load_iris()
x_vals = np.array([x[3] for x in iris.data])
y_vals = np.array([y[0] for y in iris.data])

###
# Model Parameters
###

# Declare batch size
batch_size = 50

# Initialize placeholders
x_data = tf.placeholder(shape=[None, 1], dtype=tf.float32)
y_target = tf.placeholder(shape=[None, 1], dtype=tf.float32)

# make results reproducible
seed = 13
np.random.seed(seed)
tf.set_random_seed(seed)

# Create variables for linear regression
A = tf.Variable(tf.random_normal(shape=[1,1]))
b = tf.Variable(tf.random_normal(shape=[1,1]))

# Declare model operations
model_output = tf.add(tf.matmul(x_data, A), b)

###
# Loss Functions
###

# Select appropriate loss function based on regression type

if regression_type == 'LASSO':
  # Declare Lasso loss function
  # 增加损失函数,其为改良过的连续阶跃函数,lasso回归的截止点设为0.9。
  # 这意味着限制斜率系数不超过0.9
  # Lasso Loss = L2_Loss + heavyside_step,
  # Where heavyside_step ~ 0 if A < constant, otherwise ~ 99
  lasso_param = tf.constant(0.9)
  heavyside_step = tf.truediv(1., tf.add(1., tf.exp(tf.multiply(-50., tf.subtract(A, lasso_param)))))
  regularization_param = tf.multiply(heavyside_step, 99.)
  loss = tf.add(tf.reduce_mean(tf.square(y_target - model_output)), regularization_param)

elif regression_type == 'Ridge':
  # Declare the Ridge loss function
  # Ridge loss = L2_loss + L2 norm of slope
  ridge_param = tf.constant(1.)
  ridge_loss = tf.reduce_mean(tf.square(A))
  loss = tf.expand_dims(tf.add(tf.reduce_mean(tf.square(y_target - model_output)), tf.multiply(ridge_param, ridge_loss)), 0)

else:
  print('Invalid regression_type parameter value',file=sys.stderr)


###
# Optimizer
###

# Declare optimizer
my_opt = tf.train.GradientDescentOptimizer(0.001)
train_step = my_opt.minimize(loss)

###
# Run regression
###

# Initialize variables
init = tf.global_variables_initializer()
sess.run(init)

# Training loop
loss_vec = []
for i in range(1500):
  rand_index = np.random.choice(len(x_vals), size=batch_size)
  rand_x = np.transpose([x_vals[rand_index]])
  rand_y = np.transpose([y_vals[rand_index]])
  sess.run(train_step, feed_dict={x_data: rand_x, y_target: rand_y})
  temp_loss = sess.run(loss, feed_dict={x_data: rand_x, y_target: rand_y})
  loss_vec.append(temp_loss[0])
  if (i+1)%300==0:
    print('Step #' + str(i+1) + ' A = ' + str(sess.run(A)) + ' b = ' + str(sess.run(b)))
    print('Loss = ' + str(temp_loss))
    print('\n')

###
# Extract regression results
###

# Get the optimal coefficients
[slope] = sess.run(A)
[y_intercept] = sess.run(b)

# Get best fit line
best_fit = []
for i in x_vals:
 best_fit.append(slope*i+y_intercept)


###
# Plot results
###

# Plot regression line against data points
plt.plot(x_vals, y_vals, 'o', label='Data Points')
plt.plot(x_vals, best_fit, 'r-', label='Best fit line', linewidth=3)
plt.legend(loc='upper left')
plt.title('Sepal Length vs Pedal Width')
plt.xlabel('Pedal Width')
plt.ylabel('Sepal Length')
plt.show()

# Plot loss over time
plt.plot(loss_vec, 'k-')
plt.title(regression_type + ' Loss per Generation')
plt.xlabel('Generation')
plt.ylabel('Loss')
plt.show()

输出结果:

Step #300 A = [[ 0.77170753]] b = [[ 1.82499862]]
Loss = [[ 10.26473045]]
Step #600 A = [[ 0.75908542]] b = [[ 3.2220633]]
Loss = [[ 3.06292033]]
Step #900 A = [[ 0.74843585]] b = [[ 3.9975822]]
Loss = [[ 1.23220456]]
Step #1200 A = [[ 0.73752165]] b = [[ 4.42974091]]
Loss = [[ 0.57872057]]
Step #1500 A = [[ 0.72942668]] b = [[ 4.67253113]]
Loss = [[ 0.40874988]]

用TensorFlow实现lasso回归和岭回归算法的示例 

用TensorFlow实现lasso回归和岭回归算法的示例

通过在标准线性回归估计的基础上,增加一个连续的阶跃函数,实现lasso回归算法。由于阶跃函数的坡度,我们需要注意步长,因为太大的步长会导致最终不收敛。

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python聚类算法之DBSACN实例分析
Nov 20 Python
解读python logging模块的使用方法
Apr 17 Python
python中virtualenvwrapper安装与使用
May 20 Python
浅谈python中str字符串和unicode对象字符串的拼接问题
Dec 04 Python
解析Python的缩进规则的使用
Jan 16 Python
Python 把序列转换为元组的函数tuple方法
Jun 27 Python
python pickle存储、读取大数据量列表、字典数据的方法
Jul 07 Python
Python实现制度转换(货币,温度,长度)
Jul 14 Python
Django实现whoosh搜索引擎使用jieba分词
Apr 08 Python
python opencv通过4坐标剪裁图片
Jun 05 Python
Python3中PyQt5简单实现文件打开及保存
Jun 10 Python
opencv用VS2013调试时用Image Watch插件查看图片
Jul 26 Python
Python实现确认字符串是否包含指定字符串的实例
May 02 #Python
详解用TensorFlow实现逻辑回归算法
May 02 #Python
Python获取指定字符前面的所有字符方法
May 02 #Python
Python 查找字符在字符串中的位置实例
May 02 #Python
python 巧用正则寻找字符串中的特定字符的位置方法
May 02 #Python
Python 在字符串中加入变量的实例讲解
May 02 #Python
Python 实现字符串中指定位置插入一个字符
May 02 #Python
You might like
div li的多行多列 无刷新分页示例代码
2013/10/16 PHP
PHP中使用sleep函数实现定时任务实例分享
2014/08/21 PHP
php保存信息到当前Session的方法
2015/03/16 PHP
Thinkphp关闭缓存的方法
2015/06/26 PHP
PHP随手笔记整理之PHP脚本和JAVA连接mysql数据库
2015/11/25 PHP
PHP+redis实现微博的拉模型案例详解
2019/07/10 PHP
Extjs优化(二)Form表单提交通用实现
2013/04/15 Javascript
5秒后跳转效果(setInterval/SetTimeOut)
2013/05/03 Javascript
jquery中的ajax方法怎样通过JSONP进行远程调用
2014/05/04 Javascript
JS+CSS实现弹出全屏灰黑色透明遮罩效果的方法
2014/12/20 Javascript
jquery移动节点实例
2015/01/14 Javascript
jQuery实现可用于博客的动态滑动菜单
2015/03/09 Javascript
瀑布流的实现方式(原生js+jquery+css3)
2020/06/28 Javascript
微信小程序 Nginx环境配置详细介绍
2017/02/14 Javascript
Node.js创建Web、TCP服务器
2017/12/05 Javascript
node.js使用express框架进行文件上传详解
2019/03/03 Javascript
记一次react前端项目打包优化的方法
2020/03/30 Javascript
[01:01:14]完美世界DOTA2联赛PWL S2 SZ vs Rebirth 第一场 11.21
2020/11/23 DOTA
用Python的Django框架编写从Google Adsense中获得报表的应用
2015/04/17 Python
python 迭代器和iter()函数详解及实例
2017/03/21 Python
Python使用回溯法子集树模板获取最长公共子序列(LCS)的方法
2017/09/08 Python
python Socket之客户端和服务端握手详解
2017/09/18 Python
对python字典过滤条件的实例详解
2019/01/22 Python
Python实现查找数组中任意第k大的数字算法示例
2019/01/23 Python
超简单的Python HTTP服务
2019/07/22 Python
Win10系统下安装labelme及json文件批量转化方法
2019/07/30 Python
详解Python中pyautogui库的最全使用方法
2020/04/01 Python
基于python实现生成指定大小txt文档
2020/07/20 Python
Lacoste澳大利亚官网:服装、鞋类及配饰
2018/11/14 全球购物
幼儿园中班开学寄语
2014/04/03 职场文书
先进班组事迹材料
2014/12/25 职场文书
元宵节寄语大全
2015/02/27 职场文书
酒店采购员岗位职责
2015/04/03 职场文书
2016年“抗战胜利纪念日”71周年校园广播稿
2015/12/18 职场文书
《槐乡的孩子》教学反思
2016/02/20 职场文书
Java异常处理try catch的基本用法
2021/12/06 Java/Android