tensorflow 自定义损失函数示例代码


Posted in Python onFebruary 05, 2020

这个自定义损失函数的背景:(一般回归用的损失函数是MSE, 但要看实际遇到的情况而有所改变)

我们现在想要做一个回归,来预估某个商品的销量,现在我们知道,一件商品的成本是1元,售价是10元。

如果我们用均方差来算的话,如果预估多一个,则损失一块钱,预估少一个,则损失9元钱(少赚的)。

显然,我宁愿预估多了,也不想预估少了。

所以,我们就自己定义一个损失函数,用来分段地看,当yhat 比 y大时怎么样,当yhat比y小时怎么样。

(yhat沿用吴恩达课堂中的叫法)

import tensorflow as tf
from numpy.random import RandomState
batch_size = 8
# 两个输入节点
x = tf.placeholder(tf.float32, shape=(None, 2), name="x-input")
# 回归问题一般只有一个输出节点
y_ = tf.placeholder(tf.float32, shape=(None, 1), name="y-input")
# 定义了一个单层的神经网络前向传播的过程,这里就是简单加权和
w1 = tf.Variable(tf.random_normal([2, 1], stddev=1, seed=1))
y = tf.matmul(x, w1)
# 定义预测多了和预测少了的成本
loss_less = 10
loss_more = 1
#在windows下,下面用这个where替代,因为调用tf.select会报错
loss = tf.reduce_sum(tf.where(tf.greater(y, y_), (y - y_)*loss_more, (y_-y)*loss_less))
train_step = tf.train.AdamOptimizer(0.001).minimize(loss)
#通过随机数生成一个模拟数据集
rdm = RandomState(1)
dataset_size = 128
X = rdm.rand(dataset_size, 2)
"""
设置回归的正确值为两个输入的和加上一个随机量,之所以要加上一个随机量是
为了加入不可预测的噪音,否则不同损失函数的意义就不大了,因为不同损失函数
都会在能完全预测正确的时候最低。一般来说,噪音为一个均值为0的小量,所以
这里的噪音设置为-0.05, 0.05的随机数。
"""
Y = [[x1 + x2 + rdm.rand()/10.0-0.05] for (x1, x2) in X]
with tf.Session() as sess:
 init = tf.global_variables_initializer()
 sess.run(init)
 steps = 5000
 for i in range(steps):
  start = (i * batch_size) % dataset_size
  end = min(start + batch_size, dataset_size)
  sess.run(train_step, feed_dict={x:X[start:end], y_:Y[start:end]})
 print(sess.run(w1))

[[ 1.01934695]
[ 1.04280889]

最终结果如上面所示。

因为我们当初生成训练数据的时候,y是x1 + x2,所以回归结果应该是1,1才对。
但是,由于我们加了自己定义的损失函数,所以,倾向于预估多一点。

如果,我们将loss_less和loss_more对调,我们看一下结果:

[[ 0.95525807]
[ 0.9813394 ]]

通过这个例子,我们可以看出,对于相同的神经网络,不同的损失函数会对训练出来的模型产生重要的影响。

引用:以上实例为《Tensorflow实战 Google深度学习框架》中提供。

总结

以上所述是小编给大家介绍的tensorflow 自定义损失函数示例,希望对大家有所帮助!

Python 相关文章推荐
Python正则表达式介绍
Aug 06 Python
Python内置的HTTP协议服务器SimpleHTTPServer使用指南
Mar 30 Python
Python3 socket同步通信简单示例
Jun 07 Python
在cmd中查看python的安装路径方法
Jul 03 Python
Python实现时间序列可视化的方法
Aug 06 Python
python单向循环链表原理与实现方法示例
Dec 03 Python
pandas 对group进行聚合的例子
Dec 27 Python
pytorch掉坑记录:model.eval的作用说明
Jun 23 Python
python 5个顶级异步框架推荐
Sep 09 Python
Python pip 常用命令汇总
Oct 19 Python
详解python第三方库的安装、PyInstaller库、random库
Mar 03 Python
Python多个MP4合成视频的实现方法
Jul 16 Python
利用Tensorflow的队列多线程读取数据方式
Feb 05 #Python
Tensorflow 多线程与多进程数据加载实例
Feb 05 #Python
TensorFlow自定义损失函数来预测商品销售量
Feb 05 #Python
解决Tensorflow 内存泄露问题
Feb 05 #Python
TensorFlow实现指数衰减学习率的方法
Feb 05 #Python
关于Tensorflow使用CPU报错的解决方式
Feb 05 #Python
解决Tensorflow sess.run导致的内存溢出问题
Feb 05 #Python
You might like
50个PHP程序性能优化的方法
2014/06/02 PHP
php字符串函数学习之strstr()
2015/03/27 PHP
PHP下的浮点运算不准的解决方法
2016/10/27 PHP
CI框架AR数据库操作常用函数总结
2016/11/21 PHP
PHP实现重载的常用方法实例详解
2017/10/18 PHP
thinkphp中U方法按路由规则生成url的方法
2018/03/12 PHP
枚举JavaScript对象的函数
2006/12/22 Javascript
15个款优秀的 jQuery 图片特效插件推荐
2011/11/21 Javascript
js实现拖拽 闭包函数详细介绍
2012/11/25 Javascript
推荐6款基于jQuery实现图片效果插件
2014/12/07 Javascript
jquery操作复选框checkbox的方法汇总
2015/02/05 Javascript
jQuery延迟加载图片插件Lazy Load使用指南
2015/03/25 Javascript
JavaScript中数组添加值和访问值常见问题
2016/02/06 Javascript
使用jQuery Rotare实现微信大转盘抽奖功能
2016/06/20 Javascript
EasyUI在表单提交之前进行验证的实例代码
2016/06/24 Javascript
详解PHP中pathinfo()函数导致的安全问题
2017/01/05 Javascript
Vue.js项目部署到服务器的详细步骤
2017/07/17 Javascript
p5.js实现斐波那契螺旋的示例代码
2018/03/22 Javascript
详解Vue结合后台的列表增删改案例
2018/08/21 Javascript
JavaScript canvas实现雪花随机动态飘落
2020/02/08 Javascript
Vue如何实现变量表达式选择器
2021/02/18 Vue.js
Vue-router编程式导航的两种实现代码
2021/03/04 Vue.js
python判断完全平方数的方法
2018/11/13 Python
python实现键盘输入的实操方法
2019/07/16 Python
OpenCV利用python来实现图像的直方图均衡化
2020/10/21 Python
西班牙在线宠物商店:zooplus.es
2017/02/24 全球购物
PatPat德国:妈妈的每日优惠
2019/10/02 全球购物
Fox Racing官方网站:越野摩托车和山地自行车装备和服装
2019/12/23 全球购物
报纸媒体创意广告词
2014/03/17 职场文书
文体活动实施方案
2014/03/27 职场文书
地质灾害防治方案
2014/05/14 职场文书
2015年元旦促销方案书
2014/12/09 职场文书
党员违纪检讨书
2015/05/05 职场文书
2015年妇产科工作总结
2015/05/18 职场文书
主题班会开场白
2015/06/01 职场文书
三严三实·严以律己心得体会
2016/01/13 职场文书