自己搭建resnet18网络并加载torchvision自带权重的操作


Posted in Python onMay 13, 2021

直接搭建网络必须与torchvision自带的网络的权重也就是pth文件的结构、尺寸和变量命名完全一致,否则无法加载权重文件。

此时可比较2个字典逐一加载

import torch
import torchvision
import cv2 as cv
from utils.utils import letter_box
from model.backbone import ResNet18

model1 = ResNet18(1)
model2 = torchvision.models.resnet18(progress=False)
fc = model2.fc
model2.fc = torch.nn.Linear(512, 1)
# print(model)
model_dict1 = model1.state_dict()
model_dict2 = torch.load('resnet18.pth')
model_list1 = list(model_dict1.keys())
model_list2 = list(model_dict2.keys())
len1 = len(model_list1)
len2 = len(model_list2)
minlen = min(len1, len2)
for n in range(minlen):
    if model_dict1[model_list1[n]].shape != model_dict2[model_list2[n]].shape:
        continue
    model_dict1[model_list1[n]] = model_dict2[model_list2[n]]
model1.load_state_dict(model_dict1)
missing, unspected = model2.load_state_dict(model_dict2)
image = cv.imread('zhn1.jpg')
image = letter_box(image, 224)
image = image[:, :, ::-1].transpose(2, 0, 1)
print('Network loading complete.')
model1.eval()
model2.eval()
with torch.no_grad():
    image = torch.tensor(image/256, dtype=torch.float32).unsqueeze(0)
    predict1 = model1(image)
    predict2 = model2(image)
print('finished')
# torch.save(model.state_dict(), 'resnet18.pth')

以上为全部程序,最终可测试原模型与加载了自带权重的自定义模型的输出是否相等。

补充:使用Pytorch搭建ResNet分类网络并基于迁移学习训练

如果stride=1,padding=1

卷积处理是不会改变特征矩阵的高和宽

使用BN层时

卷积中的参数bias置为False(有无偏置BN层的输出都相同),BN层放在conv层和relu层的中间

复习BN层:

Batch Norm 层是对每层数据归一化后再进行线性变换改善数据分布, 其中的线性变换是可学习的.

Batch Norm优点:减轻过拟合;改善梯度传播(权重不会过高或过低)容许较高的学习率,能够提高训练速度。减轻对初始化权重的强依赖,使得数据分布在激活函数的非饱和区域,一定程度上解决梯度消失问题。作为一种正则化的方式,在某种程度上减少对dropout的使用。

Batch Norm层摆放位置:在激活层(如 ReLU )之前还是之后,没有一个统一的定论。

BN层与 Dropout 合作:Batch Norm的提出使得dropout的使用减少,但是Batch Norm不能完全取代dropout,保留较小的dropout率,如0.2可能效果更佳。

为什么要先normalize再通过γ,β线性变换恢复接近原来的样子,这不是多此一举吗?

在一定条件下可以纠正原始数据的分布(方差,均值变为新值γ,β),当原始数据分布足够好时就是恒等映射,不改变分布。如果不做BN,方差和均值对前面网络的参数有复杂的关联依赖,具有复杂的非线性。在新参数 γH′ + β 中仅由 γ,β 确定,与前边网络的参数无关,因此新参数很容易通过梯度下降来学习,能够学习到较好的分布。

迁移学习导入权重和下载权重:

import torchvision.models.resnet#ctrl+鼠标左键点击即可下载权重
net = resnet34()#一开始不能设置全连接层的输出种类为自己想要的,必须先将模型参数载入,再修改全连接层
# 官方提供载入预训练模型的方法
model_weight_path = "./resnet34-pre.pth"#权重路径
missing_keys, unexpected_keys = net.load_state_dict(torch.load(model_weight_path), strict=False)#载入模型权重
inchannel = net.fc.in_features
net.fc = nn.Linear(inchannel, 5)#重新确定全连接层

完整代码:

model部分:

import torch.nn as nn
import torch
class BasicBlock(nn.Module):#对应18层和34层所对应的残差结构(既要有实线残差结构功能,也要有虚线残差结构功能)
    expansion = 1#残差结构主分支上的三个卷积层是否相同,相同为1,第三层是一二层四倍则为4
    def __init__(self, in_channel, out_channel, stride=1, downsample=None):#downsample代表虚线残差结构选项
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=in_channel, out_channels=out_channel,
                               kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channel)
        self.relu = nn.ReLU()
        self.conv2 = nn.Conv2d(in_channels=out_channel, out_channels=out_channel,
                               kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channel)
        self.downsample = downsample
    def forward(self, x):
        identity = x
        if self.downsample is not None:
            identity = self.downsample(x)#得到捷径分支的输出
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        out += identity
        out = self.relu(out)
        return out#得到残差结构的最终输出

class Bottleneck(nn.Module):#对应50层、101层和152层所对应的残差结构
    expansion = 4#第三层卷积核个数是第一层和第二层的四倍
    def __init__(self, in_channel, out_channel, stride=1, downsample=None):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=in_channel, out_channels=out_channel,
                               kernel_size=1, stride=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channel)
        self.conv2 = nn.Conv2d(in_channels=out_channel, out_channels=out_channel,
                               kernel_size=3, stride=stride, bias=False, padding=1)
        self.bn2 = nn.BatchNorm2d(out_channel)
        self.conv3 = nn.Conv2d(in_channels=out_channel, out_channels=out_channel*self.expansion,
                               kernel_size=1, stride=1, bias=False)
        self.bn3 = nn.BatchNorm2d(out_channel*self.expansion)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
    def forward(self, x):
        identity = x
        if self.downsample is not None:
            identity = self.downsample(x)
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)
        out = self.conv3(out)
        out = self.bn3(out)
        out += identity
        out = self.relu(out)
        return out

class ResNet(nn.Module):#定义整个网络的框架部分
#blocks_num是残差结构的数目,是一个列表参数,block对应哪个残差模块
    def __init__(self, block, blocks_num, num_classes=1000, include_top=True):
        super(ResNet, self).__init__()
        self.include_top = include_top
        self.in_channel = 64#通过第一个池化层后所得到的特征矩阵的深度
        self.conv1 = nn.Conv2d(3, self.in_channel, kernel_size=7, stride=2,
                               padding=3, bias=False)
        self.bn1 = nn.BatchNorm2d(self.in_channel)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, 64, blocks_num[0])
        self.layer2 = self._make_layer(block, 128, blocks_num[1], stride=2)
        self.layer3 = self._make_layer(block, 256, blocks_num[2], stride=2)
        self.layer4 = self._make_layer(block, 512, blocks_num[3], stride=2)
        if self.include_top:
            self.avgpool = nn.AdaptiveAvgPool2d((1, 1))  # output size = (1, 1)
            self.fc = nn.Linear(512 * block.expansion, num_classes)
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
    def _make_layer(self, block, channel, block_num, stride=1):#channel:残差结构中,第一个卷积层所使用的卷积核的个数
        downsample = None
        if stride != 1 or self.in_channel != channel * block.expansion:#18层和34层会直接跳过这个if语句
            downsample = nn.Sequential(
                nn.Conv2d(self.in_channel, channel * block.expansion, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(channel * block.expansion))
        layers = []
        layers.append(block(self.in_channel, channel, downsample=downsample, stride=stride))
        self.in_channel = channel * block.expansion
        for _ in range(1, block_num):
            layers.append(block(self.in_channel, channel))
        return nn.Sequential(*layers)
    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        if self.include_top:#默认是true
            x = self.avgpool(x)
            x = torch.flatten(x, 1)
            x = self.fc(x)
        return x

def resnet34(num_classes=1000, include_top=True):
    return ResNet(BasicBlock, [3, 4, 6, 3], num_classes=num_classes, include_top=include_top)

def resnet101(num_classes=1000, include_top=True):
    return ResNet(Bottleneck, [3, 4, 23, 3], num_classes=num_classes, include_top=include_top)

训练部分:

import torch
import torch.nn as nn
from torchvision import transforms, datasets
import json
import matplotlib.pyplot as plt
import os
import torch.optim as optim
from model import resnet34, resnet101
import torchvision.models.resnet#ctrl+鼠标左键点击即可下载权重
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)
data_transform = {
    "train": transforms.Compose([transforms.RandomResizedCrop(224),
                                 transforms.RandomHorizontalFlip(),
                                 transforms.ToTensor(),
                                 transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),#和官网初始化方法保持一致
    "val": transforms.Compose([transforms.Resize(256),
                               transforms.CenterCrop(224),
                               transforms.ToTensor(),
                               transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])}

data_root = os.path.abspath(os.path.join(os.getcwd(), "../.."))  # get data root path
image_path = data_root + "/data_set/flower_data/"  # flower data set path
train_dataset = datasets.ImageFolder(root=image_path+"train",
                                     transform=data_transform["train"])
train_num = len(train_dataset)
# {'daisy':0, 'dandelion':1, 'roses':2, 'sunflower':3, 'tulips':4}
flower_list = train_dataset.class_to_idx
cla_dict = dict((val, key) for key, val in flower_list.items())
# write dict into json file
json_str = json.dumps(cla_dict, indent=4)
with open('class_indices.json', 'w') as json_file:
    json_file.write(json_str)
batch_size = 16
train_loader = torch.utils.data.DataLoader(train_dataset,
                                           batch_size=batch_size, shuffle=True,
                                           num_workers=0)
validate_dataset = datasets.ImageFolder(root=image_path + "val",
                                        transform=data_transform["val"])
val_num = len(validate_dataset)
validate_loader = torch.utils.data.DataLoader(validate_dataset,
                                              batch_size=batch_size, shuffle=False,
                                              num_workers=0)
net = resnet34()#一开始不能设置全连接层的输出种类为自己想要的,必须先将模型参数载入,再修改全连接层
# 官方提供载入预训练模型的方法
model_weight_path = "./resnet34-pre.pth"#权重路径
missing_keys, unexpected_keys = net.load_state_dict(torch.load(model_weight_path), strict=False)#载入模型权重
inchannel = net.fc.in_features
net.fc = nn.Linear(inchannel, 5)#重新确定全连接层
net.to(device)
loss_function = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), lr=0.0001)
best_acc = 0.0
save_path = './resNet34.pth'
for epoch in range(3):
    # train
    net.train()#控制BN层状态
    running_loss = 0.0
    for step, data in enumerate(train_loader, start=0):
        images, labels = data
        optimizer.zero_grad()
        logits = net(images.to(device))
        loss = loss_function(logits, labels.to(device))
        loss.backward()
        optimizer.step()
        # print statistics
        running_loss += loss.item()
        # print train process
        rate = (step+1)/len(train_loader)
        a = "*" * int(rate * 50)
        b = "." * int((1 - rate) * 50)
        print("\rtrain loss: {:^3.0f}%[{}->{}]{:.4f}".format(int(rate*100), a, b, loss), end="")
    print()
    # validate
    net.eval()#控制BN层状态
    acc = 0.0  # accumulate accurate number / epoch
    with torch.no_grad():
        for val_data in validate_loader:
            val_images, val_labels = val_data
            outputs = net(val_images.to(device))  # eval model only have last output layer
            # loss = loss_function(outputs, test_labels)
            predict_y = torch.max(outputs, dim=1)[1]
            acc += (predict_y == val_labels.to(device)).sum().item()
        val_accurate = acc / val_num
        if val_accurate > best_acc:
            best_acc = val_accurate
            torch.save(net.state_dict(), save_path)
        print('[epoch %d] train_loss: %.3f  test_accuracy: %.3f' %
              (epoch + 1, running_loss / step, val_accurate))
print('Finished Training')

预测部分:

import torch
from model import resnet34
from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as plt
import json
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
data_transform = transforms.Compose(
    [transforms.Resize(256),
     transforms.CenterCrop(224),
     transforms.ToTensor(),
     transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])#采用和训练方法一样的标准化处理
# load image
img = Image.open("../aa.jpg")
plt.imshow(img)
# [N, C, H, W]
img = data_transform(img)
# expand batch dimension
img = torch.unsqueeze(img, dim=0)
# read class_indict
try:
    json_file = open('./class_indices.json', 'r')
    class_indict = json.load(json_file)
except Exception as e:
    print(e)
    exit(-1)
# create model
model = resnet34(num_classes=5)
# load model weights
model_weight_path = "./resNet34.pth"
model.load_state_dict(torch.load(model_weight_path, map_location=device))#载入训练好的模型参数
model.eval()#使用eval()模式
with torch.no_grad():#不跟踪损失梯度
    # predict class
    output = torch.squeeze(model(img))#压缩batch维度
    predict = torch.softmax(output, dim=0)#通过softmax得到概率分布
    predict_cla = torch.argmax(predict).numpy()#寻找最大值所对应的索引
print(class_indict[str(predict_cla)], predict[predict_cla].numpy())#打印类别信息和概率
plt.show()

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python使用新浪微博api上传图片到微博示例
Jan 10 Python
python中xrange和range的区别
May 13 Python
python文件和目录操作函数小结
Jul 11 Python
Python实现线程池代码分享
Jun 21 Python
python实现Decorator模式实例代码
Feb 09 Python
python 中的list和array的不同之处及转换问题
Mar 13 Python
在Django中URL正则表达式匹配的方法
Dec 20 Python
对python:循环定义多个变量的实例详解
Jan 20 Python
Python read函数按字节(字符)读取文件的实现
Jul 03 Python
django rest framework 实现用户登录认证详解
Jul 29 Python
Python hashlib加密模块常用方法解析
Dec 18 Python
Python接口测试环境搭建过程详解
Jun 29 Python
如何使用flask将模型部署为服务
May 13 #Python
教你用python控制安卓手机
Python数据分析入门之数据读取与存储
May 13 #Python
python执行js代码的方法
pytorch加载预训练模型与自己模型不匹配的解决方案
May 13 #Python
Python数据分析入门之教你怎么搭建环境
Pytorch 统计模型参数量的操作 param.numel()
May 13 #Python
You might like
php通过sort()函数给数组排序的方法
2015/03/18 PHP
php中preg_match的isU代表什么意思
2015/10/01 PHP
apache和PHP如何整合在一起
2015/10/12 PHP
PHP基于双向链表与排序操作实现的会员排名功能示例
2017/12/26 PHP
javascript 面向对象 function类
2010/05/13 Javascript
jquery 模板的应用示例
2013/11/12 Javascript
Visual Studio中js调试的方法图解
2014/06/30 Javascript
JavaScript检测上传文件大小的方法
2015/07/22 Javascript
JavaScript常用正则验证函数实例小结【年龄,数字,Email,手机,URL,日期等】
2017/01/23 Javascript
原生js实现吸顶效果
2017/03/13 Javascript
jQuery插件HighCharts实现的2D面积图效果示例【附demo源码下载】
2017/03/15 Javascript
JavaScript切换搜索引擎的导航网页搜索框实例代码
2017/06/11 Javascript
基于Vue实现拖拽效果
2018/04/27 Javascript
基于Vue 2.0 监听文本框内容变化及ref的使用说明介绍
2018/08/24 Javascript
Webpack3+React16代码分割的实现
2021/03/03 Javascript
python获取各操作系统硬件信息的方法
2015/06/03 Python
python 实时得到cpu和内存的使用情况方法
2018/06/11 Python
python 实现A*算法的示例代码
2018/08/13 Python
Windows系统下PhantomJS的安装和基本用法
2018/10/21 Python
python实现Virginia无密钥解密
2019/03/20 Python
Django Rest framework频率原理与限制
2019/07/26 Python
用Python徒手撸一个股票回测框架搭建【推荐】
2019/08/05 Python
python的faker库用法
2019/11/28 Python
Python列表倒序输出及其效率详解
2020/03/04 Python
Django-simple-captcha验证码包使用方法详解
2020/11/28 Python
HTML5本地存储之Web Storage应用介绍
2013/01/06 HTML / CSS
教师自荐书
2013/10/08 职场文书
文员岗位职责范本
2014/03/08 职场文书
目标管理责任书
2014/04/15 职场文书
新闻学专业求职信
2014/07/28 职场文书
乡镇干部先进性教育活动个人整改措施
2014/09/16 职场文书
国庆节慰问信
2015/02/15 职场文书
2015年八一建军节慰问信
2015/03/23 职场文书
2016个人先进事迹材料范文
2016/03/01 职场文书
启迪人心的励志语录:脾气永远不要大于本事
2020/01/02 职场文书
详解CSS不定宽溢出文本适配滚动
2021/05/24 HTML / CSS