用tensorflow实现弹性网络回归算法


Posted in Python onJanuary 09, 2018

本文实例为大家分享了tensorflow实现弹性网络回归算法,供大家参考,具体内容如下

python代码:

#用tensorflow实现弹性网络算法(多变量) 
#使用鸢尾花数据集,后三个特征作为特征,用来预测第一个特征。 
 
 
#1 导入必要的编程库,创建计算图,加载数据集 
import matplotlib.pyplot as plt 
import tensorflow as tf 
import numpy as np 
from sklearn import datasets 
from tensorflow.python.framework import ops 
 
ops.get_default_graph() 
sess = tf.Session() 
iris = datasets.load_iris() 
 
x_vals = np.array([[x[1], x[2], x[3]] for x in iris.data]) 
y_vals = np.array([y[0] for y in iris.data]) 
 
 
#2 声明学习率,批量大小,占位符和模型变量,模型输出 
learning_rate = 0.001 
batch_size = 50 
x_data = tf.placeholder(shape=[None, 3], dtype=tf.float32) #占位符大小为3 
y_target = tf.placeholder(shape=[None, 1], dtype=tf.float32) 
A = tf.Variable(tf.random_normal(shape=[3,1])) 
b = tf.Variable(tf.random_normal(shape=[1,1])) 
model_output = tf.add(tf.matmul(x_data, A), b) 
 
 
#3 对于弹性网络回归算法,损失函数包括L1正则和L2正则 
elastic_param1 = tf.constant(1.) 
elastic_param2 = tf.constant(1.) 
l1_a_loss = tf.reduce_mean(abs(A)) 
l2_a_loss = tf.reduce_mean(tf.square(A)) 
e1_term = tf.multiply(elastic_param1, l1_a_loss) 
e2_term = tf.multiply(elastic_param2, l2_a_loss) 
loss = tf.expand_dims(tf.add(tf.add(tf.reduce_mean(tf.square(y_target - model_output)), e1_term), e2_term), 0) 
 
 
 
#4 初始化变量, 声明优化器, 然后遍历迭代运行, 训练拟合得到参数 
init = tf.global_variables_initializer() 
sess.run(init) 
my_opt = tf.train.GradientDescentOptimizer(learning_rate) 
train_step = my_opt.minimize(loss) 
 
loss_vec = [] 
for i in range(1000): 
   rand_index = np.random.choice(len(x_vals), size=batch_size) 
   rand_x = x_vals[rand_index] 
   rand_y = np.transpose([y_vals[rand_index]]) 
   sess.run(train_step, feed_dict={x_data:rand_x, y_target:rand_y}) 
   temp_loss = sess.run(loss, feed_dict={x_data:rand_x, y_target:rand_y}) 
   loss_vec.append(temp_loss) 
   if (i+1)%250 == 0: 
     print('Step#' + str(i+1) +'A = ' + str(sess.run(A)) + 'b=' + str(sess.run(b))) 
     print('Loss= ' +str(temp_loss)) 
      
 
#现在能观察到, 随着训练迭代后损失函数已收敛。 
plt.plot(loss_vec, 'k--') 
plt.title('Loss per Generation') 
plt.xlabel('Generation') 
plt.ylabel('Loss') 
plt.show()

本文参考书《Tensorflow机器学习实战指南》

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
使用python编写批量卸载手机中安装的android应用脚本
Jul 21 Python
100行Python代码实现自动抢火车票(附源码)
Jan 11 Python
Python 实现一行输入多个值的方法
Apr 21 Python
python抓取网站的图片并下载到本地的方法
May 22 Python
python得到电脑的开机时间方法
Oct 15 Python
详解python pandas 分组统计的方法
Jul 30 Python
pytorch载入预训练模型后,实现训练指定层
Jan 06 Python
Python安装whl文件过程图解
Feb 18 Python
Python猴子补丁Monkey Patch用法实例解析
Mar 23 Python
keras中的loss、optimizer、metrics用法
Jun 15 Python
python实例化对象的具体方法
Jun 17 Python
基于Python实现将列表数据生成折线图
Mar 23 Python
Python+matplotlib实现计算两个信号的交叉谱密度实例
Jan 08 #Python
python matplotlib 注释文本箭头简单代码示例
Jan 08 #Python
Python自定义简单图轴简单实例
Jan 08 #Python
[原创]python爬虫(入门教程、视频教程)
Jan 08 #Python
小米5s微信跳一跳小程序python源码
Jan 08 #Python
Python实现判断字符串中包含某个字符的判断函数示例
Jan 08 #Python
Python实现的字典值比较功能示例
Jan 08 #Python
You might like
mayfish 数据入库验证代码
2010/04/30 PHP
PHP 获取远程文件大小的3种解决方法
2013/07/11 PHP
php中常量DIRECTORY_SEPARATOR用法深入分析
2014/11/14 PHP
PDO预处理语句PDOStatement对象使用总结
2014/11/20 PHP
PHP模板引擎Smarty内置变量调解器用法详解
2016/04/11 PHP
php动态读取数据清除最右边距的方法
2017/04/12 PHP
js或css文件后面跟参数的原因说明
2010/01/09 Javascript
Ext中下拉列表ComboBox组件store数据格式用法介绍
2013/07/15 Javascript
JavaScript中九种常用排序算法
2014/09/02 Javascript
浅谈JavaScript中Date(日期对象),Math对象
2015/02/05 Javascript
php利用curl获取远程图片实现方法
2015/10/26 Javascript
Bootstrap Paginator分页插件与ajax相结合实现动态无刷新分页效果
2016/05/27 Javascript
jQuery实现的简单拖拽功能示例
2016/09/13 Javascript
Angular如何引入第三方库的方法详解
2017/07/13 Javascript
JavaScript学习总结(一) ECMAScript、BOM、DOM(核心、浏览器对象模型与文档对象模型)
2018/01/07 Javascript
vue 2.0 购物车小球抛物线的示例代码
2018/02/01 Javascript
jquery使用FormData实现异步上传文件
2018/10/25 jQuery
JavaScript中关于base64的一些事
2019/05/06 Javascript
微信小程序获取当前位置和城市名
2019/11/13 Javascript
浅谈Vuex的this.$store.commit和在Vue项目中引用公共方法
2020/07/24 Javascript
OpenLayers加载缩放控件使用方法详解
2020/09/25 Javascript
JS实现炫酷轮播图
2020/11/15 Javascript
js实现简易计算器小功能
2020/11/18 Javascript
JavaScript canvas实现雨滴特效
2021/01/10 Javascript
跟老齐学Python之永远强大的函数
2014/09/14 Python
python获取mp3文件信息的方法
2015/06/15 Python
tensorflow实现简单的卷积神经网络
2018/05/24 Python
Python实现在某个数组中查找一个值的算法示例
2018/06/27 Python
多个版本的python共存时使用pip的正确做法
2020/10/26 Python
优衣库英国官网:UNIQLO英国
2016/12/25 全球购物
Eclipse面试题
2014/03/22 面试题
学生自我鉴定范文
2013/10/04 职场文书
明信片寄语大全
2014/04/08 职场文书
买房委托公证书
2014/04/08 职场文书
传承焦裕禄精神思想汇报2014
2014/09/10 职场文书
解决vue $http的get和post请求跨域问题
2021/06/07 Vue.js