TensorFlow自定义损失函数来预测商品销售量


Posted in Python onFebruary 05, 2020

在预测商品销量时,如果预测多了(预测值比真实销量大),商家损失的是生产商品的成本;而如果预测少了(预测值比真实销量小),损失的则是商品的利润。因为一般商品的成本和商品的利润不会严格相等,比如如果一个商品的成本是1元,但是利润是10元,那么少预测一个就少挣10元;而多预测一个才少挣1元,所以如果神经网络模型最小化的是均方误差损失函数,那么很有可能此模型就无法最大化预期的销售利润。

为了最大化预期利润,需要将损失函数和利润直接联系起来,需要注意的是,损失函数定义的是损失,所以要将利润最大化,定义的损失函数应该刻画成本或者代价,下面的公式给出了一个当预测多于真实值和预测少于真实值时有不同损失系数的损失函数:

TensorFlow自定义损失函数来预测商品销售量

其中,yi为一个batch中第i个数据的真实值,yi'为神经网络得到的预测值,a和b是常量,比如在上面介绍的销量预测问题中,a就等于10 (真实值多于预测值的代价),而b等于1 (真实值少于预测值的代价)。

通过对这个自定义损失函数的优化,模型提供的预测值更有可能最大化收益,在TensorFlow中,可以通过以下代码来实现这个损失函数:

loss = tf.reduce_sum(tf.where(tf.greater(y_, y), (y_ - y) * loss_less, (y - y_) * loss_more))

①tf.greater函数的输入是两个张量,此函数会比较这两个输入张量中每一个元素的大小,并返回比较结果,当tf.greater的输入张量维度不一样时,TensorFlow会进行类似NumPy广播操作(broadcasting)的处理;

②tf.where函数有三个参数,第一个为选择条件,当选择条件为True时,tf.where函数会选择第二个参数中的值,否则使用第三个参数中的值,需要注意的是,tf.where函数的判断和选择都是在元素级别进行的。

接下来使用一段TensorFlow代码展示这两个函数的使用:

import tensorflow as tf
v1 = tf.constant([1.0, 2.0, 3.0, 4.0])
v2 = tf.constant([4.0, 3.0, 2.0, 1.0])
with tf.Session() as sess:
 print(sess.run(tf.greater(v1, v2)))
 print(sess.run(tf.where(tf.greater(v1, v2), v1, v2)))
 '''输出结果为:
 [False False True True]
 [4. 3. 3. 4.]'''

在了解如何使用这两个函数之后,我们来看一看刚才的预测商品销售量的实例如何通过具体的TensorFlow代码实现:

import tensorflow as tf
from numpy.random import RandomState

#声明wl、W2两个变量,通过seed参数设定了随机种子,这样可以保证每次运行得到的结果是一样的
w = tf.Variable(tf.random_normal([2, 1], stddev=1, seed=1))

x = tf.placeholder(tf.float32, shape=(None, 2), name="x-input")
y_ = tf.placeholder(tf.float32, shape=(None, 1), name="y-input")

#定义神经网络结构
y = tf.matmul(x, w)

#定义真实值与预测值之间的交叉熵损失函数,来刻画真实值与预测值之间的差距
loss_less = 10
loss_more = 1
loss = tf.reduce_sum(tf.where(tf.greater(y_, y), (y_ - y) * loss_less, (y - y_) * loss_more))

#定义反向传播算法的优化方法
train_step = tf.train.AdamOptimizer(learning_rate=0.001).minimize(loss)

#设置随机数种子
rdm = RandomState(seed=1)
#设置随机数据集大小
dataset_size = 128
X = rdm.rand(dataset_size, 2)
'''设置回归的正确值为两个输入的和加上一个随机量。
之所以要加上一个随机量是为了加入不可预测的噪音,否则不同损失函数的意义就不大了,因为不同损失函数都会在能完全预测正确的时候最低。
一般来说噪音为一个均值为0的小量,所以这里的噪音设置为-0.05——0.05的随机数。'''
Y = [[x1 + x2 + rdm.rand()/10.0 -0.05] for x1,x2 in X]

#创建会话
with tf.Session() as sess:
 #初始化变量
 init_op = tf.global_variables_initializer()
 sess.run(init_op)
 
 print(sess.run(w))
 
 #设置batch训练数据的大小
 batch_size = 8
 #设置训练得轮数
 STEPS = 5000
 for i in range(STEPS):
  #每次选取batch_size个样本进行训练
  start = (i * batch_size) % dataset_size
  end = min(start + batch_size, dataset_size)

  #通过选取的样本训练神经网络并更新参数
  sess.run(train_step, feed_dict={x:X[start:end], y_:Y[start:end]})
 
 print(sess.run(w))
 '''输出结果为:
 [[-0.8113182]
 [ 1.4845988]]
 [[1.019347 ]
 [1.0428089]]'''

可以看到参数w优化后,预测函数为1.019347 * x1 + 1.0428089 * x2,显然是大于实际的预测函数x1 + x2的,这是因为我们的损失函数中指定预测少了的损失更大(loss_less > loss_more),所以模型会偏向于预测多一点。

如果我们更换代码,改为:

loss_less = 1
loss_more = 10

那么我们的结果就会变为:

[[-0.8113182]
 [ 1.4845988]]
[[0.95561105]
 [0.98101896]]

预测函数变为了0.95561105 * x1 + 0.98101896 * x2,可以看到这时候模型就会偏向于预测少一点。

因此,我们可以得出结论:对于相同的神经网络,不同的损失函数会对训练得到的模型产生不同效果。

总结

以上所述是小编给大家介绍的TensorFlow自定义损失函数来预测商品销售量,希望对大家有所帮助!

Python 相关文章推荐
用PyQt进行Python图形界面的程序的开发的入门指引
Apr 14 Python
python数组复制拷贝的实现方法
Jun 09 Python
Python函数中*args和**kwargs来传递变长参数的用法
Jan 26 Python
Python数据结构之顺序表的实现代码示例
Nov 15 Python
[原创]python爬虫(入门教程、视频教程)
Jan 08 Python
Python实现简单http服务器
Apr 12 Python
详解Python中pandas的安装操作说明(傻瓜版)
Apr 08 Python
解决django服务器重启端口被占用的问题
Jul 26 Python
用pandas划分数据集实现训练集和测试集
Jul 20 Python
如何将json数据转换为python数据
Sep 04 Python
python 中yaml文件用法大全
Jul 04 Python
Python可视化神器pyecharts绘制水球图
Jul 07 Python
解决Tensorflow 内存泄露问题
Feb 05 #Python
TensorFlow实现指数衰减学习率的方法
Feb 05 #Python
关于Tensorflow使用CPU报错的解决方式
Feb 05 #Python
解决Tensorflow sess.run导致的内存溢出问题
Feb 05 #Python
解决TensorFlow训练内存不断增长,进程被杀死问题
Feb 05 #Python
浅谈tensorflow之内存暴涨问题
Feb 05 #Python
对Tensorflow中Device实例的生成和管理详解
Feb 04 #Python
You might like
按上下级层次关系输出内容的PHP代码
2010/07/17 PHP
PHP实现图片压缩的两则实例
2014/07/19 PHP
php获取汉字拼音首字母的方法
2015/10/21 PHP
PHP 计算两个时间段之间交集的天数示例
2019/10/24 PHP
提高 DHTML 页面性能
2006/12/25 Javascript
Javascript 个人笔记(没有整理,很乱)
2007/07/07 Javascript
js 未结束的字符串常量错误解决方法
2010/06/13 Javascript
枚举的实现求得1-1000所有出现1的数字并计算出现1的个数
2013/09/10 Javascript
JS实现上下左右对称的九九乘法表
2016/02/22 Javascript
JavaScript快速切换繁体中文和简体中文的方法及网站支持简繁体切换的绝招
2016/03/07 Javascript
JavaScript根据CSS的Media Queries来判断浏览设备的方法
2016/05/10 Javascript
vue.js实现单选框、复选框和下拉框示例
2017/07/18 Javascript
JavaScript实现三级联动菜单效果
2017/08/16 Javascript
Vue 项目分环境打包的方法示例
2018/08/03 Javascript
Vuejs+vue-router打包+Nginx配置的实例
2018/09/20 Javascript
vue-cli3.0 环境变量与模式配置方法
2018/11/08 Javascript
es6中比较有用的7个技巧小结
2019/07/12 Javascript
[51:15]2014 DOTA2国际邀请赛中国区预选赛 Orenda VS LGD-GAMING
2014/05/22 DOTA
[46:43]DOTA2上海特级锦标赛D组小组赛#1 EG VS COL第三局
2016/02/28 DOTA
Flask的图形化管理界面搭建框架Flask-Admin的使用教程
2016/06/13 Python
python检查URL是否正常访问的小技巧
2017/02/25 Python
对python中的argv和argc使用详解
2018/12/15 Python
Python图像的增强处理操作示例【基于ImageEnhance类】
2019/01/03 Python
Django Channels 实现点对点实时聊天和消息推送功能
2019/07/17 Python
Python 使用 environs 库定义环境变量的方法
2020/02/25 Python
简单了解python列表和元组的区别
2020/05/14 Python
HTML5 创建canvas元素示例代码
2014/06/04 HTML / CSS
main 主函数执行完毕后,是否可能会再执行一段代码,给出说明
2012/12/05 面试题
HR喜欢的自荐信格式
2013/10/08 职场文书
房屋改造计划书
2014/01/10 职场文书
大学毕业感言50字
2014/02/07 职场文书
家长通知书家长评语
2014/04/17 职场文书
公安机关纪律作风整顿个人剖析材料材料
2014/10/10 职场文书
观看建国大业观后感
2015/06/01 职场文书
Python机器学习之基于Pytorch实现猫狗分类
2021/06/08 Python
使用pd.merge表连接出现多余行的问题解决
2022/06/16 Python