keras打印loss对权重的导数方式


Posted in Python onJune 10, 2020

Notes

怀疑模型梯度爆炸,想打印模型 loss 对各权重的导数看看。如果如果fit来训练的话,可以用keras.callbacks.TensorBoard实现。

但此次使用train_on_batch来训练的,用K.gradients和K.function实现。

Codes

以一份 VAE 代码为例

# -*- coding: utf8 -*-
import keras
from keras.models import Model
from keras.layers import Input, Lambda, Conv2D, MaxPooling2D, Flatten, Dense, Reshape
from keras.losses import binary_crossentropy
from keras.datasets import mnist, fashion_mnist
import keras.backend as K
from scipy.stats import norm
import numpy as np
import matplotlib.pyplot as plt

BATCH = 128
N_CLASS = 10
EPOCH = 5
IN_DIM = 28 * 28
H_DIM = 128
Z_DIM = 2

(x_train, y_train), (x_test, y_test) = fashion_mnist.load_data()
x_train = x_train.reshape(len(x_train), -1).astype('float32') / 255.
x_test = x_test.reshape(len(x_test), -1).astype('float32') / 255.

def sampleing(args):
  """reparameterize"""
  mu, logvar = args
  eps = K.random_normal([K.shape(mu)[0], Z_DIM], mean=0.0, stddev=1.0)
  return mu + eps * K.exp(logvar / 2.)

# encode
x_in = Input([IN_DIM])
h = Dense(H_DIM, activation='relu')(x_in)
z_mu = Dense(Z_DIM)(h) # mean,不用激活
z_logvar = Dense(Z_DIM)(h) # log variance,不用激活
z = Lambda(sampleing, output_shape=[Z_DIM])([z_mu, z_logvar]) # 只能有一个参数
encoder = Model(x_in, [z_mu, z_logvar, z], name='encoder')

# decode
z_in = Input([Z_DIM])
h_hat = Dense(H_DIM, activation='relu')(z_in)
x_hat = Dense(IN_DIM, activation='sigmoid')(h_hat)
decoder = Model(z_in, x_hat, name='decoder')

# VAE
x_in = Input([IN_DIM])
x = x_in
z_mu, z_logvar, z = encoder(x)
x = decoder(z)
out = x
vae = Model(x_in, [out, out], name='vae')

# loss_kl = 0.5 * K.sum(K.square(z_mu) + K.exp(z_logvar) - 1. - z_logvar, axis=1)
# loss_recon = binary_crossentropy(K.reshape(vae_in, [-1, IN_DIM]), vae_out) * IN_DIM
# loss_vae = K.mean(loss_kl + loss_recon)

def loss_kl(y_true, y_pred):
  return 0.5 * K.sum(K.square(z_mu) + K.exp(z_logvar) - 1. - z_logvar, axis=1)


# vae.add_loss(loss_vae)
vae.compile(optimizer='rmsprop',
      loss=[loss_kl, 'binary_crossentropy'],
      loss_weights=[1, IN_DIM])
vae.summary()

# 获取模型权重 variable
w = vae.trainable_weights
print(w)

# 打印 KL 对权重的导数
# KL 要是 Tensor,不能是上面的函数 `loss_kl`
grad = K.gradients(0.5 * K.sum(K.square(z_mu) + K.exp(z_logvar) - 1. - z_logvar, axis=1),
          w)
print(grad) # 有些是 None 的
grad = grad[grad is not None] # 去掉 None,不然报错

# 打印梯度的函数
# K.function 的输入和输出必要是 list!就算只有一个
show_grad = K.function([vae.input], [grad])

# vae.fit(x_train, # y_train, # 不能传 y_train
#     batch_size=BATCH,
#     epochs=EPOCH,
#     verbose=1,
#     validation_data=(x_test, None))

''' 以 train_on_batch 方式训练 '''
for epoch in range(EPOCH):
  for b in range(x_train.shape[0] // BATCH):
    idx = np.random.choice(x_train.shape[0], BATCH)
    x = x_train[idx]
    l = vae.train_on_batch([x], [x, x])

  # 计算梯度
  gd = show_grad([x])
  # 打印梯度
  print(gd)

# show manifold
PIXEL = 28
N_PICT = 30
grid_x = norm.ppf(np.linspace(0.05, 0.95, N_PICT))
grid_y = grid_x

figure = np.zeros([N_PICT * PIXEL, N_PICT * PIXEL])
for i, xi in enumerate(grid_x):
  for j, yj in enumerate(grid_y):
    noise = np.array([[xi, yj]]) # 必须秩为 2,两层中括号
    x_gen = decoder.predict(noise)
    # print('x_gen shape:', x_gen.shape)
    x_gen = x_gen[0].reshape([PIXEL, PIXEL])
    figure[i * PIXEL: (i+1) * PIXEL,
        j * PIXEL: (j+1) * PIXEL] = x_gen

fig = plt.figure(figsize=(10, 10))
plt.imshow(figure, cmap='Greys_r')
fig.savefig('./variational_autoencoder.png')
plt.show()

补充知识:keras 自定义损失 自动求导时出现None

问题记录,keras 自定义损失 自动求导时出现None,后来想到是因为传入的变量没有使用,所以keras无法求出偏导,修改后问题解决。就是不愿使用的变量×0,求导后还是0就可以了。

def my_complex_loss_graph(y_label, emb_uid, lstm_out,y_true_1,y_true_2,y_true_3,out_1,out_2,out_3):
 
  mse_out_1 = mean_squared_error(y_true_1, out_1)
  mse_out_2 = mean_squared_error(y_true_2, out_2)
  mse_out_3 = mean_squared_error(y_true_3, out_3)
  # emb_uid= K.reshape(emb_uid, [-1, 32])
  cosine_sim = tf.reduce_sum(0.5*tf.square(emb_uid-lstm_out))
 
  cost=0*cosine_sim+K.sum([0.5*mse_out_1 , 0.25*mse_out_2,0.25*mse_out_3],axis=1,keepdims=True)
  # print(mse_out_1)
  final_loss = cost
 
  return K.mean(final_loss)

以上这篇keras打印loss对权重的导数方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python2.x版本中cmp()方法的使用教程
May 14 Python
python3序列化与反序列化用法实例
May 26 Python
一篇文章入门Python生态系统(Python新手入门指导)
Dec 11 Python
python paramiko模块学习分享
Aug 23 Python
Python实现的维尼吉亚密码算法示例
Apr 12 Python
python2.7实现邮件发送功能
Dec 12 Python
对python多线程中互斥锁Threading.Lock的简单应用详解
Jan 11 Python
PyQt5基本控件使用之消息弹出、用户输入、文件对话框的使用方法
Aug 06 Python
Python小程序 控制鼠标循环点击代码实例
Oct 08 Python
Python Pillow.Image 图像保存和参数选择方式
Jan 09 Python
python实现简单区块链结构
Apr 25 Python
python自动获取微信公众号最新文章的实现代码
Jul 15 Python
Python xlrd模块导入过程及常用操作
Jun 10 #Python
keras-siamese用自己的数据集实现详解
Jun 10 #Python
python实现mean-shift聚类算法
Jun 10 #Python
Keras之自定义损失(loss)函数用法说明
Jun 10 #Python
Python xlwt模块使用代码实例
Jun 10 #Python
python中def是做什么的
Jun 10 #Python
keras实现调用自己训练的模型,并去掉全连接层
Jun 09 #Python
You might like
Discuz 6.0+ 批量注册用户名
2009/09/13 PHP
PHP中的Memcache详解
2014/04/05 PHP
[原创]php逐行读取txt文件写入数组的方法
2015/07/02 PHP
php时间函数用法分析
2016/05/28 PHP
PHP与以太坊交互详解
2018/08/24 PHP
基于Jquery的跨域传输数据(JSONP)
2011/03/10 Javascript
40个有创意的jQuery图片、内容滑动及弹出插件收藏集之一
2011/12/31 Javascript
Javascript绝句欣赏 一些经典的js代码
2012/02/22 Javascript
iframe的父子窗口之间的对象相互调用基本用法
2013/09/03 Javascript
js中的onchange和onpropertychange (onchange无效的解决方法)
2014/03/08 Javascript
使用javascript实现简单的选项卡切换
2015/01/09 Javascript
jquery实现Ctrl+Enter提交表单的方法
2015/07/21 Javascript
使用bootstrap实现下拉框搜索功能的实例讲解
2018/08/10 Javascript
小程序测试后台服务的方法(ngrok)
2019/03/08 Javascript
微信小程序picker组件关于objectArray数据类型的绑定方法
2019/03/13 Javascript
vue权限问题的完美解决方案
2019/05/08 Javascript
原生JS与CSS实现软件卸载对话框功能
2019/12/05 Javascript
vue基于v-charts封装双向条形图的实现代码
2019/12/09 Javascript
[04:52]DOTA2亚洲邀请赛附加赛 TOP10精彩集锦
2015/01/29 DOTA
Python的加密模块md5、sha、crypt使用实例
2014/09/28 Python
Python Flask-web表单使用详解
2017/11/18 Python
PyTorch快速搭建神经网络及其保存提取方法详解
2018/04/28 Python
Python爬虫获取图片并下载保存至本地的实例
2018/06/01 Python
详解Python连接MySQL数据库的多种方式
2019/04/16 Python
解决Python3 抓取微信账单信息问题
2019/07/19 Python
详解pandas中MultiIndex和对象实际索引不一致问题
2019/07/23 Python
基于Python爬取爱奇艺资源过程解析
2020/03/02 Python
移动web模拟客户端实现多方框输入密码效果【附代码】
2016/03/25 HTML / CSS
音乐专业应届生教师求职信
2013/11/04 职场文书
老师给学生的表扬信
2014/01/17 职场文书
门前三包责任书
2014/04/15 职场文书
《长江之歌》教学反思
2014/04/17 职场文书
地球一小时宣传标语
2014/06/24 职场文书
“向国旗敬礼”活动策划方案(4篇)
2014/09/27 职场文书
2014年教师学期工作总结
2014/11/08 职场文书
教师网络培训心得体会
2016/01/09 职场文书