tensorflow 自定义损失函数示例代码


Posted in Python onFebruary 05, 2020

这个自定义损失函数的背景:(一般回归用的损失函数是MSE, 但要看实际遇到的情况而有所改变)

我们现在想要做一个回归,来预估某个商品的销量,现在我们知道,一件商品的成本是1元,售价是10元。

如果我们用均方差来算的话,如果预估多一个,则损失一块钱,预估少一个,则损失9元钱(少赚的)。

显然,我宁愿预估多了,也不想预估少了。

所以,我们就自己定义一个损失函数,用来分段地看,当yhat 比 y大时怎么样,当yhat比y小时怎么样。

(yhat沿用吴恩达课堂中的叫法)

import tensorflow as tf
from numpy.random import RandomState
batch_size = 8
# 两个输入节点
x = tf.placeholder(tf.float32, shape=(None, 2), name="x-input")
# 回归问题一般只有一个输出节点
y_ = tf.placeholder(tf.float32, shape=(None, 1), name="y-input")
# 定义了一个单层的神经网络前向传播的过程,这里就是简单加权和
w1 = tf.Variable(tf.random_normal([2, 1], stddev=1, seed=1))
y = tf.matmul(x, w1)
# 定义预测多了和预测少了的成本
loss_less = 10
loss_more = 1
#在windows下,下面用这个where替代,因为调用tf.select会报错
loss = tf.reduce_sum(tf.where(tf.greater(y, y_), (y - y_)*loss_more, (y_-y)*loss_less))
train_step = tf.train.AdamOptimizer(0.001).minimize(loss)
#通过随机数生成一个模拟数据集
rdm = RandomState(1)
dataset_size = 128
X = rdm.rand(dataset_size, 2)
"""
设置回归的正确值为两个输入的和加上一个随机量,之所以要加上一个随机量是
为了加入不可预测的噪音,否则不同损失函数的意义就不大了,因为不同损失函数
都会在能完全预测正确的时候最低。一般来说,噪音为一个均值为0的小量,所以
这里的噪音设置为-0.05, 0.05的随机数。
"""
Y = [[x1 + x2 + rdm.rand()/10.0-0.05] for (x1, x2) in X]
with tf.Session() as sess:
 init = tf.global_variables_initializer()
 sess.run(init)
 steps = 5000
 for i in range(steps):
  start = (i * batch_size) % dataset_size
  end = min(start + batch_size, dataset_size)
  sess.run(train_step, feed_dict={x:X[start:end], y_:Y[start:end]})
 print(sess.run(w1))

[[ 1.01934695]
[ 1.04280889]

最终结果如上面所示。

因为我们当初生成训练数据的时候,y是x1 + x2,所以回归结果应该是1,1才对。
但是,由于我们加了自己定义的损失函数,所以,倾向于预估多一点。

如果,我们将loss_less和loss_more对调,我们看一下结果:

[[ 0.95525807]
[ 0.9813394 ]]

通过这个例子,我们可以看出,对于相同的神经网络,不同的损失函数会对训练出来的模型产生重要的影响。

引用:以上实例为《Tensorflow实战 Google深度学习框架》中提供。

总结

以上所述是小编给大家介绍的tensorflow 自定义损失函数示例,希望对大家有所帮助!

Python 相关文章推荐
python实现异步回调机制代码分享
Jan 10 Python
linux系统使用python监测网络接口获取网络的输入输出
Jan 15 Python
python实现上传样本到virustotal并查询扫描信息的方法
Oct 05 Python
在Python中使用next()方法操作文件的教程
May 24 Python
python xlsxwriter创建excel图表的方法
Jun 11 Python
flask session组件的使用示例
Dec 25 Python
基于python的selenium两种文件上传操作实现详解
Sep 19 Python
解决jupyter notebook打不开无反应 浏览器未启动的问题
Apr 10 Python
opencv 实现特定颜色线条提取与定位操作
Jun 02 Python
详解Python IO口多路复用
Jun 17 Python
如何利用python生成MD5并去重
Dec 07 Python
python如何进行基准测试
Apr 26 Python
利用Tensorflow的队列多线程读取数据方式
Feb 05 #Python
Tensorflow 多线程与多进程数据加载实例
Feb 05 #Python
TensorFlow自定义损失函数来预测商品销售量
Feb 05 #Python
解决Tensorflow 内存泄露问题
Feb 05 #Python
TensorFlow实现指数衰减学习率的方法
Feb 05 #Python
关于Tensorflow使用CPU报错的解决方式
Feb 05 #Python
解决Tensorflow sess.run导致的内存溢出问题
Feb 05 #Python
You might like
农民和部队如何穿矿
2020/03/04 星际争霸
php中获取指定IP的物理地址的代码(正则表达式)
2011/06/23 PHP
destoon实现公司新闻详细页添加评论功能的方法
2014/07/15 PHP
PHP单例模式详细介绍
2015/07/01 PHP
yii2.0整合阿里云oss删除单个文件的方法
2017/09/19 PHP
PHP 计算两个时间段之间交集的天数示例
2019/10/24 PHP
PHP7 参数处理机制修改
2021/03/09 PHP
jquerymobile checkbox及时刷新才能获取其准确值
2012/04/14 Javascript
jQuery之ajax技术的详细介绍
2013/06/19 Javascript
seaJs的模块定义和模块加载浅析
2014/06/06 Javascript
浅谈JavaScript超时调用和间歇调用
2015/08/30 Javascript
javascript设计模式--策略模式之输入验证
2015/11/27 Javascript
AngularJS中$interval的用法详解
2016/02/02 Javascript
关于安卓手机微信浏览器中使用XMLHttpRequest 2上传图片显示字节数为0的解决办法
2016/05/17 Javascript
浅析$.getJSON异步请求和同步请求
2016/06/06 Javascript
JSON 数据格式详解
2017/09/13 Javascript
node简单实现一个更改头像功能的示例
2017/12/29 Javascript
Vue中的slot使用插槽分发内容的方法
2018/03/01 Javascript
详解在HTTPS 项目中使用百度地图 API
2019/04/26 Javascript
vue-autoui自匹配webapi的UI控件的实现
2020/03/20 Javascript
详解Vue数据驱动原理
2020/11/17 Javascript
[01:10:02]IG vs Winstrike 2018国际邀请赛小组赛BO2 第一场 8.19
2018/08/21 DOTA
Python编程入门的一些基本知识
2015/05/13 Python
Python实现文件按照日期命名的方法
2015/07/09 Python
Python实现的破解字符串找茬游戏算法示例
2017/09/25 Python
HTML5离线缓存Manifest是什么
2016/03/09 HTML / CSS
html5使用canvas压缩图片的示例代码
2018/09/11 HTML / CSS
波兰最大的度假胜地和城市公寓租赁运营商:Sun & Snow
2018/10/18 全球购物
一些Solaris面试题
2015/12/22 面试题
How to spawning asynchronous work in J2EE
2016/08/29 面试题
高二历史教学反思
2014/01/25 职场文书
公司委托书怎么写
2014/08/02 职场文书
个人主要事迹材料
2014/08/26 职场文书
工作失职检讨书(精华篇)
2014/10/15 职场文书
《最后一头战象》教学反思
2016/02/16 职场文书
JAVA SpringMVC实现自定义拦截器
2022/03/16 Python