pytorch 获取层权重,对特定层注入hook, 提取中间层输出的方法


Posted in Python onAugust 17, 2019

如下所示:

#获取模型权重
for k, v in model_2.state_dict().iteritems():
 print("Layer {}".format(k))
 print(v)
#获取模型权重
for layer in model_2.modules():
 if isinstance(layer, nn.Linear):
  print(layer.weight)
#将一个模型权重载入另一个模型
model = VGG(make_layers(cfg['E']), **kwargs)
if pretrained:
 load = torch.load('/home/huangqk/.torch/models/vgg19-dcbb9e9d.pth')
 load_state = {k: v for k, v in load.items() if k not in ['classifier.0.weight', 'classifier.0.bias', 'classifier.3.weight', 'classifier.3.bias', 'classifier.6.weight', 'classifier.6.bias']}
 model_state = model.state_dict()
 model_state.update(load_state)
 model.load_state_dict(model_state)
return model
# 对特定层注入hook
def hook_layers(model):
 def hook_function(module, inputs, outputs):
  recreate_image(inputs[0])

 print(model.features._modules)
 first_layer = list(model.features._modules.items())[0][1]
 first_layer.register_forward_hook(hook_function)
#获取层
x = someinput
for l in vgg.features.modules():
 x = l(x)
modulelist = list(vgg.features.modules())
for l in modulelist[:5]:
 x = l(x)
keep = x
for l in modulelist[5:]:
 x = l(x)
# 提取vgg模型的中间层输出
# coding:utf8
import torch
import torch.nn as nn
from torchvision.models import vgg16
from collections import namedtuple


class Vgg16(torch.nn.Module):
 def __init__(self):
  super(Vgg16, self).__init__()
  features = list(vgg16(pretrained=True).features)[:23]
  # features的第3,8,15,22层分别是: relu1_2,relu2_2,relu3_3,relu4_3
  self.features = nn.ModuleList(features).eval()

 def forward(self, x):
  results = []
  for ii, model in enumerate(self.features):
   x = model(x)
   if ii in {3, 8, 15, 22}:
    results.append(x)

  vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3'])
  return vgg_outputs(*results)

以上这篇pytorch 获取层权重,对特定层注入hook, 提取中间层输出的方法就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
学习python (1)
Oct 31 Python
CentOS6.5设置Django开发环境
Oct 13 Python
tensorflow中next_batch的具体使用
Feb 02 Python
Python基于dom操作xml数据的方法示例
May 12 Python
基于wxPython的GUI实现输入对话框(2)
Feb 27 Python
python获取点击的坐标画图形的方法
Jul 09 Python
对python while循环和双重循环的实例详解
Aug 23 Python
Python大数据之网络爬虫的post请求、get请求区别实例分析
Nov 16 Python
Tensorflow: 从checkpoint文件中读取tensor方式
Feb 10 Python
Python进程间通信multiprocess代码实例
Mar 18 Python
python的reverse函数翻转结果为None的问题
May 11 Python
Python如何识别银行卡卡号?
Jun 10 Python
关于PyTorch源码解读之torchvision.models
Aug 17 #Python
django项目用higcharts统计最近七天文章点击量
Aug 17 #Python
Django对models里的objects的使用详解
Aug 17 #Python
python3.6中@property装饰器的使用方法示例
Aug 17 #Python
对django的User模型和四种扩展/重写方法小结
Aug 17 #Python
python3.6编写的单元测试示例
Aug 17 #Python
python3 实现的对象与json相互转换操作示例
Aug 17 #Python
You might like
PHP+APACHE实现用户论证的方法
2006/10/09 PHP
destoon调用自定义模板及样式的公告栏
2014/06/21 PHP
使用ThinkPHP的自动完成实现无限级分类实例详解
2016/09/02 PHP
Centos 6.5下PHP 5.3安装ffmpeg扩展的步骤详解
2017/03/02 PHP
JS 图片缩放效果代码
2010/06/09 Javascript
ie6下png图片背景不透明的解决办法使用js实现
2013/01/11 Javascript
jquery 滚动条事件简单实例
2013/07/12 Javascript
使用js正则控制input标签只允许输入的值
2013/07/29 Javascript
悬浮数字的实现案例
2014/02/19 Javascript
jQuery获取checkboxlist的value值的方法
2015/09/27 Javascript
浅谈JavaScript 执行环境、作用域及垃圾回收
2016/05/31 Javascript
bootstrap使用validate实现简单校验功能
2016/12/02 Javascript
vue图片加载与显示默认图片实例代码
2017/03/16 Javascript
vue.js 获取当前自定义属性值
2017/06/01 Javascript
jQuery UI 实例讲解 - 日期选择器(Datepicker)
2017/09/18 jQuery
基于vue+axios+lrz.js微信端图片压缩上传方法
2019/06/25 Javascript
Vue 实现一个简单的鼠标拖拽滚动效果插件
2020/12/10 Vue.js
树莓派中python获取GY-85九轴模块信息示例
2013/12/05 Python
理解Python中的With语句
2015/02/02 Python
Python读写文件方法总结
2015/06/09 Python
Python unittest单元测试框架总结
2018/09/08 Python
Python如何使用argparse模块处理命令行参数
2019/12/11 Python
美国最大的在线生存商店:Survival Frog
2020/12/13 全球购物
环境科学专业个人求职信
2013/09/26 职场文书
大学毕业生通用求职信
2013/09/28 职场文书
机械电子工程专业推荐信范文
2013/11/20 职场文书
大学毕业感言
2014/01/10 职场文书
《三个小伙伴》教学反思
2014/04/11 职场文书
民族学专业职业生涯规划范文:积跬步以至千里
2014/09/11 职场文书
诉讼授权委托书范本
2014/10/05 职场文书
军人离婚协议书样本
2014/10/21 职场文书
小学生作文批改评语
2014/12/25 职场文书
单位考核聘任报告
2015/03/02 职场文书
python 利用 PIL 将数组值转成图片的实现
2021/04/12 Python
基于Redis位图实现用户签到功能
2021/05/08 Redis
Python可视化神器pyecharts绘制地理图表
2022/07/07 Python