TensorFlow自定义损失函数来预测商品销售量


Posted in Python onFebruary 05, 2020

在预测商品销量时,如果预测多了(预测值比真实销量大),商家损失的是生产商品的成本;而如果预测少了(预测值比真实销量小),损失的则是商品的利润。因为一般商品的成本和商品的利润不会严格相等,比如如果一个商品的成本是1元,但是利润是10元,那么少预测一个就少挣10元;而多预测一个才少挣1元,所以如果神经网络模型最小化的是均方误差损失函数,那么很有可能此模型就无法最大化预期的销售利润。

为了最大化预期利润,需要将损失函数和利润直接联系起来,需要注意的是,损失函数定义的是损失,所以要将利润最大化,定义的损失函数应该刻画成本或者代价,下面的公式给出了一个当预测多于真实值和预测少于真实值时有不同损失系数的损失函数:

TensorFlow自定义损失函数来预测商品销售量

其中,yi为一个batch中第i个数据的真实值,yi'为神经网络得到的预测值,a和b是常量,比如在上面介绍的销量预测问题中,a就等于10 (真实值多于预测值的代价),而b等于1 (真实值少于预测值的代价)。

通过对这个自定义损失函数的优化,模型提供的预测值更有可能最大化收益,在TensorFlow中,可以通过以下代码来实现这个损失函数:

loss = tf.reduce_sum(tf.where(tf.greater(y_, y), (y_ - y) * loss_less, (y - y_) * loss_more))

①tf.greater函数的输入是两个张量,此函数会比较这两个输入张量中每一个元素的大小,并返回比较结果,当tf.greater的输入张量维度不一样时,TensorFlow会进行类似NumPy广播操作(broadcasting)的处理;

②tf.where函数有三个参数,第一个为选择条件,当选择条件为True时,tf.where函数会选择第二个参数中的值,否则使用第三个参数中的值,需要注意的是,tf.where函数的判断和选择都是在元素级别进行的。

接下来使用一段TensorFlow代码展示这两个函数的使用:

import tensorflow as tf
v1 = tf.constant([1.0, 2.0, 3.0, 4.0])
v2 = tf.constant([4.0, 3.0, 2.0, 1.0])
with tf.Session() as sess:
 print(sess.run(tf.greater(v1, v2)))
 print(sess.run(tf.where(tf.greater(v1, v2), v1, v2)))
 '''输出结果为:
 [False False True True]
 [4. 3. 3. 4.]'''

在了解如何使用这两个函数之后,我们来看一看刚才的预测商品销售量的实例如何通过具体的TensorFlow代码实现:

import tensorflow as tf
from numpy.random import RandomState

#声明wl、W2两个变量,通过seed参数设定了随机种子,这样可以保证每次运行得到的结果是一样的
w = tf.Variable(tf.random_normal([2, 1], stddev=1, seed=1))

x = tf.placeholder(tf.float32, shape=(None, 2), name="x-input")
y_ = tf.placeholder(tf.float32, shape=(None, 1), name="y-input")

#定义神经网络结构
y = tf.matmul(x, w)

#定义真实值与预测值之间的交叉熵损失函数,来刻画真实值与预测值之间的差距
loss_less = 10
loss_more = 1
loss = tf.reduce_sum(tf.where(tf.greater(y_, y), (y_ - y) * loss_less, (y - y_) * loss_more))

#定义反向传播算法的优化方法
train_step = tf.train.AdamOptimizer(learning_rate=0.001).minimize(loss)

#设置随机数种子
rdm = RandomState(seed=1)
#设置随机数据集大小
dataset_size = 128
X = rdm.rand(dataset_size, 2)
'''设置回归的正确值为两个输入的和加上一个随机量。
之所以要加上一个随机量是为了加入不可预测的噪音,否则不同损失函数的意义就不大了,因为不同损失函数都会在能完全预测正确的时候最低。
一般来说噪音为一个均值为0的小量,所以这里的噪音设置为-0.05——0.05的随机数。'''
Y = [[x1 + x2 + rdm.rand()/10.0 -0.05] for x1,x2 in X]

#创建会话
with tf.Session() as sess:
 #初始化变量
 init_op = tf.global_variables_initializer()
 sess.run(init_op)
 
 print(sess.run(w))
 
 #设置batch训练数据的大小
 batch_size = 8
 #设置训练得轮数
 STEPS = 5000
 for i in range(STEPS):
  #每次选取batch_size个样本进行训练
  start = (i * batch_size) % dataset_size
  end = min(start + batch_size, dataset_size)

  #通过选取的样本训练神经网络并更新参数
  sess.run(train_step, feed_dict={x:X[start:end], y_:Y[start:end]})
 
 print(sess.run(w))
 '''输出结果为:
 [[-0.8113182]
 [ 1.4845988]]
 [[1.019347 ]
 [1.0428089]]'''

可以看到参数w优化后,预测函数为1.019347 * x1 + 1.0428089 * x2,显然是大于实际的预测函数x1 + x2的,这是因为我们的损失函数中指定预测少了的损失更大(loss_less > loss_more),所以模型会偏向于预测多一点。

如果我们更换代码,改为:

loss_less = 1
loss_more = 10

那么我们的结果就会变为:

[[-0.8113182]
 [ 1.4845988]]
[[0.95561105]
 [0.98101896]]

预测函数变为了0.95561105 * x1 + 0.98101896 * x2,可以看到这时候模型就会偏向于预测少一点。

因此,我们可以得出结论:对于相同的神经网络,不同的损失函数会对训练得到的模型产生不同效果。

总结

以上所述是小编给大家介绍的TensorFlow自定义损失函数来预测商品销售量,希望对大家有所帮助!

Python 相关文章推荐
python实现将文本转换成语音的方法
May 28 Python
Python入门之后再看点什么好?
Mar 05 Python
解决Django后台ManyToManyField显示成Object的问题
Aug 09 Python
Django对models里的objects的使用详解
Aug 17 Python
python爬虫-模拟微博登录功能
Sep 12 Python
python 实现turtle画图并导出图片格式的文件
Dec 07 Python
Python程序控制语句用法实例分析
Jan 14 Python
pytorch的batch normalize使用详解
Jan 15 Python
Win10下安装并使用tensorflow-gpu1.8.0+python3.6全过程分析(显卡MX250+CUDA9.0+cudnn)
Feb 17 Python
Python 读取xml数据,cv2裁剪图片实例
Mar 10 Python
如何将PySpark导入Python的放实现(2种)
Apr 26 Python
django rest framework 过滤时间操作
Jul 12 Python
解决Tensorflow 内存泄露问题
Feb 05 #Python
TensorFlow实现指数衰减学习率的方法
Feb 05 #Python
关于Tensorflow使用CPU报错的解决方式
Feb 05 #Python
解决Tensorflow sess.run导致的内存溢出问题
Feb 05 #Python
解决TensorFlow训练内存不断增长,进程被杀死问题
Feb 05 #Python
浅谈tensorflow之内存暴涨问题
Feb 05 #Python
对Tensorflow中Device实例的生成和管理详解
Feb 04 #Python
You might like
PHP的Yii框架中Model模型的学习教程
2016/03/29 PHP
ThinkPHP开发--使用七牛云储存
2017/09/14 PHP
js新闻滚动 js如何实现新闻滚动效果
2013/01/07 Javascript
用jQuery获取IE9下拉框默认值问题探讨
2013/07/22 Javascript
jquery利用ajax调用后台方法实例
2013/08/23 Javascript
javascript unicode与GBK2312(中文)编码转换方法
2013/11/14 Javascript
js二维数组定义和初始化的三种方法总结
2014/03/03 Javascript
简介JavaScript中search()方法的使用
2015/06/06 Javascript
JavaScript获得url查询参数的方法
2015/07/02 Javascript
基于javascript显示当前时间以及倒计时功能
2016/03/18 Javascript
jquery Banner轮播选项卡
2016/12/26 Javascript
javascript构造函数以及原型对象的理解
2017/01/13 Javascript
Express+Nodejs 下的登录拦截实现代码
2017/07/01 NodeJs
vue.js实例对象+组件树的详细介绍
2017/10/20 Javascript
JS实现显示当前日期的实例代码
2018/07/03 Javascript
35个最好用的Vue开源库(史上最全)
2019/01/03 Javascript
我所理解的JavaScript中的this指向
2020/09/04 Javascript
[02:43]中国五虎出征TI3视频
2013/08/02 DOTA
Python xlrd读取excel日期类型的2种方法
2015/04/28 Python
python简单获取本机计算机名和IP地址的方法
2015/06/03 Python
Tensorflow实现卷积神经网络的详细代码
2018/05/24 Python
Python 微信之获取好友昵称并制作wordcloud的实例
2019/02/21 Python
python3.7通过thrift操作hbase的示例代码
2020/01/14 Python
opencv中图像叠加/图像融合/按位操作的实现
2020/04/01 Python
Django多数据库联用实现方法解析
2020/11/12 Python
浅谈css3新单位vw、vh、vmin、vmax的使用详解
2017/12/01 HTML / CSS
HTML5对手机页面长按会粘贴复制禁用的解决方法
2016/07/19 HTML / CSS
大一工商管理职业生涯规划:有梦最美,行动相随
2014/09/18 职场文书
文体活动总结
2015/02/04 职场文书
2015毕业设计工作总结
2015/07/24 职场文书
导游词之台湾阿里山
2019/10/23 职场文书
mysql批量新增和存储的方法实例
2021/04/07 MySQL
SQL Server2019数据库之简单子查询的具有方法
2021/04/27 SQL Server
新手入门Mysql--sql执行过程
2021/06/20 MySQL
JavaScript高级程序设计之基本引用类型
2021/11/17 Javascript
浅谈css清除浮动(clearfix和clear)的用法
2023/05/21 HTML / CSS