用tensorflow实现弹性网络回归算法


Posted in Python onJanuary 09, 2018

本文实例为大家分享了tensorflow实现弹性网络回归算法,供大家参考,具体内容如下

python代码:

#用tensorflow实现弹性网络算法(多变量) 
#使用鸢尾花数据集,后三个特征作为特征,用来预测第一个特征。 
 
 
#1 导入必要的编程库,创建计算图,加载数据集 
import matplotlib.pyplot as plt 
import tensorflow as tf 
import numpy as np 
from sklearn import datasets 
from tensorflow.python.framework import ops 
 
ops.get_default_graph() 
sess = tf.Session() 
iris = datasets.load_iris() 
 
x_vals = np.array([[x[1], x[2], x[3]] for x in iris.data]) 
y_vals = np.array([y[0] for y in iris.data]) 
 
 
#2 声明学习率,批量大小,占位符和模型变量,模型输出 
learning_rate = 0.001 
batch_size = 50 
x_data = tf.placeholder(shape=[None, 3], dtype=tf.float32) #占位符大小为3 
y_target = tf.placeholder(shape=[None, 1], dtype=tf.float32) 
A = tf.Variable(tf.random_normal(shape=[3,1])) 
b = tf.Variable(tf.random_normal(shape=[1,1])) 
model_output = tf.add(tf.matmul(x_data, A), b) 
 
 
#3 对于弹性网络回归算法,损失函数包括L1正则和L2正则 
elastic_param1 = tf.constant(1.) 
elastic_param2 = tf.constant(1.) 
l1_a_loss = tf.reduce_mean(abs(A)) 
l2_a_loss = tf.reduce_mean(tf.square(A)) 
e1_term = tf.multiply(elastic_param1, l1_a_loss) 
e2_term = tf.multiply(elastic_param2, l2_a_loss) 
loss = tf.expand_dims(tf.add(tf.add(tf.reduce_mean(tf.square(y_target - model_output)), e1_term), e2_term), 0) 
 
 
 
#4 初始化变量, 声明优化器, 然后遍历迭代运行, 训练拟合得到参数 
init = tf.global_variables_initializer() 
sess.run(init) 
my_opt = tf.train.GradientDescentOptimizer(learning_rate) 
train_step = my_opt.minimize(loss) 
 
loss_vec = [] 
for i in range(1000): 
   rand_index = np.random.choice(len(x_vals), size=batch_size) 
   rand_x = x_vals[rand_index] 
   rand_y = np.transpose([y_vals[rand_index]]) 
   sess.run(train_step, feed_dict={x_data:rand_x, y_target:rand_y}) 
   temp_loss = sess.run(loss, feed_dict={x_data:rand_x, y_target:rand_y}) 
   loss_vec.append(temp_loss) 
   if (i+1)%250 == 0: 
     print('Step#' + str(i+1) +'A = ' + str(sess.run(A)) + 'b=' + str(sess.run(b))) 
     print('Loss= ' +str(temp_loss)) 
      
 
#现在能观察到, 随着训练迭代后损失函数已收敛。 
plt.plot(loss_vec, 'k--') 
plt.title('Loss per Generation') 
plt.xlabel('Generation') 
plt.ylabel('Loss') 
plt.show()

本文参考书《Tensorflow机器学习实战指南》

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
使用Python下载Bing图片(代码)
Nov 07 Python
Python类定义和类继承详解
May 08 Python
Python 爬虫爬取指定博客的所有文章
Feb 17 Python
python的变量与赋值详细分析
Nov 08 Python
Python实现的redis分布式锁功能示例
May 29 Python
OPENCV去除小连通区域,去除孔洞的实例讲解
Jun 21 Python
python 进程 进程池 进程间通信实现解析
Aug 23 Python
python+selenium 鼠标事件操作方法
Aug 24 Python
pytorch之ImageFolder使用详解
Jan 06 Python
Python基于staticmethod装饰器标示静态方法
Oct 17 Python
Restful_framework视图组件代码实例解析
Nov 17 Python
python自动化测试之Selenium详解
Mar 13 Python
Python+matplotlib实现计算两个信号的交叉谱密度实例
Jan 08 #Python
python matplotlib 注释文本箭头简单代码示例
Jan 08 #Python
Python自定义简单图轴简单实例
Jan 08 #Python
[原创]python爬虫(入门教程、视频教程)
Jan 08 #Python
小米5s微信跳一跳小程序python源码
Jan 08 #Python
Python实现判断字符串中包含某个字符的判断函数示例
Jan 08 #Python
Python实现的字典值比较功能示例
Jan 08 #Python
You might like
外媒评选出10支2020年最受欢迎的Dota2战队
2021/03/05 DOTA
php使用str_replace实现输入框回车替换br的方法
2014/11/24 PHP
Java中final关键字详解
2015/08/10 PHP
JQuery 学习笔记 选择器之六
2009/07/23 Javascript
javascript客户端解决方案 缓存提供程序
2010/07/14 Javascript
jquery 与NVelocity 产生冲突的解决方法
2011/06/13 Javascript
JS去除数组重复值的五种不同方法
2013/09/06 Javascript
ie9 提示'console' 未定义问题的解决方法
2014/03/20 Javascript
jQuery获取checkboxlist的value值的方法
2015/09/27 Javascript
基于JavaScript实现移除(删除)数组中指定元素
2016/01/04 Javascript
jQuery代码性能优化的10种方法
2016/06/21 Javascript
Vuejs第九篇之组件作用域及props数据传递实例详解
2016/09/05 Javascript
jsonp跨域请求实现示例
2017/03/13 Javascript
JavaScript编写的网页小游戏,很给力
2017/08/18 Javascript
Node.js自定义实现文件路由功能
2017/09/22 Javascript
使用vue-cli(vue脚手架)快速搭建项目的方法
2018/05/21 Javascript
node版本管理工具n包使用教程详解
2018/11/09 Javascript
微信小程序设置全局请求URL及封装wx.request请求操作示例
2019/04/02 Javascript
深入理解JavaScript 箭头函数
2019/05/30 Javascript
如何通过javaScript去除字符串两端的空白字符
2020/02/06 Javascript
全面解析js中的原型,原型对象,原型链
2021/01/25 Javascript
Python的Twisted框架中使用Deferred对象来管理回调函数
2016/05/25 Python
详解python的ORM中Pony用法
2018/02/09 Python
用python编写第一个IDA插件的实例
2018/05/29 Python
对numpy中数组转置的求解以及向量内积计算方法
2018/10/31 Python
python 阶乘累加和的实例
2019/02/01 Python
对python 自定义协议的方法详解
2019/02/13 Python
Series和DataFrame使用简单入门
2019/11/13 Python
Python 日期时间datetime 加一天,减一天,加减一小时一分钟,加减一年
2020/04/16 Python
基于Keras中Conv1D和Conv2D的区别说明
2020/06/19 Python
HTML5 创建canvas元素示例代码
2014/06/04 HTML / CSS
详解HTML5 Canvas标签及基本使用
2020/01/10 HTML / CSS
Gweniss格温妮丝女包官网:英国纯手工制造潮流包包品牌
2018/02/07 全球购物
销售人员职业生涯规划范文
2014/03/01 职场文书
保密协议书范本
2014/04/22 职场文书
践行党的群众路线心得体会
2014/11/05 职场文书