keras打印loss对权重的导数方式


Posted in Python onJune 10, 2020

Notes

怀疑模型梯度爆炸,想打印模型 loss 对各权重的导数看看。如果如果fit来训练的话,可以用keras.callbacks.TensorBoard实现。

但此次使用train_on_batch来训练的,用K.gradients和K.function实现。

Codes

以一份 VAE 代码为例

# -*- coding: utf8 -*-
import keras
from keras.models import Model
from keras.layers import Input, Lambda, Conv2D, MaxPooling2D, Flatten, Dense, Reshape
from keras.losses import binary_crossentropy
from keras.datasets import mnist, fashion_mnist
import keras.backend as K
from scipy.stats import norm
import numpy as np
import matplotlib.pyplot as plt

BATCH = 128
N_CLASS = 10
EPOCH = 5
IN_DIM = 28 * 28
H_DIM = 128
Z_DIM = 2

(x_train, y_train), (x_test, y_test) = fashion_mnist.load_data()
x_train = x_train.reshape(len(x_train), -1).astype('float32') / 255.
x_test = x_test.reshape(len(x_test), -1).astype('float32') / 255.

def sampleing(args):
  """reparameterize"""
  mu, logvar = args
  eps = K.random_normal([K.shape(mu)[0], Z_DIM], mean=0.0, stddev=1.0)
  return mu + eps * K.exp(logvar / 2.)

# encode
x_in = Input([IN_DIM])
h = Dense(H_DIM, activation='relu')(x_in)
z_mu = Dense(Z_DIM)(h) # mean,不用激活
z_logvar = Dense(Z_DIM)(h) # log variance,不用激活
z = Lambda(sampleing, output_shape=[Z_DIM])([z_mu, z_logvar]) # 只能有一个参数
encoder = Model(x_in, [z_mu, z_logvar, z], name='encoder')

# decode
z_in = Input([Z_DIM])
h_hat = Dense(H_DIM, activation='relu')(z_in)
x_hat = Dense(IN_DIM, activation='sigmoid')(h_hat)
decoder = Model(z_in, x_hat, name='decoder')

# VAE
x_in = Input([IN_DIM])
x = x_in
z_mu, z_logvar, z = encoder(x)
x = decoder(z)
out = x
vae = Model(x_in, [out, out], name='vae')

# loss_kl = 0.5 * K.sum(K.square(z_mu) + K.exp(z_logvar) - 1. - z_logvar, axis=1)
# loss_recon = binary_crossentropy(K.reshape(vae_in, [-1, IN_DIM]), vae_out) * IN_DIM
# loss_vae = K.mean(loss_kl + loss_recon)

def loss_kl(y_true, y_pred):
  return 0.5 * K.sum(K.square(z_mu) + K.exp(z_logvar) - 1. - z_logvar, axis=1)


# vae.add_loss(loss_vae)
vae.compile(optimizer='rmsprop',
      loss=[loss_kl, 'binary_crossentropy'],
      loss_weights=[1, IN_DIM])
vae.summary()

# 获取模型权重 variable
w = vae.trainable_weights
print(w)

# 打印 KL 对权重的导数
# KL 要是 Tensor,不能是上面的函数 `loss_kl`
grad = K.gradients(0.5 * K.sum(K.square(z_mu) + K.exp(z_logvar) - 1. - z_logvar, axis=1),
          w)
print(grad) # 有些是 None 的
grad = grad[grad is not None] # 去掉 None,不然报错

# 打印梯度的函数
# K.function 的输入和输出必要是 list!就算只有一个
show_grad = K.function([vae.input], [grad])

# vae.fit(x_train, # y_train, # 不能传 y_train
#     batch_size=BATCH,
#     epochs=EPOCH,
#     verbose=1,
#     validation_data=(x_test, None))

''' 以 train_on_batch 方式训练 '''
for epoch in range(EPOCH):
  for b in range(x_train.shape[0] // BATCH):
    idx = np.random.choice(x_train.shape[0], BATCH)
    x = x_train[idx]
    l = vae.train_on_batch([x], [x, x])

  # 计算梯度
  gd = show_grad([x])
  # 打印梯度
  print(gd)

# show manifold
PIXEL = 28
N_PICT = 30
grid_x = norm.ppf(np.linspace(0.05, 0.95, N_PICT))
grid_y = grid_x

figure = np.zeros([N_PICT * PIXEL, N_PICT * PIXEL])
for i, xi in enumerate(grid_x):
  for j, yj in enumerate(grid_y):
    noise = np.array([[xi, yj]]) # 必须秩为 2,两层中括号
    x_gen = decoder.predict(noise)
    # print('x_gen shape:', x_gen.shape)
    x_gen = x_gen[0].reshape([PIXEL, PIXEL])
    figure[i * PIXEL: (i+1) * PIXEL,
        j * PIXEL: (j+1) * PIXEL] = x_gen

fig = plt.figure(figsize=(10, 10))
plt.imshow(figure, cmap='Greys_r')
fig.savefig('./variational_autoencoder.png')
plt.show()

补充知识:keras 自定义损失 自动求导时出现None

问题记录,keras 自定义损失 自动求导时出现None,后来想到是因为传入的变量没有使用,所以keras无法求出偏导,修改后问题解决。就是不愿使用的变量×0,求导后还是0就可以了。

def my_complex_loss_graph(y_label, emb_uid, lstm_out,y_true_1,y_true_2,y_true_3,out_1,out_2,out_3):
 
  mse_out_1 = mean_squared_error(y_true_1, out_1)
  mse_out_2 = mean_squared_error(y_true_2, out_2)
  mse_out_3 = mean_squared_error(y_true_3, out_3)
  # emb_uid= K.reshape(emb_uid, [-1, 32])
  cosine_sim = tf.reduce_sum(0.5*tf.square(emb_uid-lstm_out))
 
  cost=0*cosine_sim+K.sum([0.5*mse_out_1 , 0.25*mse_out_2,0.25*mse_out_3],axis=1,keepdims=True)
  # print(mse_out_1)
  final_loss = cost
 
  return K.mean(final_loss)

以上这篇keras打印loss对权重的导数方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
多线程爬虫批量下载pcgame图片url 保存为xml的实现代码
Jan 17 Python
Python实现的飞速中文网小说下载脚本
Apr 23 Python
Python读取csv文件分隔符设置方法
Jan 14 Python
python 获取等间隔的数组实例
Jul 04 Python
Python实现一个带权无回置随机抽选函数的方法
Jul 24 Python
Python3实现配置文件差异对比脚本
Nov 18 Python
基于Django实现日志记录报错信息
Dec 17 Python
python中return的返回和执行实例
Dec 24 Python
tensorflow之自定义神经网络层实例
Feb 07 Python
Python如何使用turtle库绘制图形
Feb 26 Python
Python日志处理模块logging用法解析
May 19 Python
Python连接HDFS实现文件上传下载及Pandas转换文本文件到CSV操作
Jun 06 Python
Python xlrd模块导入过程及常用操作
Jun 10 #Python
keras-siamese用自己的数据集实现详解
Jun 10 #Python
python实现mean-shift聚类算法
Jun 10 #Python
Keras之自定义损失(loss)函数用法说明
Jun 10 #Python
Python xlwt模块使用代码实例
Jun 10 #Python
python中def是做什么的
Jun 10 #Python
keras实现调用自己训练的模型,并去掉全连接层
Jun 09 #Python
You might like
轻松修复Discuz!数据库
2008/05/03 PHP
PHP+AJAX实现投票功能的方法
2015/09/28 PHP
ZendFramework框架实现连接两个或多个数据库的方法
2016/12/08 PHP
PHP7创建COOKIE和销毁COOKIE的实例方法
2020/02/03 PHP
详解JavaScript的表达式与运算符
2015/11/30 Javascript
详解JavaScript基本类型和引用类型
2015/12/09 Javascript
JavaScript定义数组的三种方法(new Array(),new Array('x','y')
2016/10/04 Javascript
Vue 使用中的小技巧
2018/04/26 Javascript
关于在vue 中使用百度ueEditor编辑器的方法实例代码
2018/09/14 Javascript
jQuery/JS监听input输入框值变化实例
2019/10/17 jQuery
javascript实现支付宝滑块验证码效果
2020/07/24 Javascript
JS实现简易日历效果
2021/01/25 Javascript
详尽讲述用Python的Django框架测试驱动开发的教程
2015/04/22 Python
详解Python中的type()方法的使用
2015/05/21 Python
Python中的ctime()方法使用教程
2015/05/22 Python
Python的Django框架中自定义模版标签的示例
2015/07/20 Python
python使用opencv进行人脸识别
2017/04/07 Python
python使用xslt提取网页数据的方法
2018/02/23 Python
python3+mysql查询数据并通过邮件群发excel附件
2018/02/24 Python
详解Python中的type和object
2018/08/15 Python
使用Python横向合并excel文件的实例
2018/12/11 Python
使用pyinstaller打包PyQt4程序遇到的问题及解决方法
2019/06/24 Python
Python数据分析pandas模块用法实例详解
2019/11/20 Python
PyTorch和Keras计算模型参数的例子
2020/01/02 Python
pyinstaller 3.6版本通过pip安装失败的解决办法(推荐)
2020/01/18 Python
python3.x中安装web.py步骤方法
2020/06/23 Python
Python操作Elasticsearch处理timeout超时
2020/07/17 Python
django创建css文件夹的具体方法
2020/07/31 Python
Pycharm操作Git及GitHub的步骤详解
2020/10/27 Python
Selenium环境变量配置(火狐浏览器)及验证实现
2020/12/07 Python
C#如何进行LDAP用户校验
2012/11/21 面试题
Java中的异常处理机制的简单原理和应用
2013/04/27 面试题
就业自荐书
2013/12/05 职场文书
市委常委班子党的群众路线教育实践活动整改措施
2014/10/02 职场文书
2016年小学优秀班主任事迹材料
2016/02/29 职场文书
Win11任务栏太宽了怎么办?一招解决Win11任务栏太宽问题
2021/11/21 数码科技