tensorflow 自定义损失函数示例代码


Posted in Python onFebruary 05, 2020

这个自定义损失函数的背景:(一般回归用的损失函数是MSE, 但要看实际遇到的情况而有所改变)

我们现在想要做一个回归,来预估某个商品的销量,现在我们知道,一件商品的成本是1元,售价是10元。

如果我们用均方差来算的话,如果预估多一个,则损失一块钱,预估少一个,则损失9元钱(少赚的)。

显然,我宁愿预估多了,也不想预估少了。

所以,我们就自己定义一个损失函数,用来分段地看,当yhat 比 y大时怎么样,当yhat比y小时怎么样。

(yhat沿用吴恩达课堂中的叫法)

import tensorflow as tf
from numpy.random import RandomState
batch_size = 8
# 两个输入节点
x = tf.placeholder(tf.float32, shape=(None, 2), name="x-input")
# 回归问题一般只有一个输出节点
y_ = tf.placeholder(tf.float32, shape=(None, 1), name="y-input")
# 定义了一个单层的神经网络前向传播的过程,这里就是简单加权和
w1 = tf.Variable(tf.random_normal([2, 1], stddev=1, seed=1))
y = tf.matmul(x, w1)
# 定义预测多了和预测少了的成本
loss_less = 10
loss_more = 1
#在windows下,下面用这个where替代,因为调用tf.select会报错
loss = tf.reduce_sum(tf.where(tf.greater(y, y_), (y - y_)*loss_more, (y_-y)*loss_less))
train_step = tf.train.AdamOptimizer(0.001).minimize(loss)
#通过随机数生成一个模拟数据集
rdm = RandomState(1)
dataset_size = 128
X = rdm.rand(dataset_size, 2)
"""
设置回归的正确值为两个输入的和加上一个随机量,之所以要加上一个随机量是
为了加入不可预测的噪音,否则不同损失函数的意义就不大了,因为不同损失函数
都会在能完全预测正确的时候最低。一般来说,噪音为一个均值为0的小量,所以
这里的噪音设置为-0.05, 0.05的随机数。
"""
Y = [[x1 + x2 + rdm.rand()/10.0-0.05] for (x1, x2) in X]
with tf.Session() as sess:
 init = tf.global_variables_initializer()
 sess.run(init)
 steps = 5000
 for i in range(steps):
  start = (i * batch_size) % dataset_size
  end = min(start + batch_size, dataset_size)
  sess.run(train_step, feed_dict={x:X[start:end], y_:Y[start:end]})
 print(sess.run(w1))

[[ 1.01934695]
[ 1.04280889]

最终结果如上面所示。

因为我们当初生成训练数据的时候,y是x1 + x2,所以回归结果应该是1,1才对。
但是,由于我们加了自己定义的损失函数,所以,倾向于预估多一点。

如果,我们将loss_less和loss_more对调,我们看一下结果:

[[ 0.95525807]
[ 0.9813394 ]]

通过这个例子,我们可以看出,对于相同的神经网络,不同的损失函数会对训练出来的模型产生重要的影响。

引用:以上实例为《Tensorflow实战 Google深度学习框架》中提供。

总结

以上所述是小编给大家介绍的tensorflow 自定义损失函数示例,希望对大家有所帮助!

Python 相关文章推荐
用Python编写一个基于终端的实现翻译的脚本
Apr 24 Python
Python统计文件中去重后uuid个数的方法
Jul 30 Python
python3实现暴力穷举博客园密码
Jun 19 Python
Python基于回溯法子集树模板解决旅行商问题(TSP)实例
Sep 05 Python
python实现反转部分单向链表
Sep 27 Python
Python功能点实现:函数级/代码块级计时器
Jan 02 Python
在Pandas中处理NaN值的方法
Jun 25 Python
python搜索包的路径的实现方法
Jul 19 Python
Python Django框架模板渲染功能示例
Nov 08 Python
Python创建一个元素都为0的列表实例
Nov 28 Python
Pytorch DataLoader 变长数据处理方式
Jan 08 Python
在pytorch中动态调整优化器的学习率方式
Jun 24 Python
利用Tensorflow的队列多线程读取数据方式
Feb 05 #Python
Tensorflow 多线程与多进程数据加载实例
Feb 05 #Python
TensorFlow自定义损失函数来预测商品销售量
Feb 05 #Python
解决Tensorflow 内存泄露问题
Feb 05 #Python
TensorFlow实现指数衰减学习率的方法
Feb 05 #Python
关于Tensorflow使用CPU报错的解决方式
Feb 05 #Python
解决Tensorflow sess.run导致的内存溢出问题
Feb 05 #Python
You might like
一些花式咖啡的配方
2021/03/03 冲泡冲煮
php简单统计字符串单词数量的方法
2015/06/19 PHP
浅谈php://filter的妙用
2019/03/05 PHP
Laravel框架查询构造器 CURD操作示例
2019/09/04 PHP
Javascript的构造函数和constructor属性
2010/01/09 Javascript
jquery的map与get方法详解
2013/11/04 Javascript
解决自定义$(id)的方法与jquery选择器$冲突的问题
2014/06/14 Javascript
jQuery实现ichat在线客服插件
2014/12/29 Javascript
DOM基础教程之使用DOM控制表单
2015/01/20 Javascript
JavaScript删除数组元素的方法
2015/03/20 Javascript
jQuery 更改checkbox的状态,无效的解决方法
2016/07/22 Javascript
js删除数组元素、清空数组的简单方法(必看)
2016/07/27 Javascript
jQuery插件ajaxFileUpload使用实例解析
2016/10/19 Javascript
angularjs之$timeout指令详解
2017/06/13 Javascript
关于express与koa的使用对比详解
2018/01/25 Javascript
vue-router之nuxt动态路由设置的两种方法小结
2018/09/26 Javascript
layui table 复选框跳页后再回来保持原来选中的状态示例
2019/10/26 Javascript
关于小程序优化的一些建议(小结)
2020/12/10 Javascript
[01:13]2014DOTA2西雅图邀请赛 舌尖上的TI4
2014/07/08 DOTA
浅析python 中__name__ = '__main__' 的作用
2014/07/05 Python
Python通过调用mysql存储过程实现更新数据功能示例
2018/04/03 Python
详解Python解决抓取内容乱码问题(decode和encode解码)
2019/03/29 Python
python打包成so文件过程解析
2019/09/28 Python
40行Python代码实现天气预报和每日鸡汤推送功能
2020/02/27 Python
pycharm工具连接mysql数据库失败问题
2020/04/01 Python
Pytorch 卷积中的 Input Shape用法
2020/06/29 Python
Python中Selenium模块的使用详解
2020/10/09 Python
详解Django中异步任务之django-celery
2020/11/05 Python
中国双语服务优势的在线购票及活动平台:247tickets
2018/10/26 全球购物
在子网210.27.48.21/30种有多少个可用地址?分别是什么?
2014/07/27 面试题
新闻学专业应届生求职信
2013/11/08 职场文书
魅力教师事迹材料
2014/01/10 职场文书
绿色小区申报材料
2014/08/22 职场文书
乡镇党建工作汇报材料
2014/10/27 职场文书
Python趣味挑战之用pygame实现简单的金币旋转效果
2021/05/31 Python
sql server偶发出现死锁的解决方法
2022/04/10 SQL Server