keras打印loss对权重的导数方式


Posted in Python onJune 10, 2020

Notes

怀疑模型梯度爆炸,想打印模型 loss 对各权重的导数看看。如果如果fit来训练的话,可以用keras.callbacks.TensorBoard实现。

但此次使用train_on_batch来训练的,用K.gradients和K.function实现。

Codes

以一份 VAE 代码为例

# -*- coding: utf8 -*-
import keras
from keras.models import Model
from keras.layers import Input, Lambda, Conv2D, MaxPooling2D, Flatten, Dense, Reshape
from keras.losses import binary_crossentropy
from keras.datasets import mnist, fashion_mnist
import keras.backend as K
from scipy.stats import norm
import numpy as np
import matplotlib.pyplot as plt

BATCH = 128
N_CLASS = 10
EPOCH = 5
IN_DIM = 28 * 28
H_DIM = 128
Z_DIM = 2

(x_train, y_train), (x_test, y_test) = fashion_mnist.load_data()
x_train = x_train.reshape(len(x_train), -1).astype('float32') / 255.
x_test = x_test.reshape(len(x_test), -1).astype('float32') / 255.

def sampleing(args):
  """reparameterize"""
  mu, logvar = args
  eps = K.random_normal([K.shape(mu)[0], Z_DIM], mean=0.0, stddev=1.0)
  return mu + eps * K.exp(logvar / 2.)

# encode
x_in = Input([IN_DIM])
h = Dense(H_DIM, activation='relu')(x_in)
z_mu = Dense(Z_DIM)(h) # mean,不用激活
z_logvar = Dense(Z_DIM)(h) # log variance,不用激活
z = Lambda(sampleing, output_shape=[Z_DIM])([z_mu, z_logvar]) # 只能有一个参数
encoder = Model(x_in, [z_mu, z_logvar, z], name='encoder')

# decode
z_in = Input([Z_DIM])
h_hat = Dense(H_DIM, activation='relu')(z_in)
x_hat = Dense(IN_DIM, activation='sigmoid')(h_hat)
decoder = Model(z_in, x_hat, name='decoder')

# VAE
x_in = Input([IN_DIM])
x = x_in
z_mu, z_logvar, z = encoder(x)
x = decoder(z)
out = x
vae = Model(x_in, [out, out], name='vae')

# loss_kl = 0.5 * K.sum(K.square(z_mu) + K.exp(z_logvar) - 1. - z_logvar, axis=1)
# loss_recon = binary_crossentropy(K.reshape(vae_in, [-1, IN_DIM]), vae_out) * IN_DIM
# loss_vae = K.mean(loss_kl + loss_recon)

def loss_kl(y_true, y_pred):
  return 0.5 * K.sum(K.square(z_mu) + K.exp(z_logvar) - 1. - z_logvar, axis=1)


# vae.add_loss(loss_vae)
vae.compile(optimizer='rmsprop',
      loss=[loss_kl, 'binary_crossentropy'],
      loss_weights=[1, IN_DIM])
vae.summary()

# 获取模型权重 variable
w = vae.trainable_weights
print(w)

# 打印 KL 对权重的导数
# KL 要是 Tensor,不能是上面的函数 `loss_kl`
grad = K.gradients(0.5 * K.sum(K.square(z_mu) + K.exp(z_logvar) - 1. - z_logvar, axis=1),
          w)
print(grad) # 有些是 None 的
grad = grad[grad is not None] # 去掉 None,不然报错

# 打印梯度的函数
# K.function 的输入和输出必要是 list!就算只有一个
show_grad = K.function([vae.input], [grad])

# vae.fit(x_train, # y_train, # 不能传 y_train
#     batch_size=BATCH,
#     epochs=EPOCH,
#     verbose=1,
#     validation_data=(x_test, None))

''' 以 train_on_batch 方式训练 '''
for epoch in range(EPOCH):
  for b in range(x_train.shape[0] // BATCH):
    idx = np.random.choice(x_train.shape[0], BATCH)
    x = x_train[idx]
    l = vae.train_on_batch([x], [x, x])

  # 计算梯度
  gd = show_grad([x])
  # 打印梯度
  print(gd)

# show manifold
PIXEL = 28
N_PICT = 30
grid_x = norm.ppf(np.linspace(0.05, 0.95, N_PICT))
grid_y = grid_x

figure = np.zeros([N_PICT * PIXEL, N_PICT * PIXEL])
for i, xi in enumerate(grid_x):
  for j, yj in enumerate(grid_y):
    noise = np.array([[xi, yj]]) # 必须秩为 2,两层中括号
    x_gen = decoder.predict(noise)
    # print('x_gen shape:', x_gen.shape)
    x_gen = x_gen[0].reshape([PIXEL, PIXEL])
    figure[i * PIXEL: (i+1) * PIXEL,
        j * PIXEL: (j+1) * PIXEL] = x_gen

fig = plt.figure(figsize=(10, 10))
plt.imshow(figure, cmap='Greys_r')
fig.savefig('./variational_autoencoder.png')
plt.show()

补充知识:keras 自定义损失 自动求导时出现None

问题记录,keras 自定义损失 自动求导时出现None,后来想到是因为传入的变量没有使用,所以keras无法求出偏导,修改后问题解决。就是不愿使用的变量×0,求导后还是0就可以了。

def my_complex_loss_graph(y_label, emb_uid, lstm_out,y_true_1,y_true_2,y_true_3,out_1,out_2,out_3):
 
  mse_out_1 = mean_squared_error(y_true_1, out_1)
  mse_out_2 = mean_squared_error(y_true_2, out_2)
  mse_out_3 = mean_squared_error(y_true_3, out_3)
  # emb_uid= K.reshape(emb_uid, [-1, 32])
  cosine_sim = tf.reduce_sum(0.5*tf.square(emb_uid-lstm_out))
 
  cost=0*cosine_sim+K.sum([0.5*mse_out_1 , 0.25*mse_out_2,0.25*mse_out_3],axis=1,keepdims=True)
  # print(mse_out_1)
  final_loss = cost
 
  return K.mean(final_loss)

以上这篇keras打印loss对权重的导数方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python交换两个变量的值方法
Jan 12 Python
python实现websocket的客户端压力测试
Jun 25 Python
python实现发送form-data数据的方法详解
Sep 27 Python
python的mysql数据库建立表与插入数据操作示例
Sep 30 Python
python 实现绘制整齐的表格
Nov 18 Python
妙用itchat! python实现久坐提醒功能
Nov 25 Python
python使用多线程+socket实现端口扫描
May 28 Python
Django用户登录与注册系统的实现示例
Jun 03 Python
解决Python paramiko 模块远程执行ssh 命令 nohup 不生效的问题
Jul 14 Python
通过代码实例了解Python sys模块
Sep 14 Python
python 操作excel表格的方法
Dec 05 Python
Python第三方库安装缓慢的解决方法
Feb 06 Python
Python xlrd模块导入过程及常用操作
Jun 10 #Python
keras-siamese用自己的数据集实现详解
Jun 10 #Python
python实现mean-shift聚类算法
Jun 10 #Python
Keras之自定义损失(loss)函数用法说明
Jun 10 #Python
Python xlwt模块使用代码实例
Jun 10 #Python
python中def是做什么的
Jun 10 #Python
keras实现调用自己训练的模型,并去掉全连接层
Jun 09 #Python
You might like
PHP输出数组中重名的元素的几种处理方法
2012/09/05 PHP
如何解决CI框架的Disallowed Key Characters错误提示
2013/07/05 PHP
Yii2.0中使用js异步删除示例
2017/03/10 PHP
使用PHPWord生成word文档的方法详解
2019/06/06 PHP
javascript 写类方式之七
2009/07/05 Javascript
JavaScript中SQL语句的应用实现
2010/05/04 Javascript
javascript中RegExp保留小数点后几位数的方法分享
2013/08/13 Javascript
jquery实现的美女拼图游戏实例
2015/05/04 Javascript
喜大普奔!jQuery发布 3.0 最终版
2016/06/12 Javascript
AngularJs IE Compatibility 兼容老版本IE
2016/09/01 Javascript
JS中数组重排序方法
2016/11/11 Javascript
Three.js 再探 - 写一个微信跳一跳极简版游戏
2018/01/04 Javascript
vue里面父组件修改子组件样式的方法
2018/02/03 Javascript
解决vue-cli + webpack 新建项目出错的问题
2018/03/20 Javascript
微信小程序中weui用法解析
2019/10/21 Javascript
微信小程序8种数据通信的方式小结
2020/02/03 Javascript
[04:49]期待西雅图之战 2016国际邀请赛中国区预选赛WINGS战队赛后采访
2016/06/29 DOTA
matplotlib绘图实例演示标记路径
2018/01/23 Python
python 定义给定初值或长度的list方法
2018/06/23 Python
python 监听salt job状态,并任务数据推送到redis中的方法
2019/01/14 Python
python3.4爬虫demo
2019/01/22 Python
python实现在cmd窗口显示彩色文字
2019/06/24 Python
Pandas中resample方法详解
2019/07/02 Python
pytorch 使用单个GPU与多个GPU进行训练与测试的方法
2019/08/19 Python
python @propert装饰器使用方法原理解析
2019/12/25 Python
使用Python爬虫库requests发送请求、传递URL参数、定制headers
2020/01/25 Python
TensorFlow2.X使用图片制作简单的数据集训练模型
2020/04/08 Python
Python解析微信dat文件的方法
2020/11/30 Python
HTML5+css3:3D旋转木马效果相册
2017/01/03 HTML / CSS
大一军训感言
2014/01/09 职场文书
会计系毕业生求职信
2014/05/28 职场文书
关于感恩的演讲稿800字
2014/08/26 职场文书
介绍信的格式
2015/01/30 职场文书
保送生自荐信范文
2015/03/26 职场文书
Python Pandas数据分析之iloc和loc的用法详解
2021/11/11 Python
在Docker容器中部署SQL Server
2022/04/11 Servers