keras打印loss对权重的导数方式


Posted in Python onJune 10, 2020

Notes

怀疑模型梯度爆炸,想打印模型 loss 对各权重的导数看看。如果如果fit来训练的话,可以用keras.callbacks.TensorBoard实现。

但此次使用train_on_batch来训练的,用K.gradients和K.function实现。

Codes

以一份 VAE 代码为例

# -*- coding: utf8 -*-
import keras
from keras.models import Model
from keras.layers import Input, Lambda, Conv2D, MaxPooling2D, Flatten, Dense, Reshape
from keras.losses import binary_crossentropy
from keras.datasets import mnist, fashion_mnist
import keras.backend as K
from scipy.stats import norm
import numpy as np
import matplotlib.pyplot as plt

BATCH = 128
N_CLASS = 10
EPOCH = 5
IN_DIM = 28 * 28
H_DIM = 128
Z_DIM = 2

(x_train, y_train), (x_test, y_test) = fashion_mnist.load_data()
x_train = x_train.reshape(len(x_train), -1).astype('float32') / 255.
x_test = x_test.reshape(len(x_test), -1).astype('float32') / 255.

def sampleing(args):
  """reparameterize"""
  mu, logvar = args
  eps = K.random_normal([K.shape(mu)[0], Z_DIM], mean=0.0, stddev=1.0)
  return mu + eps * K.exp(logvar / 2.)

# encode
x_in = Input([IN_DIM])
h = Dense(H_DIM, activation='relu')(x_in)
z_mu = Dense(Z_DIM)(h) # mean,不用激活
z_logvar = Dense(Z_DIM)(h) # log variance,不用激活
z = Lambda(sampleing, output_shape=[Z_DIM])([z_mu, z_logvar]) # 只能有一个参数
encoder = Model(x_in, [z_mu, z_logvar, z], name='encoder')

# decode
z_in = Input([Z_DIM])
h_hat = Dense(H_DIM, activation='relu')(z_in)
x_hat = Dense(IN_DIM, activation='sigmoid')(h_hat)
decoder = Model(z_in, x_hat, name='decoder')

# VAE
x_in = Input([IN_DIM])
x = x_in
z_mu, z_logvar, z = encoder(x)
x = decoder(z)
out = x
vae = Model(x_in, [out, out], name='vae')

# loss_kl = 0.5 * K.sum(K.square(z_mu) + K.exp(z_logvar) - 1. - z_logvar, axis=1)
# loss_recon = binary_crossentropy(K.reshape(vae_in, [-1, IN_DIM]), vae_out) * IN_DIM
# loss_vae = K.mean(loss_kl + loss_recon)

def loss_kl(y_true, y_pred):
  return 0.5 * K.sum(K.square(z_mu) + K.exp(z_logvar) - 1. - z_logvar, axis=1)


# vae.add_loss(loss_vae)
vae.compile(optimizer='rmsprop',
      loss=[loss_kl, 'binary_crossentropy'],
      loss_weights=[1, IN_DIM])
vae.summary()

# 获取模型权重 variable
w = vae.trainable_weights
print(w)

# 打印 KL 对权重的导数
# KL 要是 Tensor,不能是上面的函数 `loss_kl`
grad = K.gradients(0.5 * K.sum(K.square(z_mu) + K.exp(z_logvar) - 1. - z_logvar, axis=1),
          w)
print(grad) # 有些是 None 的
grad = grad[grad is not None] # 去掉 None,不然报错

# 打印梯度的函数
# K.function 的输入和输出必要是 list!就算只有一个
show_grad = K.function([vae.input], [grad])

# vae.fit(x_train, # y_train, # 不能传 y_train
#     batch_size=BATCH,
#     epochs=EPOCH,
#     verbose=1,
#     validation_data=(x_test, None))

''' 以 train_on_batch 方式训练 '''
for epoch in range(EPOCH):
  for b in range(x_train.shape[0] // BATCH):
    idx = np.random.choice(x_train.shape[0], BATCH)
    x = x_train[idx]
    l = vae.train_on_batch([x], [x, x])

  # 计算梯度
  gd = show_grad([x])
  # 打印梯度
  print(gd)

# show manifold
PIXEL = 28
N_PICT = 30
grid_x = norm.ppf(np.linspace(0.05, 0.95, N_PICT))
grid_y = grid_x

figure = np.zeros([N_PICT * PIXEL, N_PICT * PIXEL])
for i, xi in enumerate(grid_x):
  for j, yj in enumerate(grid_y):
    noise = np.array([[xi, yj]]) # 必须秩为 2,两层中括号
    x_gen = decoder.predict(noise)
    # print('x_gen shape:', x_gen.shape)
    x_gen = x_gen[0].reshape([PIXEL, PIXEL])
    figure[i * PIXEL: (i+1) * PIXEL,
        j * PIXEL: (j+1) * PIXEL] = x_gen

fig = plt.figure(figsize=(10, 10))
plt.imshow(figure, cmap='Greys_r')
fig.savefig('./variational_autoencoder.png')
plt.show()

补充知识:keras 自定义损失 自动求导时出现None

问题记录,keras 自定义损失 自动求导时出现None,后来想到是因为传入的变量没有使用,所以keras无法求出偏导,修改后问题解决。就是不愿使用的变量×0,求导后还是0就可以了。

def my_complex_loss_graph(y_label, emb_uid, lstm_out,y_true_1,y_true_2,y_true_3,out_1,out_2,out_3):
 
  mse_out_1 = mean_squared_error(y_true_1, out_1)
  mse_out_2 = mean_squared_error(y_true_2, out_2)
  mse_out_3 = mean_squared_error(y_true_3, out_3)
  # emb_uid= K.reshape(emb_uid, [-1, 32])
  cosine_sim = tf.reduce_sum(0.5*tf.square(emb_uid-lstm_out))
 
  cost=0*cosine_sim+K.sum([0.5*mse_out_1 , 0.25*mse_out_2,0.25*mse_out_3],axis=1,keepdims=True)
  # print(mse_out_1)
  final_loss = cost
 
  return K.mean(final_loss)

以上这篇keras打印loss对权重的导数方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python标准异常和异常处理详解
Feb 02 Python
Python编程实现蚁群算法详解
Nov 13 Python
Python程序员面试题 你必须提前准备!
Jan 16 Python
[原创]Python入门教程2. 字符串基本操作【运算、格式化输出、常用函数】
Oct 29 Python
使用Python开发SQLite代理服务器的方法
Dec 07 Python
Python2与Python3的区别点整理
Dec 12 Python
pycharm激活码有效到2020年11月底
Sep 18 Python
Python3 利用face_recognition实现人脸识别的方法
Mar 13 Python
Tensorflow卷积实现原理+手写python代码实现卷积教程
May 22 Python
python 图像判断,清晰度(明暗),彩色与黑白实例
Jun 04 Python
Python实现封装打包自己写的代码,被python import
Jul 12 Python
pytorch Dropout过拟合的操作
May 27 Python
Python xlrd模块导入过程及常用操作
Jun 10 #Python
keras-siamese用自己的数据集实现详解
Jun 10 #Python
python实现mean-shift聚类算法
Jun 10 #Python
Keras之自定义损失(loss)函数用法说明
Jun 10 #Python
Python xlwt模块使用代码实例
Jun 10 #Python
python中def是做什么的
Jun 10 #Python
keras实现调用自己训练的模型,并去掉全连接层
Jun 09 #Python
You might like
用session做客户验证时的注意事项
2006/10/09 PHP
php环境配置 php5 mysql5 apache2 phpmyadmin安装与配置
2006/11/17 PHP
PHP 全角转半角实现代码
2010/05/16 PHP
php数组函数序列之in_array() 查找数组值是否存在
2011/10/29 PHP
关于php 接口问题(php接口主要也就是运用curl,curl函数)
2013/07/01 PHP
Drupal7连接多个数据库及常见问题解决
2014/03/02 PHP
php接口技术实例详解
2016/12/07 PHP
php判断数组是否为空的实例方法
2020/05/10 PHP
JavaScript极简入门教程(二):对象和函数
2014/10/25 Javascript
javascript实现Table间隔色以及选择高亮(和动态切换数据)的方法
2015/05/14 Javascript
JQuery选择器、过滤器大整理
2015/05/26 Javascript
IE8下jQuery改变png图片透明度时出现的黑边
2015/08/30 Javascript
JS实现图片高亮展示效果实例
2015/11/24 Javascript
详解在 Angular 项目中添加 clean-blog 模板
2017/07/04 Javascript
vue2.0 兄弟组件(平级)通讯的实现代码
2018/01/15 Javascript
element-ui 时间选择器限制范围的实现(随动)
2019/01/09 Javascript
layui数据表格实现重载数据表格功能(搜索功能)
2019/07/27 Javascript
HTML+JavaScript实现扫雷小游戏
2019/09/30 Javascript
Vue父子传递实例讲解
2020/02/14 Javascript
JS模拟实现京东快递单号查询
2020/11/30 Javascript
[01:23:45]DOTA2-DPC中国联赛 正赛 CDEC vs Dragon BO3 第一场 1月22日
2021/03/11 DOTA
Python基础入门之seed()方法的使用
2015/05/15 Python
python虚拟环境的安装配置图文教程
2017/10/20 Python
75条笑死人的知乎神回复,用60行代码就爬完了
2019/05/06 Python
Python3 无重复字符的最长子串的实现
2019/10/08 Python
python运用pygame库实现双人弹球小游戏
2019/11/25 Python
TensorFlow打印输出tensor的值
2020/04/19 Python
CSS3解决移动页面上点击链接触发色块的问题
2016/06/03 HTML / CSS
html5之Canvas路径绘图、坐标变换应用实例
2012/12/26 HTML / CSS
计算机本科生自荐信
2013/10/15 职场文书
人事主管岗位职责
2014/01/30 职场文书
政风行风整改方案
2014/10/25 职场文书
介绍信模板
2015/01/31 职场文书
图书馆义工感想
2015/08/07 职场文书
导游词之阆中古城
2019/12/23 职场文书
解决ObjectMapper.convertValue() 遇到的一些问题
2021/06/30 Java/Android