pytorch 获取层权重,对特定层注入hook, 提取中间层输出的方法


Posted in Python onAugust 17, 2019

如下所示:

#获取模型权重
for k, v in model_2.state_dict().iteritems():
 print("Layer {}".format(k))
 print(v)
#获取模型权重
for layer in model_2.modules():
 if isinstance(layer, nn.Linear):
  print(layer.weight)
#将一个模型权重载入另一个模型
model = VGG(make_layers(cfg['E']), **kwargs)
if pretrained:
 load = torch.load('/home/huangqk/.torch/models/vgg19-dcbb9e9d.pth')
 load_state = {k: v for k, v in load.items() if k not in ['classifier.0.weight', 'classifier.0.bias', 'classifier.3.weight', 'classifier.3.bias', 'classifier.6.weight', 'classifier.6.bias']}
 model_state = model.state_dict()
 model_state.update(load_state)
 model.load_state_dict(model_state)
return model
# 对特定层注入hook
def hook_layers(model):
 def hook_function(module, inputs, outputs):
  recreate_image(inputs[0])

 print(model.features._modules)
 first_layer = list(model.features._modules.items())[0][1]
 first_layer.register_forward_hook(hook_function)
#获取层
x = someinput
for l in vgg.features.modules():
 x = l(x)
modulelist = list(vgg.features.modules())
for l in modulelist[:5]:
 x = l(x)
keep = x
for l in modulelist[5:]:
 x = l(x)
# 提取vgg模型的中间层输出
# coding:utf8
import torch
import torch.nn as nn
from torchvision.models import vgg16
from collections import namedtuple


class Vgg16(torch.nn.Module):
 def __init__(self):
  super(Vgg16, self).__init__()
  features = list(vgg16(pretrained=True).features)[:23]
  # features的第3,8,15,22层分别是: relu1_2,relu2_2,relu3_3,relu4_3
  self.features = nn.ModuleList(features).eval()

 def forward(self, x):
  results = []
  for ii, model in enumerate(self.features):
   x = model(x)
   if ii in {3, 8, 15, 22}:
    results.append(x)

  vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3'])
  return vgg_outputs(*results)

以上这篇pytorch 获取层权重,对特定层注入hook, 提取中间层输出的方法就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
在Python下使用Txt2Html实现网页过滤代理的教程
Apr 11 Python
使用Python下载歌词并嵌入歌曲文件中的实现代码
Nov 13 Python
Python2.7编程中SQLite3基本操作方法示例
Aug 09 Python
Python使用装饰器进行django开发实例代码
Feb 06 Python
Python使用pymongo模块操作MongoDB的方法示例
Jul 20 Python
django解决跨域请求的问题详解
Jan 20 Python
11个Python3字典内置方法大全与示例汇总
May 13 Python
pandas dataframe的合并实现(append, merge, concat)
Jun 24 Python
python 使用pdfminer3k 读取PDF文档的例子
Aug 27 Python
Python对接 xray 和微信实现自动告警
Sep 17 Python
Python图像处理库PIL的ImageGrab模块介绍详解
Feb 26 Python
详解python中的异常和文件读写
Jan 03 Python
关于PyTorch源码解读之torchvision.models
Aug 17 #Python
django项目用higcharts统计最近七天文章点击量
Aug 17 #Python
Django对models里的objects的使用详解
Aug 17 #Python
python3.6中@property装饰器的使用方法示例
Aug 17 #Python
对django的User模型和四种扩展/重写方法小结
Aug 17 #Python
python3.6编写的单元测试示例
Aug 17 #Python
python3 实现的对象与json相互转换操作示例
Aug 17 #Python
You might like
PHP 动态随机生成验证码类代码
2010/04/09 PHP
ueditor 1.2.6 使用方法说明
2013/07/24 PHP
php另类上传图片的方法(PHP用Socket上传图片)
2013/10/30 PHP
详解php语言最牛掰的Laravel框架
2017/11/20 PHP
JS鼠标事件大全 推荐收藏
2011/11/01 Javascript
javascript简单实现命名空间效果
2014/03/06 Javascript
js子页面获取父页面数据示例
2014/05/15 Javascript
原生js实现百叶窗效果及原理介绍
2016/04/12 Javascript
switch语句的妙用(必看篇)
2016/10/03 Javascript
微信小程序 缓存(本地缓存、异步缓存、同步缓存)详解
2017/01/17 Javascript
react.js 翻页插件实例代码
2017/01/19 Javascript
详解vue-router 2.0 常用基础知识点之router.push()
2017/05/10 Javascript
JS事件流与事件处理程序实例分析
2019/08/16 Javascript
LayUi数据表格自定义赋值方式
2019/10/26 Javascript
vue中watch和computed为什么能监听到数据的改变以及不同之处
2019/12/27 Javascript
JS几个常用的函数和对象定义与用法示例
2020/01/15 Javascript
javascript设计模式 ? 装饰模式原理与应用实例分析
2020/04/14 Javascript
Vue实现附件上传功能
2020/05/28 Javascript
vue实现列表滚动的过渡动画
2020/06/29 Javascript
nuxt.js服务端渲染中axios和proxy代理的配置操作
2020/11/06 Javascript
[01:07:47]Secret vs Optic Supermajor 胜者组 BO3 第一场 6.4
2018/06/05 DOTA
[52:14]VG vs Serenity 2018国际邀请赛小组赛BO2 第一场 8.17
2018/08/20 DOTA
python计算圆周率pi的方法
2015/07/11 Python
使用CSS3制作饼状旋转载入效果的实例
2015/06/23 HTML / CSS
体育纪念品、亲笔签名的体育收藏品:Steiner Sports
2020/07/31 全球购物
JAVA和C++区别都有哪些
2015/03/30 面试题
软件测试面试题
2014/01/05 面试题
转预备党员政审材料
2014/02/06 职场文书
财务科科长岗位职责
2014/03/10 职场文书
企业管理毕业生求职信
2014/03/11 职场文书
2014年心理健康教育工作总结
2014/12/06 职场文书
公司会议开幕词
2015/01/29 职场文书
PyTorch 如何自动计算梯度
2021/05/23 Python
分析ZooKeeper分布式锁的实现
2021/06/30 Java/Android
Python中使用Opencv开发停车位计数器功能
2022/04/04 Python
方法汇总:Python 安装第三方库常用
2022/04/26 Python