tensorflow 自定义损失函数示例代码


Posted in Python onFebruary 05, 2020

这个自定义损失函数的背景:(一般回归用的损失函数是MSE, 但要看实际遇到的情况而有所改变)

我们现在想要做一个回归,来预估某个商品的销量,现在我们知道,一件商品的成本是1元,售价是10元。

如果我们用均方差来算的话,如果预估多一个,则损失一块钱,预估少一个,则损失9元钱(少赚的)。

显然,我宁愿预估多了,也不想预估少了。

所以,我们就自己定义一个损失函数,用来分段地看,当yhat 比 y大时怎么样,当yhat比y小时怎么样。

(yhat沿用吴恩达课堂中的叫法)

import tensorflow as tf
from numpy.random import RandomState
batch_size = 8
# 两个输入节点
x = tf.placeholder(tf.float32, shape=(None, 2), name="x-input")
# 回归问题一般只有一个输出节点
y_ = tf.placeholder(tf.float32, shape=(None, 1), name="y-input")
# 定义了一个单层的神经网络前向传播的过程,这里就是简单加权和
w1 = tf.Variable(tf.random_normal([2, 1], stddev=1, seed=1))
y = tf.matmul(x, w1)
# 定义预测多了和预测少了的成本
loss_less = 10
loss_more = 1
#在windows下,下面用这个where替代,因为调用tf.select会报错
loss = tf.reduce_sum(tf.where(tf.greater(y, y_), (y - y_)*loss_more, (y_-y)*loss_less))
train_step = tf.train.AdamOptimizer(0.001).minimize(loss)
#通过随机数生成一个模拟数据集
rdm = RandomState(1)
dataset_size = 128
X = rdm.rand(dataset_size, 2)
"""
设置回归的正确值为两个输入的和加上一个随机量,之所以要加上一个随机量是
为了加入不可预测的噪音,否则不同损失函数的意义就不大了,因为不同损失函数
都会在能完全预测正确的时候最低。一般来说,噪音为一个均值为0的小量,所以
这里的噪音设置为-0.05, 0.05的随机数。
"""
Y = [[x1 + x2 + rdm.rand()/10.0-0.05] for (x1, x2) in X]
with tf.Session() as sess:
 init = tf.global_variables_initializer()
 sess.run(init)
 steps = 5000
 for i in range(steps):
  start = (i * batch_size) % dataset_size
  end = min(start + batch_size, dataset_size)
  sess.run(train_step, feed_dict={x:X[start:end], y_:Y[start:end]})
 print(sess.run(w1))

[[ 1.01934695]
[ 1.04280889]

最终结果如上面所示。

因为我们当初生成训练数据的时候,y是x1 + x2,所以回归结果应该是1,1才对。
但是,由于我们加了自己定义的损失函数,所以,倾向于预估多一点。

如果,我们将loss_less和loss_more对调,我们看一下结果:

[[ 0.95525807]
[ 0.9813394 ]]

通过这个例子,我们可以看出,对于相同的神经网络,不同的损失函数会对训练出来的模型产生重要的影响。

引用:以上实例为《Tensorflow实战 Google深度学习框架》中提供。

总结

以上所述是小编给大家介绍的tensorflow 自定义损失函数示例,希望对大家有所帮助!

Python 相关文章推荐
python list语法学习(带例子)
Nov 01 Python
python基础入门详解(文件输入/输出 内建类型 字典操作使用方法)
Dec 08 Python
Python的Bottle框架的一些使用技巧介绍
Apr 08 Python
python字典排序实例详解
May 20 Python
Python+树莓派+YOLO打造一款人工智能照相机
Jan 02 Python
Python实现的括号匹配判断功能示例
Aug 25 Python
TensorFlow内存管理bfc算法实例
Feb 03 Python
Python字符编码转码之GBK,UTF8互转
Feb 09 Python
解决keras使用cov1D函数的输入问题
Jun 29 Python
Python环境搭建过程从安装到Hello World
Feb 05 Python
python中if和elif的区别介绍
Nov 07 Python
python在package下继续嵌套一个package
Apr 14 Python
利用Tensorflow的队列多线程读取数据方式
Feb 05 #Python
Tensorflow 多线程与多进程数据加载实例
Feb 05 #Python
TensorFlow自定义损失函数来预测商品销售量
Feb 05 #Python
解决Tensorflow 内存泄露问题
Feb 05 #Python
TensorFlow实现指数衰减学习率的方法
Feb 05 #Python
关于Tensorflow使用CPU报错的解决方式
Feb 05 #Python
解决Tensorflow sess.run导致的内存溢出问题
Feb 05 #Python
You might like
PHP数组内存耗用太多问题的解决方法
2010/04/05 PHP
编译PHP报错configure error Cannot find libmysqlclient under usr的解决方法
2014/06/27 PHP
PHP对象实例化单例方法
2017/01/19 PHP
javascript 设置文本框中焦点的位置
2009/11/20 Javascript
xss文件页面内容读取(解决)
2010/11/28 Javascript
Jquery attr("checked") 返回checked或undefined 获取选中失效
2013/10/10 Javascript
javascript创建和存储cookie示例
2014/01/07 Javascript
如何防止回车(enter)键提交表单
2014/05/11 Javascript
2014 年最热门的21款JavaScript框架推荐
2014/12/25 Javascript
javascript+html5实现绘制圆环的方法
2015/07/28 Javascript
js实现简洁的滑动门菜单(选项卡)效果代码
2015/09/04 Javascript
jquery实现图片预加载
2015/12/25 Javascript
textarea 在浏览器中固定大小和禁止拖动的实现方法
2016/12/03 Javascript
Bootstrap表格使用方法详解
2017/02/17 Javascript
javascript回调函数详解
2018/02/06 Javascript
vue+elementUI实现图片上传功能
2019/08/20 Javascript
如何在Vue.JS中使用图标组件
2020/08/04 Javascript
python文件操作相关知识点总结整理
2016/02/22 Python
JSON文件及Python对JSON文件的读写操作
2018/10/07 Python
python版DDOS攻击脚本
2019/06/12 Python
使用Tkinter制作信息提示框
2020/02/18 Python
Python3爬虫中关于中文分词的详解
2020/07/29 Python
Python用Jira库来操作Jira
2020/12/28 Python
前端面试必备之CSS3的新特性
2017/09/05 HTML / CSS
详解使用HTML5 Canvas创建动态粒子网格动画
2016/12/14 HTML / CSS
美国摩托车头盔、零件、齿轮及配件商店:Cycle Gear
2019/06/12 全球购物
教师学习培训邀请函
2014/02/04 职场文书
餐厅经理岗位职责范本
2014/02/17 职场文书
办公室主任岗位承诺书
2014/05/29 职场文书
幼儿园大班教师个人工作总结
2015/02/05 职场文书
人口与计划生育责任书
2015/05/09 职场文书
学雷锋活动简报
2015/07/20 职场文书
经销商会议开幕词
2016/03/04 职场文书
python数据库批量插入数据的实现(executemany的使用)
2021/04/30 Python
MySQL常见优化方案汇总
2022/01/18 MySQL
Flutter集成高德地图并添加自定义Maker的实践
2022/04/07 Java/Android