用tensorflow实现弹性网络回归算法


Posted in Python onJanuary 09, 2018

本文实例为大家分享了tensorflow实现弹性网络回归算法,供大家参考,具体内容如下

python代码:

#用tensorflow实现弹性网络算法(多变量) 
#使用鸢尾花数据集,后三个特征作为特征,用来预测第一个特征。 
 
 
#1 导入必要的编程库,创建计算图,加载数据集 
import matplotlib.pyplot as plt 
import tensorflow as tf 
import numpy as np 
from sklearn import datasets 
from tensorflow.python.framework import ops 
 
ops.get_default_graph() 
sess = tf.Session() 
iris = datasets.load_iris() 
 
x_vals = np.array([[x[1], x[2], x[3]] for x in iris.data]) 
y_vals = np.array([y[0] for y in iris.data]) 
 
 
#2 声明学习率,批量大小,占位符和模型变量,模型输出 
learning_rate = 0.001 
batch_size = 50 
x_data = tf.placeholder(shape=[None, 3], dtype=tf.float32) #占位符大小为3 
y_target = tf.placeholder(shape=[None, 1], dtype=tf.float32) 
A = tf.Variable(tf.random_normal(shape=[3,1])) 
b = tf.Variable(tf.random_normal(shape=[1,1])) 
model_output = tf.add(tf.matmul(x_data, A), b) 
 
 
#3 对于弹性网络回归算法,损失函数包括L1正则和L2正则 
elastic_param1 = tf.constant(1.) 
elastic_param2 = tf.constant(1.) 
l1_a_loss = tf.reduce_mean(abs(A)) 
l2_a_loss = tf.reduce_mean(tf.square(A)) 
e1_term = tf.multiply(elastic_param1, l1_a_loss) 
e2_term = tf.multiply(elastic_param2, l2_a_loss) 
loss = tf.expand_dims(tf.add(tf.add(tf.reduce_mean(tf.square(y_target - model_output)), e1_term), e2_term), 0) 
 
 
 
#4 初始化变量, 声明优化器, 然后遍历迭代运行, 训练拟合得到参数 
init = tf.global_variables_initializer() 
sess.run(init) 
my_opt = tf.train.GradientDescentOptimizer(learning_rate) 
train_step = my_opt.minimize(loss) 
 
loss_vec = [] 
for i in range(1000): 
   rand_index = np.random.choice(len(x_vals), size=batch_size) 
   rand_x = x_vals[rand_index] 
   rand_y = np.transpose([y_vals[rand_index]]) 
   sess.run(train_step, feed_dict={x_data:rand_x, y_target:rand_y}) 
   temp_loss = sess.run(loss, feed_dict={x_data:rand_x, y_target:rand_y}) 
   loss_vec.append(temp_loss) 
   if (i+1)%250 == 0: 
     print('Step#' + str(i+1) +'A = ' + str(sess.run(A)) + 'b=' + str(sess.run(b))) 
     print('Loss= ' +str(temp_loss)) 
      
 
#现在能观察到, 随着训练迭代后损失函数已收敛。 
plt.plot(loss_vec, 'k--') 
plt.title('Loss per Generation') 
plt.xlabel('Generation') 
plt.ylabel('Loss') 
plt.show()

本文参考书《Tensorflow机器学习实战指南》

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python fabric使用笔记
May 09 Python
python算法演练_One Rule 算法(详解)
May 17 Python
Python实现加载及解析properties配置文件的方法
Mar 29 Python
使用Python设计一个代码统计工具
Apr 04 Python
django允许外部访问的实例讲解
May 14 Python
python tkinter实现界面切换的示例代码
Jun 14 Python
Python模块汇总(常用第三方库)
Oct 07 Python
Python csv文件的读写操作实例详解
Nov 19 Python
基于SpringBoot构造器注入循环依赖及解决方式
Apr 26 Python
python识别验证码的思路及解决方案
Sep 13 Python
查找适用于matplotlib的中文字体名称与实际文件名对应关系的方法
Jan 05 Python
python中%格式表达式实例用法
Jun 18 Python
Python+matplotlib实现计算两个信号的交叉谱密度实例
Jan 08 #Python
python matplotlib 注释文本箭头简单代码示例
Jan 08 #Python
Python自定义简单图轴简单实例
Jan 08 #Python
[原创]python爬虫(入门教程、视频教程)
Jan 08 #Python
小米5s微信跳一跳小程序python源码
Jan 08 #Python
Python实现判断字符串中包含某个字符的判断函数示例
Jan 08 #Python
Python实现的字典值比较功能示例
Jan 08 #Python
You might like
PHP4(windows版本)中的COM函数
2006/10/09 PHP
PHP 读取文件的正确方法
2009/04/29 PHP
PHP数组实例总结与说明
2011/08/23 PHP
php常用数学函数汇总
2014/11/21 PHP
9条PHP编程小知识及易犯的小错误
2015/01/22 PHP
php把数组值转换成键的方法
2015/07/13 PHP
thinkphp5 加载静态资源路径与常量的方法
2017/12/24 PHP
解决laravel(5.5)访问public报错的问题
2019/10/12 PHP
Prototype使用指南之dom.js
2007/01/10 Javascript
JavaScript对象链式操作代码(jquery)
2010/07/04 Javascript
20款超赞的jQuery插件 Web开发人员必备
2011/02/26 Javascript
JS实现一个按钮的方法
2015/02/05 Javascript
一不小心就做错的JS闭包面试题
2015/11/25 Javascript
浅析jQuery中使用$所引发的问题
2016/05/29 Javascript
利用JavaScript实现栈的数据结构示例代码
2017/08/02 Javascript
vue+axios新手实践实现登陆的示例代码
2018/06/06 Javascript
Vue的H5页面唤起支付宝支付功能
2019/04/18 Javascript
详解VSCode配置启动Vue项目
2019/05/14 Javascript
微信小程序全局变量改变监听的实现方法
2019/07/15 Javascript
VueJS 取得 URL 参数值的方法
2019/07/19 Javascript
Vue 解决父组件跳转子路由后当前导航active样式消失问题
2020/07/21 Javascript
Python3使用requests包抓取并保存网页源码的方法
2016/03/15 Python
使用sklearn之LabelEncoder将Label标准化的方法
2018/07/11 Python
python实现AES和RSA加解密的方法
2019/03/28 Python
Django使用中间键实现csrf认证详解
2019/07/22 Python
聊聊python中的循环遍历
2020/09/07 Python
美国高端寝具品牌:Coyuchi
2017/02/08 全球购物
史蒂夫·马登加拿大官网:Steve Madden加拿大
2017/11/18 全球购物
ajax是什么及其工作原理
2012/02/08 面试题
转让协议书范本
2014/09/13 职场文书
在职员工证明书
2014/09/19 职场文书
2015年领导班子工作总结
2015/05/23 职场文书
2015教师个人年度工作总结
2015/10/23 职场文书
让文件路径提取变得更简单的Python Path库
2021/05/27 Python
MySQL如何修改字段类型和字段长度
2022/06/10 MySQL
教你如何用cmd快速登录服务器
2022/06/10 Servers