tensorflow 自定义损失函数示例代码


Posted in Python onFebruary 05, 2020

这个自定义损失函数的背景:(一般回归用的损失函数是MSE, 但要看实际遇到的情况而有所改变)

我们现在想要做一个回归,来预估某个商品的销量,现在我们知道,一件商品的成本是1元,售价是10元。

如果我们用均方差来算的话,如果预估多一个,则损失一块钱,预估少一个,则损失9元钱(少赚的)。

显然,我宁愿预估多了,也不想预估少了。

所以,我们就自己定义一个损失函数,用来分段地看,当yhat 比 y大时怎么样,当yhat比y小时怎么样。

(yhat沿用吴恩达课堂中的叫法)

import tensorflow as tf
from numpy.random import RandomState
batch_size = 8
# 两个输入节点
x = tf.placeholder(tf.float32, shape=(None, 2), name="x-input")
# 回归问题一般只有一个输出节点
y_ = tf.placeholder(tf.float32, shape=(None, 1), name="y-input")
# 定义了一个单层的神经网络前向传播的过程,这里就是简单加权和
w1 = tf.Variable(tf.random_normal([2, 1], stddev=1, seed=1))
y = tf.matmul(x, w1)
# 定义预测多了和预测少了的成本
loss_less = 10
loss_more = 1
#在windows下,下面用这个where替代,因为调用tf.select会报错
loss = tf.reduce_sum(tf.where(tf.greater(y, y_), (y - y_)*loss_more, (y_-y)*loss_less))
train_step = tf.train.AdamOptimizer(0.001).minimize(loss)
#通过随机数生成一个模拟数据集
rdm = RandomState(1)
dataset_size = 128
X = rdm.rand(dataset_size, 2)
"""
设置回归的正确值为两个输入的和加上一个随机量,之所以要加上一个随机量是
为了加入不可预测的噪音,否则不同损失函数的意义就不大了,因为不同损失函数
都会在能完全预测正确的时候最低。一般来说,噪音为一个均值为0的小量,所以
这里的噪音设置为-0.05, 0.05的随机数。
"""
Y = [[x1 + x2 + rdm.rand()/10.0-0.05] for (x1, x2) in X]
with tf.Session() as sess:
 init = tf.global_variables_initializer()
 sess.run(init)
 steps = 5000
 for i in range(steps):
  start = (i * batch_size) % dataset_size
  end = min(start + batch_size, dataset_size)
  sess.run(train_step, feed_dict={x:X[start:end], y_:Y[start:end]})
 print(sess.run(w1))

[[ 1.01934695]
[ 1.04280889]

最终结果如上面所示。

因为我们当初生成训练数据的时候,y是x1 + x2,所以回归结果应该是1,1才对。
但是,由于我们加了自己定义的损失函数,所以,倾向于预估多一点。

如果,我们将loss_less和loss_more对调,我们看一下结果:

[[ 0.95525807]
[ 0.9813394 ]]

通过这个例子,我们可以看出,对于相同的神经网络,不同的损失函数会对训练出来的模型产生重要的影响。

引用:以上实例为《Tensorflow实战 Google深度学习框架》中提供。

总结

以上所述是小编给大家介绍的tensorflow 自定义损失函数示例,希望对大家有所帮助!

Python 相关文章推荐
python开发之thread线程基础实例入门
Nov 11 Python
python下调用pytesseract识别某网站验证码的实现方法
Jun 06 Python
django反向解析URL和URL命名空间的方法
Jun 05 Python
python进阶之多线程对同一个全局变量的处理方法
Nov 09 Python
深入了解Python在HDA中的应用
Sep 05 Python
Python异步编程之协程任务的调度操作实例分析
Feb 01 Python
python GUI库图形界面开发之PyQt5窗口背景与不规则窗口实例
Feb 25 Python
Python关键字及可变参数*args,**kw原理解析
Apr 04 Python
Python使用20行代码实现微信聊天机器人
Jun 05 Python
利用Vscode进行Python开发环境配置的步骤
Jun 22 Python
python实现数字炸弹游戏
Jul 17 Python
Python实现列表拼接和去重的三种方式
Jul 02 Python
利用Tensorflow的队列多线程读取数据方式
Feb 05 #Python
Tensorflow 多线程与多进程数据加载实例
Feb 05 #Python
TensorFlow自定义损失函数来预测商品销售量
Feb 05 #Python
解决Tensorflow 内存泄露问题
Feb 05 #Python
TensorFlow实现指数衰减学习率的方法
Feb 05 #Python
关于Tensorflow使用CPU报错的解决方式
Feb 05 #Python
解决Tensorflow sess.run导致的内存溢出问题
Feb 05 #Python
You might like
PHP获取url的函数代码
2011/08/02 PHP
php数组函数序列之array_flip() 将数组键名与值对调
2011/11/07 PHP
Linux环境下搭建php开发环境的操作步骤
2013/06/17 PHP
Yii2.0 模态弹出框+ajax提交表单
2016/05/22 PHP
php利用ffmpeg提取视频中音频与视频画面的方法详解
2017/06/07 PHP
PHP微信企业号开发之回调模式开启与用法示例
2017/11/25 PHP
jQuery判断元素是否是隐藏的代码
2011/04/24 Javascript
JavaScript识别网页关键字并进行描红的方法
2015/11/09 Javascript
实例解析js中try、catch、finally的执行规则
2017/02/24 Javascript
javascript实现获取一个日期段内每天不同的价格(计算入住总价格)
2018/02/05 Javascript
vue 实现复制内容到粘贴板clipboard的方法
2018/03/17 Javascript
Vue刷新修改页面中数据的方法
2018/09/16 Javascript
nodejs实现范围请求的实现代码
2018/10/12 NodeJs
微信小程序学习笔记之登录API与获取用户信息操作图文详解
2019/03/29 Javascript
深入学习Vue nextTick的用法及原理
2019/10/08 Javascript
npx create-react-app xxx创建项目报错的解决办法
2020/02/17 Javascript
python互斥锁、加锁、同步机制、异步通信知识总结
2018/02/11 Python
Python lambda函数基本用法实例分析
2018/03/16 Python
Python打开文件,将list、numpy数组内容写入txt文件中的方法
2018/10/26 Python
python 实现得到当前时间偏移day天后的日期方法
2018/12/31 Python
python3去掉string中的标点符号方法
2019/01/22 Python
Python实现二叉搜索树BST的方法示例
2019/07/30 Python
基于python实现matlab filter函数过程详解
2020/06/08 Python
python爬虫破解字体加密案例详解
2021/03/02 Python
蔻驰意大利官网:COACH意大利
2019/01/16 全球购物
SQL里面IN比较快还是EXISTS比较快
2012/07/19 面试题
英语专业毕业生自我鉴定
2013/11/09 职场文书
公司培训心得体会
2014/01/03 职场文书
教师绩效工资方案
2014/02/01 职场文书
教师节学生演讲稿
2014/09/03 职场文书
2014年小学重阳节活动策划方案
2014/09/16 职场文书
戒毒悔改检讨书
2014/09/21 职场文书
小学中等生评语
2014/12/29 职场文书
自信主题班会
2015/08/14 职场文书
SpringBoot整合Minio文件存储
2022/04/03 Java/Android
vue3种table表格选项个数的控制方法
2022/04/14 Vue.js