tensorflow 自定义损失函数示例代码


Posted in Python onFebruary 05, 2020

这个自定义损失函数的背景:(一般回归用的损失函数是MSE, 但要看实际遇到的情况而有所改变)

我们现在想要做一个回归,来预估某个商品的销量,现在我们知道,一件商品的成本是1元,售价是10元。

如果我们用均方差来算的话,如果预估多一个,则损失一块钱,预估少一个,则损失9元钱(少赚的)。

显然,我宁愿预估多了,也不想预估少了。

所以,我们就自己定义一个损失函数,用来分段地看,当yhat 比 y大时怎么样,当yhat比y小时怎么样。

(yhat沿用吴恩达课堂中的叫法)

import tensorflow as tf
from numpy.random import RandomState
batch_size = 8
# 两个输入节点
x = tf.placeholder(tf.float32, shape=(None, 2), name="x-input")
# 回归问题一般只有一个输出节点
y_ = tf.placeholder(tf.float32, shape=(None, 1), name="y-input")
# 定义了一个单层的神经网络前向传播的过程,这里就是简单加权和
w1 = tf.Variable(tf.random_normal([2, 1], stddev=1, seed=1))
y = tf.matmul(x, w1)
# 定义预测多了和预测少了的成本
loss_less = 10
loss_more = 1
#在windows下,下面用这个where替代,因为调用tf.select会报错
loss = tf.reduce_sum(tf.where(tf.greater(y, y_), (y - y_)*loss_more, (y_-y)*loss_less))
train_step = tf.train.AdamOptimizer(0.001).minimize(loss)
#通过随机数生成一个模拟数据集
rdm = RandomState(1)
dataset_size = 128
X = rdm.rand(dataset_size, 2)
"""
设置回归的正确值为两个输入的和加上一个随机量,之所以要加上一个随机量是
为了加入不可预测的噪音,否则不同损失函数的意义就不大了,因为不同损失函数
都会在能完全预测正确的时候最低。一般来说,噪音为一个均值为0的小量,所以
这里的噪音设置为-0.05, 0.05的随机数。
"""
Y = [[x1 + x2 + rdm.rand()/10.0-0.05] for (x1, x2) in X]
with tf.Session() as sess:
 init = tf.global_variables_initializer()
 sess.run(init)
 steps = 5000
 for i in range(steps):
  start = (i * batch_size) % dataset_size
  end = min(start + batch_size, dataset_size)
  sess.run(train_step, feed_dict={x:X[start:end], y_:Y[start:end]})
 print(sess.run(w1))

[[ 1.01934695]
[ 1.04280889]

最终结果如上面所示。

因为我们当初生成训练数据的时候,y是x1 + x2,所以回归结果应该是1,1才对。
但是,由于我们加了自己定义的损失函数,所以,倾向于预估多一点。

如果,我们将loss_less和loss_more对调,我们看一下结果:

[[ 0.95525807]
[ 0.9813394 ]]

通过这个例子,我们可以看出,对于相同的神经网络,不同的损失函数会对训练出来的模型产生重要的影响。

引用:以上实例为《Tensorflow实战 Google深度学习框架》中提供。

总结

以上所述是小编给大家介绍的tensorflow 自定义损失函数示例,希望对大家有所帮助!

Python 相关文章推荐
python回调函数用法实例分析
May 09 Python
python学习之第三方包安装方法(两种方法)
Jul 30 Python
python爬取网页内容转换为PDF文件
Jul 28 Python
Sanic框架安装与简单入门示例
Jul 16 Python
Python的argparse库使用详解
Oct 09 Python
python利用ffmpeg进行录制屏幕的方法
Jan 10 Python
python设定并获取socket超时时间的方法
Jan 12 Python
selenium+python自动化测试之环境搭建
Jan 23 Python
python实现将视频按帧读取到自定义目录
Dec 10 Python
python获取系统内存占用信息的实例方法
Jul 17 Python
python实现MySQL指定表增量同步数据到clickhouse的脚本
Feb 26 Python
在 Golang 中实现 Cache::remember 方法详解
Mar 30 Python
利用Tensorflow的队列多线程读取数据方式
Feb 05 #Python
Tensorflow 多线程与多进程数据加载实例
Feb 05 #Python
TensorFlow自定义损失函数来预测商品销售量
Feb 05 #Python
解决Tensorflow 内存泄露问题
Feb 05 #Python
TensorFlow实现指数衰减学习率的方法
Feb 05 #Python
关于Tensorflow使用CPU报错的解决方式
Feb 05 #Python
解决Tensorflow sess.run导致的内存溢出问题
Feb 05 #Python
You might like
初品cakephp 入门基础
2012/02/16 PHP
php图片加水印原理(超简单的实例代码)
2013/01/18 PHP
php实现统计邮件大小的方法
2013/08/06 PHP
Laravel框架实现发送短信验证功能代码
2016/06/06 PHP
mac系统下为 php 添加 pcntl 扩展
2016/08/28 PHP
thinkPHP5项目中实现QQ第三方登录功能
2017/10/20 PHP
php爬取天猫和淘宝商品数据
2018/02/23 PHP
JavaScript Prototype对象
2009/01/07 Javascript
js中设置元素class的三种方法小结
2011/08/28 Javascript
javaScript 页面自动加载事件详解
2014/02/10 Javascript
js实现点击左右按钮轮播图片效果实例
2015/01/29 Javascript
jQuery获取上传文件的名称的正则表达式
2015/05/21 Javascript
ion content 滚动到底部会遮住一部分视图的快速解决方法
2016/09/06 Javascript
jQuery实现表格文本框淡入更改值后淡出效果
2016/09/27 Javascript
jquery实现左右轮播图效果
2017/09/28 jQuery
Vue2.0 实现移动端图片上传功能
2018/05/30 Javascript
5分钟快速掌握JS中var、let和const的异同
2018/09/19 Javascript
python字典多条件排序方法实例
2014/06/30 Python
Python迭代器和生成器介绍
2015/03/06 Python
Python中字典的setdefault()方法教程
2017/02/07 Python
Python 中urls.py:URL dispatcher(路由配置文件)详解
2017/03/24 Python
使用python进行拆分大文件的方法
2018/12/10 Python
Python JSON格式数据的提取和保存的实现
2019/03/22 Python
python 并发编程 多路复用IO模型详解
2019/08/20 Python
Pytorch 计算误判率,计算准确率,计算召回率的例子
2020/01/18 Python
HTML5重塑Web世界它将如何改变互联网
2012/12/17 HTML / CSS
介绍一下except的用法和作用
2015/01/22 面试题
上班迟到检讨书
2014/01/10 职场文书
《悯农》教学反思
2014/04/28 职场文书
企业安全标语
2014/06/07 职场文书
道路运输企业安全生产责任书
2014/07/28 职场文书
2014年团支书工作总结
2014/11/14 职场文书
中学生社区服务活动报告
2015/02/05 职场文书
高老头读书笔记
2015/06/30 职场文书
党章党规党纪学习心得体会
2016/01/14 职场文书
详解redis分布式锁的这些坑
2021/05/19 Redis