自己搭建resnet18网络并加载torchvision自带权重的操作


Posted in Python onMay 13, 2021

直接搭建网络必须与torchvision自带的网络的权重也就是pth文件的结构、尺寸和变量命名完全一致,否则无法加载权重文件。

此时可比较2个字典逐一加载

import torch
import torchvision
import cv2 as cv
from utils.utils import letter_box
from model.backbone import ResNet18

model1 = ResNet18(1)
model2 = torchvision.models.resnet18(progress=False)
fc = model2.fc
model2.fc = torch.nn.Linear(512, 1)
# print(model)
model_dict1 = model1.state_dict()
model_dict2 = torch.load('resnet18.pth')
model_list1 = list(model_dict1.keys())
model_list2 = list(model_dict2.keys())
len1 = len(model_list1)
len2 = len(model_list2)
minlen = min(len1, len2)
for n in range(minlen):
    if model_dict1[model_list1[n]].shape != model_dict2[model_list2[n]].shape:
        continue
    model_dict1[model_list1[n]] = model_dict2[model_list2[n]]
model1.load_state_dict(model_dict1)
missing, unspected = model2.load_state_dict(model_dict2)
image = cv.imread('zhn1.jpg')
image = letter_box(image, 224)
image = image[:, :, ::-1].transpose(2, 0, 1)
print('Network loading complete.')
model1.eval()
model2.eval()
with torch.no_grad():
    image = torch.tensor(image/256, dtype=torch.float32).unsqueeze(0)
    predict1 = model1(image)
    predict2 = model2(image)
print('finished')
# torch.save(model.state_dict(), 'resnet18.pth')

以上为全部程序,最终可测试原模型与加载了自带权重的自定义模型的输出是否相等。

补充:使用Pytorch搭建ResNet分类网络并基于迁移学习训练

如果stride=1,padding=1

卷积处理是不会改变特征矩阵的高和宽

使用BN层时

卷积中的参数bias置为False(有无偏置BN层的输出都相同),BN层放在conv层和relu层的中间

复习BN层:

Batch Norm 层是对每层数据归一化后再进行线性变换改善数据分布, 其中的线性变换是可学习的.

Batch Norm优点:减轻过拟合;改善梯度传播(权重不会过高或过低)容许较高的学习率,能够提高训练速度。减轻对初始化权重的强依赖,使得数据分布在激活函数的非饱和区域,一定程度上解决梯度消失问题。作为一种正则化的方式,在某种程度上减少对dropout的使用。

Batch Norm层摆放位置:在激活层(如 ReLU )之前还是之后,没有一个统一的定论。

BN层与 Dropout 合作:Batch Norm的提出使得dropout的使用减少,但是Batch Norm不能完全取代dropout,保留较小的dropout率,如0.2可能效果更佳。

为什么要先normalize再通过γ,β线性变换恢复接近原来的样子,这不是多此一举吗?

在一定条件下可以纠正原始数据的分布(方差,均值变为新值γ,β),当原始数据分布足够好时就是恒等映射,不改变分布。如果不做BN,方差和均值对前面网络的参数有复杂的关联依赖,具有复杂的非线性。在新参数 γH′ + β 中仅由 γ,β 确定,与前边网络的参数无关,因此新参数很容易通过梯度下降来学习,能够学习到较好的分布。

迁移学习导入权重和下载权重:

import torchvision.models.resnet#ctrl+鼠标左键点击即可下载权重
net = resnet34()#一开始不能设置全连接层的输出种类为自己想要的,必须先将模型参数载入,再修改全连接层
# 官方提供载入预训练模型的方法
model_weight_path = "./resnet34-pre.pth"#权重路径
missing_keys, unexpected_keys = net.load_state_dict(torch.load(model_weight_path), strict=False)#载入模型权重
inchannel = net.fc.in_features
net.fc = nn.Linear(inchannel, 5)#重新确定全连接层

完整代码:

model部分:

import torch.nn as nn
import torch
class BasicBlock(nn.Module):#对应18层和34层所对应的残差结构(既要有实线残差结构功能,也要有虚线残差结构功能)
    expansion = 1#残差结构主分支上的三个卷积层是否相同,相同为1,第三层是一二层四倍则为4
    def __init__(self, in_channel, out_channel, stride=1, downsample=None):#downsample代表虚线残差结构选项
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=in_channel, out_channels=out_channel,
                               kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channel)
        self.relu = nn.ReLU()
        self.conv2 = nn.Conv2d(in_channels=out_channel, out_channels=out_channel,
                               kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channel)
        self.downsample = downsample
    def forward(self, x):
        identity = x
        if self.downsample is not None:
            identity = self.downsample(x)#得到捷径分支的输出
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        out += identity
        out = self.relu(out)
        return out#得到残差结构的最终输出

class Bottleneck(nn.Module):#对应50层、101层和152层所对应的残差结构
    expansion = 4#第三层卷积核个数是第一层和第二层的四倍
    def __init__(self, in_channel, out_channel, stride=1, downsample=None):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=in_channel, out_channels=out_channel,
                               kernel_size=1, stride=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channel)
        self.conv2 = nn.Conv2d(in_channels=out_channel, out_channels=out_channel,
                               kernel_size=3, stride=stride, bias=False, padding=1)
        self.bn2 = nn.BatchNorm2d(out_channel)
        self.conv3 = nn.Conv2d(in_channels=out_channel, out_channels=out_channel*self.expansion,
                               kernel_size=1, stride=1, bias=False)
        self.bn3 = nn.BatchNorm2d(out_channel*self.expansion)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
    def forward(self, x):
        identity = x
        if self.downsample is not None:
            identity = self.downsample(x)
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)
        out = self.conv3(out)
        out = self.bn3(out)
        out += identity
        out = self.relu(out)
        return out

class ResNet(nn.Module):#定义整个网络的框架部分
#blocks_num是残差结构的数目,是一个列表参数,block对应哪个残差模块
    def __init__(self, block, blocks_num, num_classes=1000, include_top=True):
        super(ResNet, self).__init__()
        self.include_top = include_top
        self.in_channel = 64#通过第一个池化层后所得到的特征矩阵的深度
        self.conv1 = nn.Conv2d(3, self.in_channel, kernel_size=7, stride=2,
                               padding=3, bias=False)
        self.bn1 = nn.BatchNorm2d(self.in_channel)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, 64, blocks_num[0])
        self.layer2 = self._make_layer(block, 128, blocks_num[1], stride=2)
        self.layer3 = self._make_layer(block, 256, blocks_num[2], stride=2)
        self.layer4 = self._make_layer(block, 512, blocks_num[3], stride=2)
        if self.include_top:
            self.avgpool = nn.AdaptiveAvgPool2d((1, 1))  # output size = (1, 1)
            self.fc = nn.Linear(512 * block.expansion, num_classes)
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
    def _make_layer(self, block, channel, block_num, stride=1):#channel:残差结构中,第一个卷积层所使用的卷积核的个数
        downsample = None
        if stride != 1 or self.in_channel != channel * block.expansion:#18层和34层会直接跳过这个if语句
            downsample = nn.Sequential(
                nn.Conv2d(self.in_channel, channel * block.expansion, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(channel * block.expansion))
        layers = []
        layers.append(block(self.in_channel, channel, downsample=downsample, stride=stride))
        self.in_channel = channel * block.expansion
        for _ in range(1, block_num):
            layers.append(block(self.in_channel, channel))
        return nn.Sequential(*layers)
    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        if self.include_top:#默认是true
            x = self.avgpool(x)
            x = torch.flatten(x, 1)
            x = self.fc(x)
        return x

def resnet34(num_classes=1000, include_top=True):
    return ResNet(BasicBlock, [3, 4, 6, 3], num_classes=num_classes, include_top=include_top)

def resnet101(num_classes=1000, include_top=True):
    return ResNet(Bottleneck, [3, 4, 23, 3], num_classes=num_classes, include_top=include_top)

训练部分:

import torch
import torch.nn as nn
from torchvision import transforms, datasets
import json
import matplotlib.pyplot as plt
import os
import torch.optim as optim
from model import resnet34, resnet101
import torchvision.models.resnet#ctrl+鼠标左键点击即可下载权重
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)
data_transform = {
    "train": transforms.Compose([transforms.RandomResizedCrop(224),
                                 transforms.RandomHorizontalFlip(),
                                 transforms.ToTensor(),
                                 transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),#和官网初始化方法保持一致
    "val": transforms.Compose([transforms.Resize(256),
                               transforms.CenterCrop(224),
                               transforms.ToTensor(),
                               transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])}

data_root = os.path.abspath(os.path.join(os.getcwd(), "../.."))  # get data root path
image_path = data_root + "/data_set/flower_data/"  # flower data set path
train_dataset = datasets.ImageFolder(root=image_path+"train",
                                     transform=data_transform["train"])
train_num = len(train_dataset)
# {'daisy':0, 'dandelion':1, 'roses':2, 'sunflower':3, 'tulips':4}
flower_list = train_dataset.class_to_idx
cla_dict = dict((val, key) for key, val in flower_list.items())
# write dict into json file
json_str = json.dumps(cla_dict, indent=4)
with open('class_indices.json', 'w') as json_file:
    json_file.write(json_str)
batch_size = 16
train_loader = torch.utils.data.DataLoader(train_dataset,
                                           batch_size=batch_size, shuffle=True,
                                           num_workers=0)
validate_dataset = datasets.ImageFolder(root=image_path + "val",
                                        transform=data_transform["val"])
val_num = len(validate_dataset)
validate_loader = torch.utils.data.DataLoader(validate_dataset,
                                              batch_size=batch_size, shuffle=False,
                                              num_workers=0)
net = resnet34()#一开始不能设置全连接层的输出种类为自己想要的,必须先将模型参数载入,再修改全连接层
# 官方提供载入预训练模型的方法
model_weight_path = "./resnet34-pre.pth"#权重路径
missing_keys, unexpected_keys = net.load_state_dict(torch.load(model_weight_path), strict=False)#载入模型权重
inchannel = net.fc.in_features
net.fc = nn.Linear(inchannel, 5)#重新确定全连接层
net.to(device)
loss_function = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), lr=0.0001)
best_acc = 0.0
save_path = './resNet34.pth'
for epoch in range(3):
    # train
    net.train()#控制BN层状态
    running_loss = 0.0
    for step, data in enumerate(train_loader, start=0):
        images, labels = data
        optimizer.zero_grad()
        logits = net(images.to(device))
        loss = loss_function(logits, labels.to(device))
        loss.backward()
        optimizer.step()
        # print statistics
        running_loss += loss.item()
        # print train process
        rate = (step+1)/len(train_loader)
        a = "*" * int(rate * 50)
        b = "." * int((1 - rate) * 50)
        print("\rtrain loss: {:^3.0f}%[{}->{}]{:.4f}".format(int(rate*100), a, b, loss), end="")
    print()
    # validate
    net.eval()#控制BN层状态
    acc = 0.0  # accumulate accurate number / epoch
    with torch.no_grad():
        for val_data in validate_loader:
            val_images, val_labels = val_data
            outputs = net(val_images.to(device))  # eval model only have last output layer
            # loss = loss_function(outputs, test_labels)
            predict_y = torch.max(outputs, dim=1)[1]
            acc += (predict_y == val_labels.to(device)).sum().item()
        val_accurate = acc / val_num
        if val_accurate > best_acc:
            best_acc = val_accurate
            torch.save(net.state_dict(), save_path)
        print('[epoch %d] train_loss: %.3f  test_accuracy: %.3f' %
              (epoch + 1, running_loss / step, val_accurate))
print('Finished Training')

预测部分:

import torch
from model import resnet34
from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as plt
import json
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
data_transform = transforms.Compose(
    [transforms.Resize(256),
     transforms.CenterCrop(224),
     transforms.ToTensor(),
     transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])#采用和训练方法一样的标准化处理
# load image
img = Image.open("../aa.jpg")
plt.imshow(img)
# [N, C, H, W]
img = data_transform(img)
# expand batch dimension
img = torch.unsqueeze(img, dim=0)
# read class_indict
try:
    json_file = open('./class_indices.json', 'r')
    class_indict = json.load(json_file)
except Exception as e:
    print(e)
    exit(-1)
# create model
model = resnet34(num_classes=5)
# load model weights
model_weight_path = "./resNet34.pth"
model.load_state_dict(torch.load(model_weight_path, map_location=device))#载入训练好的模型参数
model.eval()#使用eval()模式
with torch.no_grad():#不跟踪损失梯度
    # predict class
    output = torch.squeeze(model(img))#压缩batch维度
    predict = torch.softmax(output, dim=0)#通过softmax得到概率分布
    predict_cla = torch.argmax(predict).numpy()#寻找最大值所对应的索引
print(class_indict[str(predict_cla)], predict[predict_cla].numpy())#打印类别信息和概率
plt.show()

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python微信跳一跳游戏辅助代码解析
Jan 29 Python
关于Python正则表达式 findall函数问题详解
Mar 22 Python
python中ImageTk.PhotoImage()不显示图片却不报错问题解决
Dec 06 Python
python数据挖掘需要学的内容
Jun 23 Python
Python调用graphviz绘制结构化图形网络示例
Nov 22 Python
python实现的分层随机抽样案例
Feb 25 Python
Python API len函数操作过程解析
Mar 05 Python
keras-siamese用自己的数据集实现详解
Jun 10 Python
Python3 requests模块如何模仿浏览器及代理
Jun 15 Python
导致python中import错误的原因是什么
Jul 01 Python
Python标准库之typing的用法(类型标注)
Jun 02 Python
Python OpenCV之常用滤波器使用详解
Apr 07 Python
如何使用flask将模型部署为服务
May 13 #Python
教你用python控制安卓手机
Python数据分析入门之数据读取与存储
May 13 #Python
python执行js代码的方法
pytorch加载预训练模型与自己模型不匹配的解决方案
May 13 #Python
Python数据分析入门之教你怎么搭建环境
Pytorch 统计模型参数量的操作 param.numel()
May 13 #Python
You might like
基于PHP开发中的安全防范知识详解
2013/06/06 PHP
php实现阳历阴历互转的方法
2015/10/28 PHP
php设计模式之建造器模式分析【星际争霸游戏案例】
2020/01/23 PHP
PHPStorm 2020.1 调试 Nodejs的多种方法详解
2020/09/17 NodeJs
js字符串的各种格式的转换 ToString,Format
2011/08/08 Javascript
javascript中处理时间戳为日期格式的方法
2014/01/02 Javascript
JavaScript编写连连看小游戏
2015/07/07 Javascript
基于javascript如何传递特殊字符
2015/11/30 Javascript
获取阴历(农历)和当前日期的js代码
2016/02/15 Javascript
JavaScript动态检验密码强度的实现方法
2016/11/09 Javascript
微信小程序 登录实例详解
2017/01/16 Javascript
vue.js从安装到搭建过程详解
2017/03/17 Javascript
JavaScript之Map和Set_动力节点Java学院整理
2017/06/29 Javascript
js获取文件里面的所有文件名(实例)
2017/10/17 Javascript
node.js基于fs模块对系统文件及目录进行读写操作的方法详解
2017/11/10 Javascript
vue2.0 循环遍历加载不同图片的方法
2018/03/06 Javascript
layui实现table加载的示例代码
2018/08/14 Javascript
使用JavaScript破解web
2018/09/28 Javascript
jQuery选择器选中最后一个元素,倒数第二个元素操作示例
2018/12/10 jQuery
Python正则表达式匹配ip地址实例
2014/10/09 Python
Python中pandas dataframe删除一行或一列:drop函数详解
2018/07/03 Python
基于numpy中数组元素的切片复制方法
2018/11/15 Python
Python实现去除列表中重复元素的方法总结【7种方法】
2019/02/16 Python
Django实现内容缓存实例方法
2020/06/30 Python
Brydge英国:适用于Apple iPad和Microsoft Surface Pro的蓝牙键盘
2019/05/16 全球购物
金蝶的一道SQL笔试题
2012/12/18 面试题
STP协议的主要用途是什么?为什么要用STP
2012/12/20 面试题
活动策划邀请函
2014/02/06 职场文书
进步之星获奖感言
2014/02/22 职场文书
秘书英文求职信
2014/04/16 职场文书
教师学期个人总结
2015/02/11 职场文书
自荐信怎么写
2015/03/04 职场文书
地道战观后感2000字
2015/06/04 职场文书
考生诚信考试承诺书(2016版)
2016/03/25 职场文书
Python预测分词的实现
2021/06/18 Python
详解Spring Security中的HttpBasic登录验证模式
2022/03/17 Java/Android