tensorflow 自定义损失函数示例代码


Posted in Python onFebruary 05, 2020

这个自定义损失函数的背景:(一般回归用的损失函数是MSE, 但要看实际遇到的情况而有所改变)

我们现在想要做一个回归,来预估某个商品的销量,现在我们知道,一件商品的成本是1元,售价是10元。

如果我们用均方差来算的话,如果预估多一个,则损失一块钱,预估少一个,则损失9元钱(少赚的)。

显然,我宁愿预估多了,也不想预估少了。

所以,我们就自己定义一个损失函数,用来分段地看,当yhat 比 y大时怎么样,当yhat比y小时怎么样。

(yhat沿用吴恩达课堂中的叫法)

import tensorflow as tf
from numpy.random import RandomState
batch_size = 8
# 两个输入节点
x = tf.placeholder(tf.float32, shape=(None, 2), name="x-input")
# 回归问题一般只有一个输出节点
y_ = tf.placeholder(tf.float32, shape=(None, 1), name="y-input")
# 定义了一个单层的神经网络前向传播的过程,这里就是简单加权和
w1 = tf.Variable(tf.random_normal([2, 1], stddev=1, seed=1))
y = tf.matmul(x, w1)
# 定义预测多了和预测少了的成本
loss_less = 10
loss_more = 1
#在windows下,下面用这个where替代,因为调用tf.select会报错
loss = tf.reduce_sum(tf.where(tf.greater(y, y_), (y - y_)*loss_more, (y_-y)*loss_less))
train_step = tf.train.AdamOptimizer(0.001).minimize(loss)
#通过随机数生成一个模拟数据集
rdm = RandomState(1)
dataset_size = 128
X = rdm.rand(dataset_size, 2)
"""
设置回归的正确值为两个输入的和加上一个随机量,之所以要加上一个随机量是
为了加入不可预测的噪音,否则不同损失函数的意义就不大了,因为不同损失函数
都会在能完全预测正确的时候最低。一般来说,噪音为一个均值为0的小量,所以
这里的噪音设置为-0.05, 0.05的随机数。
"""
Y = [[x1 + x2 + rdm.rand()/10.0-0.05] for (x1, x2) in X]
with tf.Session() as sess:
 init = tf.global_variables_initializer()
 sess.run(init)
 steps = 5000
 for i in range(steps):
  start = (i * batch_size) % dataset_size
  end = min(start + batch_size, dataset_size)
  sess.run(train_step, feed_dict={x:X[start:end], y_:Y[start:end]})
 print(sess.run(w1))

[[ 1.01934695]
[ 1.04280889]

最终结果如上面所示。

因为我们当初生成训练数据的时候,y是x1 + x2,所以回归结果应该是1,1才对。
但是,由于我们加了自己定义的损失函数,所以,倾向于预估多一点。

如果,我们将loss_less和loss_more对调,我们看一下结果:

[[ 0.95525807]
[ 0.9813394 ]]

通过这个例子,我们可以看出,对于相同的神经网络,不同的损失函数会对训练出来的模型产生重要的影响。

引用:以上实例为《Tensorflow实战 Google深度学习框架》中提供。

总结

以上所述是小编给大家介绍的tensorflow 自定义损失函数示例,希望对大家有所帮助!

Python 相关文章推荐
Python生成验证码实例
Aug 21 Python
10种检测Python程序运行时间、CPU和内存占用的方法
Apr 01 Python
用Python编写简单的微博爬虫
Mar 04 Python
python中函数传参详解
Jul 03 Python
深入解析Python的Tornado框架中内置的模板引擎
Jul 11 Python
浅谈终端直接执行py文件,不需要python命令
Jan 23 Python
python中for循环输出列表索引与对应的值方法
Nov 07 Python
python自定义时钟类、定时任务类
Feb 22 Python
Django CBV与FBV原理及实例详解
Aug 12 Python
Python单链表原理与实现方法详解
Feb 22 Python
Python用户自定义异常的实现
Dec 25 Python
Python turtle编写简单的球类小游戏
Mar 31 Python
利用Tensorflow的队列多线程读取数据方式
Feb 05 #Python
Tensorflow 多线程与多进程数据加载实例
Feb 05 #Python
TensorFlow自定义损失函数来预测商品销售量
Feb 05 #Python
解决Tensorflow 内存泄露问题
Feb 05 #Python
TensorFlow实现指数衰减学习率的方法
Feb 05 #Python
关于Tensorflow使用CPU报错的解决方式
Feb 05 #Python
解决Tensorflow sess.run导致的内存溢出问题
Feb 05 #Python
You might like
PHP封装的字符串加密解密函数
2015/12/18 PHP
php、mysql查询当天,查询本周,查询本月的数据实例(字段是时间戳)
2017/02/04 PHP
php实现微信企业号支付个人的方法详解
2017/07/26 PHP
对textarea框的代码调试,而且功能上使用非常方便,酷
2006/06/30 Javascript
Javascript String对象扩展HTML编码和解码的方法
2009/06/02 Javascript
在JavaScript中获取请求的URL参数
2010/12/22 Javascript
javascript嵌套函数和在函数内调用外部函数的区别分析
2016/01/31 Javascript
Javascript6中字符串的四个新用法分享
2016/09/11 Javascript
浅谈JS之iframe中的窗口
2016/09/13 Javascript
JS中事件冒泡和事件捕获介绍
2016/12/13 Javascript
nodejs入门教程一:概念与用法简介
2017/04/24 NodeJs
vue.js实现数据动态响应 Vue.set的简单应用
2017/06/15 Javascript
浅谈JS中的常用选择器及属性、方法的调用
2017/07/28 Javascript
vue2.0使用swiper组件实现轮播的示例代码
2018/03/03 Javascript
vue实现一个炫酷的日历组件
2018/10/08 Javascript
解决vue+ element ui 表单验证有值但验证失败问题
2020/01/16 Javascript
element-ui 实现响应式导航栏的示例代码
2020/05/08 Javascript
python中List的sort方法指南
2014/09/01 Python
python 3利用BeautifulSoup抓取div标签的方法示例
2017/05/28 Python
Python爬虫DNS解析缓存方法实例分析
2017/06/02 Python
开源软件包和环境管理系统Anaconda的安装使用
2017/09/04 Python
Python网络编程使用select实现socket全双工异步通信功能示例
2018/04/09 Python
python实现简单五子棋游戏
2019/06/18 Python
Python进度条的制作代码实例
2019/08/31 Python
PyTorch中torch.tensor与torch.Tensor的区别详解
2020/05/18 Python
Python 利用Entrez库筛选下载PubMed文献摘要的示例
2020/11/24 Python
详解Css3新特性应用之过渡与动画
2017/01/10 HTML / CSS
CSS3的 fit-content实现水平居中
2017/09/07 HTML / CSS
全球酒店预订网站:Hotels.com
2016/08/10 全球购物
当当网官方旗舰店:中国图书销售夺金品牌
2018/04/02 全球购物
植村秀美国官网:Shu Uemura美国
2019/03/19 全球购物
cf收人广告词大全
2014/03/14 职场文书
珍惜资源的建议书
2014/08/26 职场文书
党风廉政建设心得体会
2019/05/21 职场文书
Redis实战高并发之扣减库存项目
2022/04/14 Redis
CSS实现背景图片全屏铺满自适应的3种方式
2022/07/07 HTML / CSS