keras打印loss对权重的导数方式


Posted in Python onJune 10, 2020

Notes

怀疑模型梯度爆炸,想打印模型 loss 对各权重的导数看看。如果如果fit来训练的话,可以用keras.callbacks.TensorBoard实现。

但此次使用train_on_batch来训练的,用K.gradients和K.function实现。

Codes

以一份 VAE 代码为例

# -*- coding: utf8 -*-
import keras
from keras.models import Model
from keras.layers import Input, Lambda, Conv2D, MaxPooling2D, Flatten, Dense, Reshape
from keras.losses import binary_crossentropy
from keras.datasets import mnist, fashion_mnist
import keras.backend as K
from scipy.stats import norm
import numpy as np
import matplotlib.pyplot as plt

BATCH = 128
N_CLASS = 10
EPOCH = 5
IN_DIM = 28 * 28
H_DIM = 128
Z_DIM = 2

(x_train, y_train), (x_test, y_test) = fashion_mnist.load_data()
x_train = x_train.reshape(len(x_train), -1).astype('float32') / 255.
x_test = x_test.reshape(len(x_test), -1).astype('float32') / 255.

def sampleing(args):
  """reparameterize"""
  mu, logvar = args
  eps = K.random_normal([K.shape(mu)[0], Z_DIM], mean=0.0, stddev=1.0)
  return mu + eps * K.exp(logvar / 2.)

# encode
x_in = Input([IN_DIM])
h = Dense(H_DIM, activation='relu')(x_in)
z_mu = Dense(Z_DIM)(h) # mean,不用激活
z_logvar = Dense(Z_DIM)(h) # log variance,不用激活
z = Lambda(sampleing, output_shape=[Z_DIM])([z_mu, z_logvar]) # 只能有一个参数
encoder = Model(x_in, [z_mu, z_logvar, z], name='encoder')

# decode
z_in = Input([Z_DIM])
h_hat = Dense(H_DIM, activation='relu')(z_in)
x_hat = Dense(IN_DIM, activation='sigmoid')(h_hat)
decoder = Model(z_in, x_hat, name='decoder')

# VAE
x_in = Input([IN_DIM])
x = x_in
z_mu, z_logvar, z = encoder(x)
x = decoder(z)
out = x
vae = Model(x_in, [out, out], name='vae')

# loss_kl = 0.5 * K.sum(K.square(z_mu) + K.exp(z_logvar) - 1. - z_logvar, axis=1)
# loss_recon = binary_crossentropy(K.reshape(vae_in, [-1, IN_DIM]), vae_out) * IN_DIM
# loss_vae = K.mean(loss_kl + loss_recon)

def loss_kl(y_true, y_pred):
  return 0.5 * K.sum(K.square(z_mu) + K.exp(z_logvar) - 1. - z_logvar, axis=1)


# vae.add_loss(loss_vae)
vae.compile(optimizer='rmsprop',
      loss=[loss_kl, 'binary_crossentropy'],
      loss_weights=[1, IN_DIM])
vae.summary()

# 获取模型权重 variable
w = vae.trainable_weights
print(w)

# 打印 KL 对权重的导数
# KL 要是 Tensor,不能是上面的函数 `loss_kl`
grad = K.gradients(0.5 * K.sum(K.square(z_mu) + K.exp(z_logvar) - 1. - z_logvar, axis=1),
          w)
print(grad) # 有些是 None 的
grad = grad[grad is not None] # 去掉 None,不然报错

# 打印梯度的函数
# K.function 的输入和输出必要是 list!就算只有一个
show_grad = K.function([vae.input], [grad])

# vae.fit(x_train, # y_train, # 不能传 y_train
#     batch_size=BATCH,
#     epochs=EPOCH,
#     verbose=1,
#     validation_data=(x_test, None))

''' 以 train_on_batch 方式训练 '''
for epoch in range(EPOCH):
  for b in range(x_train.shape[0] // BATCH):
    idx = np.random.choice(x_train.shape[0], BATCH)
    x = x_train[idx]
    l = vae.train_on_batch([x], [x, x])

  # 计算梯度
  gd = show_grad([x])
  # 打印梯度
  print(gd)

# show manifold
PIXEL = 28
N_PICT = 30
grid_x = norm.ppf(np.linspace(0.05, 0.95, N_PICT))
grid_y = grid_x

figure = np.zeros([N_PICT * PIXEL, N_PICT * PIXEL])
for i, xi in enumerate(grid_x):
  for j, yj in enumerate(grid_y):
    noise = np.array([[xi, yj]]) # 必须秩为 2,两层中括号
    x_gen = decoder.predict(noise)
    # print('x_gen shape:', x_gen.shape)
    x_gen = x_gen[0].reshape([PIXEL, PIXEL])
    figure[i * PIXEL: (i+1) * PIXEL,
        j * PIXEL: (j+1) * PIXEL] = x_gen

fig = plt.figure(figsize=(10, 10))
plt.imshow(figure, cmap='Greys_r')
fig.savefig('./variational_autoencoder.png')
plt.show()

补充知识:keras 自定义损失 自动求导时出现None

问题记录,keras 自定义损失 自动求导时出现None,后来想到是因为传入的变量没有使用,所以keras无法求出偏导,修改后问题解决。就是不愿使用的变量×0,求导后还是0就可以了。

def my_complex_loss_graph(y_label, emb_uid, lstm_out,y_true_1,y_true_2,y_true_3,out_1,out_2,out_3):
 
  mse_out_1 = mean_squared_error(y_true_1, out_1)
  mse_out_2 = mean_squared_error(y_true_2, out_2)
  mse_out_3 = mean_squared_error(y_true_3, out_3)
  # emb_uid= K.reshape(emb_uid, [-1, 32])
  cosine_sim = tf.reduce_sum(0.5*tf.square(emb_uid-lstm_out))
 
  cost=0*cosine_sim+K.sum([0.5*mse_out_1 , 0.25*mse_out_2,0.25*mse_out_3],axis=1,keepdims=True)
  # print(mse_out_1)
  final_loss = cost
 
  return K.mean(final_loss)

以上这篇keras打印loss对权重的导数方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
使用py2exe在Windows下将Python程序转为exe文件
Mar 04 Python
Python如何生成树形图案
Jan 03 Python
CentOS7.3编译安装Python3.6.2的方法
Jan 22 Python
Python实现合并同一个文件夹下所有txt文件的方法示例
Apr 26 Python
Python中垃圾回收和del语句详解
Nov 15 Python
Python用字典构建多级菜单功能
Jul 11 Python
python输入错误后删除的方法
Oct 12 Python
python实现替换word中的关键文字(使用通配符)
Feb 13 Python
python GUI库图形界面开发之PyQt5状态栏控件QStatusBar详细使用方法实例
Feb 28 Python
Pandas将列表(List)转换为数据框(Dataframe)
Apr 24 Python
使用PyCharm安装pytest及requests的问题
Jul 31 Python
Python中正则表达式对单个字符,多个字符和匹配边界等使用
Jan 27 Python
Python xlrd模块导入过程及常用操作
Jun 10 #Python
keras-siamese用自己的数据集实现详解
Jun 10 #Python
python实现mean-shift聚类算法
Jun 10 #Python
Keras之自定义损失(loss)函数用法说明
Jun 10 #Python
Python xlwt模块使用代码实例
Jun 10 #Python
python中def是做什么的
Jun 10 #Python
keras实现调用自己训练的模型,并去掉全连接层
Jun 09 #Python
You might like
php生成百度sitemap站点地图类函数实例
2014/10/17 PHP
PHP获取指定时间段之间的 年,月,天,时,分,秒
2016/06/05 PHP
PHP实现腾讯短网址生成api接口实例
2020/12/08 PHP
Aster vs Newbee BO3 第三场2.18
2021/03/10 DOTA
基于jquery编写的横向自适应幻灯片切换特效的实例代码
2013/08/06 Javascript
ff下JQuery无法监听input的keyup事件的解决方法
2013/12/12 Javascript
Javascript数据结构与算法之列表详解
2015/03/12 Javascript
jQuery实现鼠标划过修改样式的方法
2015/04/14 Javascript
js实现仿百度汽车频道选择汽车图片展示实例
2015/05/06 Javascript
Vue通过input筛选数据
2020/10/26 Javascript
详解Vue demo实现商品列表的展示
2019/05/07 Javascript
Vue监听页面刷新和关闭功能
2019/06/20 Javascript
Nodejs实现WebSocket代码实例
2020/05/19 NodeJs
JavaScript语句错误throw、try及catch实例解析
2020/08/18 Javascript
关于JavaScript中异步/等待的用法与理解
2020/11/18 Javascript
Python中还原JavaScript的escape函数编码后字符串的方法
2014/08/22 Python
Python列表list操作符实例分析【标准类型操作符、切片、连接字符、列表解析、重复操作等】
2017/07/24 Python
python使用openpyxl库修改excel表格数据方法
2018/05/03 Python
对Python生成汉字字库文字,以及转换为文字图片的实例详解
2019/01/29 Python
Python操作qml对象过程详解
2019/09/26 Python
TensorFlow tf.nn.max_pool实现池化操作方式
2020/01/04 Python
python实现时间序列自相关图(acf)、偏自相关图(pacf)教程
2020/06/03 Python
django 将自带的数据库sqlite3改成mysql实例
2020/07/09 Python
美国花布包包品牌:Vera Bradley
2017/08/11 全球购物
舞蹈兴趣小组活动总结
2014/07/07 职场文书
做一个有道德的人活动方案
2014/08/25 职场文书
党的群众路线教育实践活动总结材料
2014/10/30 职场文书
贪污受贿检讨书范文
2014/11/19 职场文书
2014年客房部工作总结
2014/11/22 职场文书
2014年行政人事工作总结
2014/12/09 职场文书
军训阅兵新闻稿
2015/07/17 职场文书
同学联谊会邀请函
2019/06/24 职场文书
Nginx tp3.2.3 404问题解决方案
2021/03/31 Servers
go select编译期的优化处理逻辑使用场景分析
2021/06/28 Golang
golang实现一个简单的websocket聊天室功能
2021/10/05 Golang
MySQL外键约束(Foreign Key)案例详解
2022/06/28 MySQL