tensorflow 自定义损失函数示例代码


Posted in Python onFebruary 05, 2020

这个自定义损失函数的背景:(一般回归用的损失函数是MSE, 但要看实际遇到的情况而有所改变)

我们现在想要做一个回归,来预估某个商品的销量,现在我们知道,一件商品的成本是1元,售价是10元。

如果我们用均方差来算的话,如果预估多一个,则损失一块钱,预估少一个,则损失9元钱(少赚的)。

显然,我宁愿预估多了,也不想预估少了。

所以,我们就自己定义一个损失函数,用来分段地看,当yhat 比 y大时怎么样,当yhat比y小时怎么样。

(yhat沿用吴恩达课堂中的叫法)

import tensorflow as tf
from numpy.random import RandomState
batch_size = 8
# 两个输入节点
x = tf.placeholder(tf.float32, shape=(None, 2), name="x-input")
# 回归问题一般只有一个输出节点
y_ = tf.placeholder(tf.float32, shape=(None, 1), name="y-input")
# 定义了一个单层的神经网络前向传播的过程,这里就是简单加权和
w1 = tf.Variable(tf.random_normal([2, 1], stddev=1, seed=1))
y = tf.matmul(x, w1)
# 定义预测多了和预测少了的成本
loss_less = 10
loss_more = 1
#在windows下,下面用这个where替代,因为调用tf.select会报错
loss = tf.reduce_sum(tf.where(tf.greater(y, y_), (y - y_)*loss_more, (y_-y)*loss_less))
train_step = tf.train.AdamOptimizer(0.001).minimize(loss)
#通过随机数生成一个模拟数据集
rdm = RandomState(1)
dataset_size = 128
X = rdm.rand(dataset_size, 2)
"""
设置回归的正确值为两个输入的和加上一个随机量,之所以要加上一个随机量是
为了加入不可预测的噪音,否则不同损失函数的意义就不大了,因为不同损失函数
都会在能完全预测正确的时候最低。一般来说,噪音为一个均值为0的小量,所以
这里的噪音设置为-0.05, 0.05的随机数。
"""
Y = [[x1 + x2 + rdm.rand()/10.0-0.05] for (x1, x2) in X]
with tf.Session() as sess:
 init = tf.global_variables_initializer()
 sess.run(init)
 steps = 5000
 for i in range(steps):
  start = (i * batch_size) % dataset_size
  end = min(start + batch_size, dataset_size)
  sess.run(train_step, feed_dict={x:X[start:end], y_:Y[start:end]})
 print(sess.run(w1))

[[ 1.01934695]
[ 1.04280889]

最终结果如上面所示。

因为我们当初生成训练数据的时候,y是x1 + x2,所以回归结果应该是1,1才对。
但是,由于我们加了自己定义的损失函数,所以,倾向于预估多一点。

如果,我们将loss_less和loss_more对调,我们看一下结果:

[[ 0.95525807]
[ 0.9813394 ]]

通过这个例子,我们可以看出,对于相同的神经网络,不同的损失函数会对训练出来的模型产生重要的影响。

引用:以上实例为《Tensorflow实战 Google深度学习框架》中提供。

总结

以上所述是小编给大家介绍的tensorflow 自定义损失函数示例,希望对大家有所帮助!

Python 相关文章推荐
python使用Image处理图片常用技巧分析
Jun 01 Python
Eclipse中Python开发环境搭建简单教程
Mar 23 Python
Python中使用装饰器来优化尾递归的示例
Jun 18 Python
Python利用Beautiful Soup模块搜索内容详解
Mar 29 Python
详解用python写一个抽奖程序
May 10 Python
使用Python做定时任务及时了解互联网动态
May 15 Python
python实现知乎高颜值图片爬取
Aug 12 Python
python实现多线程端口扫描
Aug 31 Python
Django实现简单网页弹出警告代码
Nov 15 Python
Python requests模块cookie实例解析
Apr 14 Python
python3 循环读取excel文件并写入json操作
Jul 14 Python
Python 类,对象,数据分类,函数参数传递详解
Sep 25 Python
利用Tensorflow的队列多线程读取数据方式
Feb 05 #Python
Tensorflow 多线程与多进程数据加载实例
Feb 05 #Python
TensorFlow自定义损失函数来预测商品销售量
Feb 05 #Python
解决Tensorflow 内存泄露问题
Feb 05 #Python
TensorFlow实现指数衰减学习率的方法
Feb 05 #Python
关于Tensorflow使用CPU报错的解决方式
Feb 05 #Python
解决Tensorflow sess.run导致的内存溢出问题
Feb 05 #Python
You might like
php中用memcached实现页面防刷新功能
2014/08/19 PHP
跨浏览器PHP下载文件名中的中文乱码问题解决方法
2015/03/05 PHP
PHP内存使用情况如何获取
2015/10/10 PHP
详解PHP5.6.30与Apache2.4.x配置
2017/06/02 PHP
JavaScript 异步调用框架 (Part 2 - 用例设计)
2009/08/03 Javascript
Jquery中val()表单取值赋值的实例代码
2013/08/15 Javascript
JS+DIV实现鼠标划过切换层效果的实例代码
2013/11/26 Javascript
jquery datatable后台封装数据示例代码
2014/08/07 Javascript
jQuery超赞的评分插件(8款)
2015/08/20 Javascript
基于jQuery实现二级下拉菜单效果
2016/02/01 Javascript
JS控制伪元素的方法汇总
2016/04/06 Javascript
JS实现根据密码长度显示安全条功能
2017/03/08 Javascript
webpack配置的最佳实践分享
2017/04/21 Javascript
基于Node的React图片上传组件实现实例代码
2017/05/10 Javascript
js简易版购物车功能
2017/06/17 Javascript
node-sass安装失败的原因与解决方法
2017/09/04 Javascript
JavaScript模拟实现封装的三种方式及写法区别
2017/10/27 Javascript
jquery实现点击a链接,跳转之后,该a链接处显示背景色的方法
2018/01/18 jQuery
vuejs实现ready函数加载完之后执行某个函数的方法
2018/08/31 Javascript
微信公众平台 发送模板消息(Java接口开发)
2019/04/17 Javascript
梯度下降法介绍及利用Python实现的方法示例
2017/07/12 Python
pandas数据处理基础之筛选指定行或者指定列的数据
2018/05/03 Python
解决Python一行输出不显示的问题
2018/12/03 Python
在Python中构建增广矩阵的实现方法
2019/07/01 Python
Pytorch 的损失函数Loss function使用详解
2020/01/02 Python
python利用google翻译方法实例(翻译字幕文件)
2020/09/21 Python
Python:__eq__和__str__函数的使用示例
2020/09/26 Python
html+css3实现的登录界面
2020/12/09 HTML / CSS
英文求职信结束语大全
2013/10/26 职场文书
大堂副理的岗位职责范文
2014/02/17 职场文书
环保倡议书400字
2014/05/15 职场文书
2014年教师节活动总结
2014/08/29 职场文书
小学教师师德整改措施
2014/09/29 职场文书
新员工入职欢迎词
2015/01/23 职场文书
2015年会计年终工作总结
2015/05/26 职场文书
我的中国梦主题班会
2015/08/14 职场文书