tensorflow 自定义损失函数示例代码


Posted in Python onFebruary 05, 2020

这个自定义损失函数的背景:(一般回归用的损失函数是MSE, 但要看实际遇到的情况而有所改变)

我们现在想要做一个回归,来预估某个商品的销量,现在我们知道,一件商品的成本是1元,售价是10元。

如果我们用均方差来算的话,如果预估多一个,则损失一块钱,预估少一个,则损失9元钱(少赚的)。

显然,我宁愿预估多了,也不想预估少了。

所以,我们就自己定义一个损失函数,用来分段地看,当yhat 比 y大时怎么样,当yhat比y小时怎么样。

(yhat沿用吴恩达课堂中的叫法)

import tensorflow as tf
from numpy.random import RandomState
batch_size = 8
# 两个输入节点
x = tf.placeholder(tf.float32, shape=(None, 2), name="x-input")
# 回归问题一般只有一个输出节点
y_ = tf.placeholder(tf.float32, shape=(None, 1), name="y-input")
# 定义了一个单层的神经网络前向传播的过程,这里就是简单加权和
w1 = tf.Variable(tf.random_normal([2, 1], stddev=1, seed=1))
y = tf.matmul(x, w1)
# 定义预测多了和预测少了的成本
loss_less = 10
loss_more = 1
#在windows下,下面用这个where替代,因为调用tf.select会报错
loss = tf.reduce_sum(tf.where(tf.greater(y, y_), (y - y_)*loss_more, (y_-y)*loss_less))
train_step = tf.train.AdamOptimizer(0.001).minimize(loss)
#通过随机数生成一个模拟数据集
rdm = RandomState(1)
dataset_size = 128
X = rdm.rand(dataset_size, 2)
"""
设置回归的正确值为两个输入的和加上一个随机量,之所以要加上一个随机量是
为了加入不可预测的噪音,否则不同损失函数的意义就不大了,因为不同损失函数
都会在能完全预测正确的时候最低。一般来说,噪音为一个均值为0的小量,所以
这里的噪音设置为-0.05, 0.05的随机数。
"""
Y = [[x1 + x2 + rdm.rand()/10.0-0.05] for (x1, x2) in X]
with tf.Session() as sess:
 init = tf.global_variables_initializer()
 sess.run(init)
 steps = 5000
 for i in range(steps):
  start = (i * batch_size) % dataset_size
  end = min(start + batch_size, dataset_size)
  sess.run(train_step, feed_dict={x:X[start:end], y_:Y[start:end]})
 print(sess.run(w1))

[[ 1.01934695]
[ 1.04280889]

最终结果如上面所示。

因为我们当初生成训练数据的时候,y是x1 + x2,所以回归结果应该是1,1才对。
但是,由于我们加了自己定义的损失函数,所以,倾向于预估多一点。

如果,我们将loss_less和loss_more对调,我们看一下结果:

[[ 0.95525807]
[ 0.9813394 ]]

通过这个例子,我们可以看出,对于相同的神经网络,不同的损失函数会对训练出来的模型产生重要的影响。

引用:以上实例为《Tensorflow实战 Google深度学习框架》中提供。

总结

以上所述是小编给大家介绍的tensorflow 自定义损失函数示例,希望对大家有所帮助!

Python 相关文章推荐
Python设计模式之中介模式简单示例
Jan 09 Python
Python实现的生产者、消费者问题完整实例
May 30 Python
基于python实现聊天室程序
Jul 27 Python
Selenium控制浏览器常见操作示例
Aug 13 Python
详解Numpy数组转置的三种方法T、transpose、swapaxes
May 27 Python
pandas通过字典生成dataframe的方法步骤
Jul 23 Python
python基于socket实现的UDP及TCP通讯功能示例
Nov 01 Python
python求质数列表的例子
Nov 24 Python
python3将变量写入SQL语句的实现方式
Mar 02 Python
python将音频进行变速的操作方法
Apr 08 Python
Keras设定GPU使用内存大小方式(Tensorflow backend)
May 22 Python
python文件目录操作之os模块
May 08 Python
利用Tensorflow的队列多线程读取数据方式
Feb 05 #Python
Tensorflow 多线程与多进程数据加载实例
Feb 05 #Python
TensorFlow自定义损失函数来预测商品销售量
Feb 05 #Python
解决Tensorflow 内存泄露问题
Feb 05 #Python
TensorFlow实现指数衰减学习率的方法
Feb 05 #Python
关于Tensorflow使用CPU报错的解决方式
Feb 05 #Python
解决Tensorflow sess.run导致的内存溢出问题
Feb 05 #Python
You might like
杏林同学录(六)
2006/10/09 PHP
台湾中原大学php教程孙仲岳主讲
2008/01/07 PHP
PHP 文件上传功能实现代码
2009/06/24 PHP
php设置允许大文件上传示例代码
2014/03/10 PHP
解决nginx不支持thinkphp中pathinfo的问题
2015/07/21 PHP
分享php邮件管理器源码
2016/01/06 PHP
php版微信小店API二次开发及使用示例
2016/11/12 PHP
php实现当前页面点击下载文件的实例代码
2016/11/16 PHP
php连接微软MSSQL(sql server)完全攻略
2016/11/27 PHP
JavaScript使用技巧精萃[代码非常实用]
2008/11/21 Javascript
JavaScript控制图片加载完成后调用回调函数的方法
2015/03/20 Javascript
基于js实现二级下拉联动
2016/12/17 Javascript
Canvas + JavaScript 制作图片粒子效果
2017/02/08 Javascript
AngularJS常见过滤器用法实例总结
2017/07/06 Javascript
在iframe中使bootstrap的模态框在父页面弹出问题
2017/08/07 Javascript
vue 标签属性数据绑定和拼接的实现方法
2018/05/17 Javascript
详解基于vue的服务端渲染框架NUXT
2018/06/20 Javascript
解决Vue中引入swiper,在数据渲染的时候,发生不滑动的问题
2018/09/27 Javascript
vue中的mescroll搜索运用及各种填坑处理
2019/10/30 Javascript
微信小程序获取当前时间及星期几的实例代码
2020/09/20 Javascript
Python 字典(Dictionary)操作详解
2014/03/11 Python
Python3读取UTF-8文件及统计文件行数的方法
2015/05/22 Python
Django的URLconf中使用缺省视图参数的方法
2015/07/18 Python
Python的pycurl包用法简介
2015/11/13 Python
浅谈Python 中整型对象的存储问题
2016/05/16 Python
Python中几种导入模块的方式总结
2017/04/27 Python
Python文件的读写和异常代码示例
2017/10/31 Python
python中的闭包函数
2018/02/09 Python
遗传算法python版
2018/03/19 Python
使用Python监控文件内容变化代码实例
2018/06/04 Python
Python 中 function(#) (X)格式 和 (#)在Python3.*中的注意事项
2018/11/30 Python
jupyter使用自动补全和切换默认浏览器的方法
2020/11/18 Python
国外软件测试工程师面试题
2016/12/09 面试题
公务员的自我鉴定
2013/10/26 职场文书
《灰椋鸟》教学反思
2014/04/27 职场文书
市场推广策划方案
2014/06/02 职场文书