用tensorflow实现弹性网络回归算法


Posted in Python onJanuary 09, 2018

本文实例为大家分享了tensorflow实现弹性网络回归算法,供大家参考,具体内容如下

python代码:

#用tensorflow实现弹性网络算法(多变量) 
#使用鸢尾花数据集,后三个特征作为特征,用来预测第一个特征。 
 
 
#1 导入必要的编程库,创建计算图,加载数据集 
import matplotlib.pyplot as plt 
import tensorflow as tf 
import numpy as np 
from sklearn import datasets 
from tensorflow.python.framework import ops 
 
ops.get_default_graph() 
sess = tf.Session() 
iris = datasets.load_iris() 
 
x_vals = np.array([[x[1], x[2], x[3]] for x in iris.data]) 
y_vals = np.array([y[0] for y in iris.data]) 
 
 
#2 声明学习率,批量大小,占位符和模型变量,模型输出 
learning_rate = 0.001 
batch_size = 50 
x_data = tf.placeholder(shape=[None, 3], dtype=tf.float32) #占位符大小为3 
y_target = tf.placeholder(shape=[None, 1], dtype=tf.float32) 
A = tf.Variable(tf.random_normal(shape=[3,1])) 
b = tf.Variable(tf.random_normal(shape=[1,1])) 
model_output = tf.add(tf.matmul(x_data, A), b) 
 
 
#3 对于弹性网络回归算法,损失函数包括L1正则和L2正则 
elastic_param1 = tf.constant(1.) 
elastic_param2 = tf.constant(1.) 
l1_a_loss = tf.reduce_mean(abs(A)) 
l2_a_loss = tf.reduce_mean(tf.square(A)) 
e1_term = tf.multiply(elastic_param1, l1_a_loss) 
e2_term = tf.multiply(elastic_param2, l2_a_loss) 
loss = tf.expand_dims(tf.add(tf.add(tf.reduce_mean(tf.square(y_target - model_output)), e1_term), e2_term), 0) 
 
 
 
#4 初始化变量, 声明优化器, 然后遍历迭代运行, 训练拟合得到参数 
init = tf.global_variables_initializer() 
sess.run(init) 
my_opt = tf.train.GradientDescentOptimizer(learning_rate) 
train_step = my_opt.minimize(loss) 
 
loss_vec = [] 
for i in range(1000): 
   rand_index = np.random.choice(len(x_vals), size=batch_size) 
   rand_x = x_vals[rand_index] 
   rand_y = np.transpose([y_vals[rand_index]]) 
   sess.run(train_step, feed_dict={x_data:rand_x, y_target:rand_y}) 
   temp_loss = sess.run(loss, feed_dict={x_data:rand_x, y_target:rand_y}) 
   loss_vec.append(temp_loss) 
   if (i+1)%250 == 0: 
     print('Step#' + str(i+1) +'A = ' + str(sess.run(A)) + 'b=' + str(sess.run(b))) 
     print('Loss= ' +str(temp_loss)) 
      
 
#现在能观察到, 随着训练迭代后损失函数已收敛。 
plt.plot(loss_vec, 'k--') 
plt.title('Loss per Generation') 
plt.xlabel('Generation') 
plt.ylabel('Loss') 
plt.show()

本文参考书《Tensorflow机器学习实战指南》

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
从Python程序中访问Java类的简单示例
Apr 20 Python
在Python的Flask框架中使用日期和时间的教程
Apr 21 Python
Python模拟随机游走图形效果示例
Feb 06 Python
python爬取淘宝商品详情页数据
Feb 23 Python
基于循环神经网络(RNN)的古诗生成器
Mar 26 Python
Python闭包执行时值的传递方式实例分析
Jun 04 Python
Python 处理图片像素点的实例
Jan 08 Python
Python如何爬取实时变化的WebSocket数据的方法
Mar 09 Python
使用Fabric自动化部署Django项目的实现
Sep 27 Python
Python 用turtle实现用正方形画圆的例子
Nov 21 Python
Python图像处理库PIL的ImageFont模块使用介绍
Feb 26 Python
Anaconda+Pycharm环境下的PyTorch配置方法
Mar 13 Python
Python+matplotlib实现计算两个信号的交叉谱密度实例
Jan 08 #Python
python matplotlib 注释文本箭头简单代码示例
Jan 08 #Python
Python自定义简单图轴简单实例
Jan 08 #Python
[原创]python爬虫(入门教程、视频教程)
Jan 08 #Python
小米5s微信跳一跳小程序python源码
Jan 08 #Python
Python实现判断字符串中包含某个字符的判断函数示例
Jan 08 #Python
Python实现的字典值比较功能示例
Jan 08 #Python
You might like
PHP中的traits简单使用实例
2015/05/13 PHP
遍历echsop的region表形成缓存的程序实例代码
2016/11/01 PHP
PHP 网站修改默认访问文件的nginx配置
2017/05/27 PHP
PHP实现链式操作的三种方法详解
2017/11/16 PHP
PHP如何防止用户重复提交表单
2020/12/09 PHP
js列举css中所有图标的实现代码
2011/07/04 Javascript
使用POST方式弹出窗口的两种方法示例介绍
2014/01/29 Javascript
JavaScript中实现单体模式分享
2015/01/29 Javascript
浅谈Javascript数组索引
2015/07/29 Javascript
js实现类似MSN提示的页面效果代码分享
2015/08/24 Javascript
安装使用Mongoose配合Node.js操作MongoDB的基础教程
2016/03/01 Javascript
jQuery基于扩展简单实现倒计时功能的方法
2016/05/14 Javascript
Javascript将字符串日期格式化为yyyy-mm-dd的方法
2016/10/27 Javascript
Vue 组件间的样式冲突污染
2017/08/31 Javascript
Vue2.0设置全局样式(less/sass和css)
2017/11/18 Javascript
移动端滑动切换组件封装 vue-swiper-router实例详解
2018/11/25 Javascript
在微信小程序中使用vant的方法
2019/06/07 Javascript
关于vue-cli3打包代码后白屏的解决方案
2020/09/02 Javascript
详解Vue中的watch和computed
2020/11/09 Javascript
python实现2048小游戏
2015/03/30 Python
python使用线程封装的一个简单定时器类实例
2015/05/16 Python
Python在Console下显示文本进度条的方法
2016/02/14 Python
Python 类的特殊成员解析
2018/06/20 Python
Python 中的range(),以及列表切片方法
2018/07/02 Python
Python小白不正确的使用类变量实例
2020/05/29 Python
python 实现一个图形界面的汇率计算器
2020/11/09 Python
Python系统公网私网流量监控实现流程
2020/11/23 Python
Python 利用argparse模块实现脚本命令行参数解析
2020/12/28 Python
使用CSS禁止textarea调整大小功能的方法
2015/03/13 HTML / CSS
基于HTML5+tracking.js实现刷脸支付功能
2020/04/16 HTML / CSS
美国知名的摄影器材销售网站:Adorama
2017/02/01 全球购物
EJB包括(SessionBean,EntityBean)说出他们的生命周期,及如何管理事务的
2015/07/24 面试题
环保口号大全
2014/06/12 职场文书
基层党员学习党的群众路线教育实践活动心得体会
2014/11/04 职场文书
2015年统计员个人工作总结
2015/07/23 职场文书
大学生如何逃脱“毕业季创业队即散伙”魔咒?
2019/08/19 职场文书