keras打印loss对权重的导数方式


Posted in Python onJune 10, 2020

Notes

怀疑模型梯度爆炸,想打印模型 loss 对各权重的导数看看。如果如果fit来训练的话,可以用keras.callbacks.TensorBoard实现。

但此次使用train_on_batch来训练的,用K.gradients和K.function实现。

Codes

以一份 VAE 代码为例

# -*- coding: utf8 -*-
import keras
from keras.models import Model
from keras.layers import Input, Lambda, Conv2D, MaxPooling2D, Flatten, Dense, Reshape
from keras.losses import binary_crossentropy
from keras.datasets import mnist, fashion_mnist
import keras.backend as K
from scipy.stats import norm
import numpy as np
import matplotlib.pyplot as plt

BATCH = 128
N_CLASS = 10
EPOCH = 5
IN_DIM = 28 * 28
H_DIM = 128
Z_DIM = 2

(x_train, y_train), (x_test, y_test) = fashion_mnist.load_data()
x_train = x_train.reshape(len(x_train), -1).astype('float32') / 255.
x_test = x_test.reshape(len(x_test), -1).astype('float32') / 255.

def sampleing(args):
  """reparameterize"""
  mu, logvar = args
  eps = K.random_normal([K.shape(mu)[0], Z_DIM], mean=0.0, stddev=1.0)
  return mu + eps * K.exp(logvar / 2.)

# encode
x_in = Input([IN_DIM])
h = Dense(H_DIM, activation='relu')(x_in)
z_mu = Dense(Z_DIM)(h) # mean,不用激活
z_logvar = Dense(Z_DIM)(h) # log variance,不用激活
z = Lambda(sampleing, output_shape=[Z_DIM])([z_mu, z_logvar]) # 只能有一个参数
encoder = Model(x_in, [z_mu, z_logvar, z], name='encoder')

# decode
z_in = Input([Z_DIM])
h_hat = Dense(H_DIM, activation='relu')(z_in)
x_hat = Dense(IN_DIM, activation='sigmoid')(h_hat)
decoder = Model(z_in, x_hat, name='decoder')

# VAE
x_in = Input([IN_DIM])
x = x_in
z_mu, z_logvar, z = encoder(x)
x = decoder(z)
out = x
vae = Model(x_in, [out, out], name='vae')

# loss_kl = 0.5 * K.sum(K.square(z_mu) + K.exp(z_logvar) - 1. - z_logvar, axis=1)
# loss_recon = binary_crossentropy(K.reshape(vae_in, [-1, IN_DIM]), vae_out) * IN_DIM
# loss_vae = K.mean(loss_kl + loss_recon)

def loss_kl(y_true, y_pred):
  return 0.5 * K.sum(K.square(z_mu) + K.exp(z_logvar) - 1. - z_logvar, axis=1)


# vae.add_loss(loss_vae)
vae.compile(optimizer='rmsprop',
      loss=[loss_kl, 'binary_crossentropy'],
      loss_weights=[1, IN_DIM])
vae.summary()

# 获取模型权重 variable
w = vae.trainable_weights
print(w)

# 打印 KL 对权重的导数
# KL 要是 Tensor,不能是上面的函数 `loss_kl`
grad = K.gradients(0.5 * K.sum(K.square(z_mu) + K.exp(z_logvar) - 1. - z_logvar, axis=1),
          w)
print(grad) # 有些是 None 的
grad = grad[grad is not None] # 去掉 None,不然报错

# 打印梯度的函数
# K.function 的输入和输出必要是 list!就算只有一个
show_grad = K.function([vae.input], [grad])

# vae.fit(x_train, # y_train, # 不能传 y_train
#     batch_size=BATCH,
#     epochs=EPOCH,
#     verbose=1,
#     validation_data=(x_test, None))

''' 以 train_on_batch 方式训练 '''
for epoch in range(EPOCH):
  for b in range(x_train.shape[0] // BATCH):
    idx = np.random.choice(x_train.shape[0], BATCH)
    x = x_train[idx]
    l = vae.train_on_batch([x], [x, x])

  # 计算梯度
  gd = show_grad([x])
  # 打印梯度
  print(gd)

# show manifold
PIXEL = 28
N_PICT = 30
grid_x = norm.ppf(np.linspace(0.05, 0.95, N_PICT))
grid_y = grid_x

figure = np.zeros([N_PICT * PIXEL, N_PICT * PIXEL])
for i, xi in enumerate(grid_x):
  for j, yj in enumerate(grid_y):
    noise = np.array([[xi, yj]]) # 必须秩为 2,两层中括号
    x_gen = decoder.predict(noise)
    # print('x_gen shape:', x_gen.shape)
    x_gen = x_gen[0].reshape([PIXEL, PIXEL])
    figure[i * PIXEL: (i+1) * PIXEL,
        j * PIXEL: (j+1) * PIXEL] = x_gen

fig = plt.figure(figsize=(10, 10))
plt.imshow(figure, cmap='Greys_r')
fig.savefig('./variational_autoencoder.png')
plt.show()

补充知识:keras 自定义损失 自动求导时出现None

问题记录,keras 自定义损失 自动求导时出现None,后来想到是因为传入的变量没有使用,所以keras无法求出偏导,修改后问题解决。就是不愿使用的变量×0,求导后还是0就可以了。

def my_complex_loss_graph(y_label, emb_uid, lstm_out,y_true_1,y_true_2,y_true_3,out_1,out_2,out_3):
 
  mse_out_1 = mean_squared_error(y_true_1, out_1)
  mse_out_2 = mean_squared_error(y_true_2, out_2)
  mse_out_3 = mean_squared_error(y_true_3, out_3)
  # emb_uid= K.reshape(emb_uid, [-1, 32])
  cosine_sim = tf.reduce_sum(0.5*tf.square(emb_uid-lstm_out))
 
  cost=0*cosine_sim+K.sum([0.5*mse_out_1 , 0.25*mse_out_2,0.25*mse_out_3],axis=1,keepdims=True)
  # print(mse_out_1)
  final_loss = cost
 
  return K.mean(final_loss)

以上这篇keras打印loss对权重的导数方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python中endswith()函数的基本使用
Apr 07 Python
python实现的简单文本类游戏实例
Apr 28 Python
介绍Python中内置的itertools模块
Apr 29 Python
利用Python脚本生成sitemap.xml的实现方法
Jan 31 Python
浅谈Pandas中map, applymap and apply的区别
Apr 10 Python
Python subprocess库的使用详解
Oct 26 Python
在python中使用requests 模拟浏览器发送请求数据的方法
Dec 26 Python
Python实现定期检查源目录与备份目录的差异并进行备份功能示例
Feb 27 Python
django中使用POST方法获取POST数据
Aug 20 Python
python 图片二值化处理(处理后为纯黑白的图片)
Nov 01 Python
使用PyTorch将文件夹下的图片分为训练集和验证集实例
Jan 08 Python
python Scrapy爬虫框架的使用
Jan 21 Python
Python xlrd模块导入过程及常用操作
Jun 10 #Python
keras-siamese用自己的数据集实现详解
Jun 10 #Python
python实现mean-shift聚类算法
Jun 10 #Python
Keras之自定义损失(loss)函数用法说明
Jun 10 #Python
Python xlwt模块使用代码实例
Jun 10 #Python
python中def是做什么的
Jun 10 #Python
keras实现调用自己训练的模型,并去掉全连接层
Jun 09 #Python
You might like
php遍历文件夹和文件列表示例分享
2014/03/11 PHP
php设置静态内容缓存时间的方法
2014/12/01 PHP
浅谈thinkphp的nginx配置,以及重写隐藏index.php入口文件方法
2019/10/12 PHP
js下获取div中的数据的原理分析
2010/04/07 Javascript
javascript Array数组对象的扩展函数代码
2010/05/22 Javascript
基于jquery的设置页面文本框 只能输入数字的实现代码
2011/04/19 Javascript
基于JQuery实现CheckBox全选全不选
2011/06/27 Javascript
js借助ActiveXObject实现创建文件
2013/09/29 Javascript
js和css写一个可以自动隐藏的悬浮框
2014/03/05 Javascript
通过Javascript读取本地Excel文件内容的代码示例
2014/04/08 Javascript
javascript实现类似超链接的效果
2014/12/26 Javascript
jquery捕捉回车键及获取checkbox值与异步请求的方法
2015/12/24 Javascript
一览画面点击复选框后获取多个id值的方法
2016/05/30 Javascript
浅谈JavaScript中面向对象的的深拷贝和浅拷贝
2016/08/01 Javascript
使用ionic在首页新闻中应用到的跑马灯效果的实现方法
2017/02/13 Javascript
Mac系统下Webstorm快捷键整理大全
2017/05/28 Javascript
JS module的导出和导入的实现代码
2019/02/25 Javascript
layui表格内容溢出的解决方法
2019/09/06 Javascript
从零搭一个自用的前端脚手架的方法步骤
2019/09/23 Javascript
使用 Jest 和 Supertest 进行接口端点测试实例详解
2020/04/25 Javascript
[01:33:59]真人秀《加油 DOTA》 第六期
2014/09/09 DOTA
python使用rsa加密算法模块模拟新浪微博登录
2014/01/22 Python
Pandas:DataFrame对象的基础操作方法
2018/06/07 Python
Python通过paramiko远程下载Linux服务器上的文件实例
2018/12/27 Python
python模拟登陆,用session维持回话的实例
2018/12/27 Python
简单了解Java Netty Reactor三种线程模型
2020/04/26 Python
sealed修饰符是干什么的
2012/10/23 面试题
vue+django实现下载文件的示例
2021/03/24 Vue.js
ktv中秋节活动方案
2014/01/30 职场文书
文明城市标语
2014/06/16 职场文书
2014年幼儿园教研工作总结
2014/12/04 职场文书
2015年社区服务活动总结
2015/03/25 职场文书
病危通知单
2015/04/17 职场文书
导游词之河姆渡遗址博物馆
2019/10/10 职场文书
Python echarts实现数据可视化实例详解
2022/03/03 Python
mysql5.5中文乱码问题解决的有用方法
2022/05/30 MySQL