自己搭建resnet18网络并加载torchvision自带权重的操作


Posted in Python onMay 13, 2021

直接搭建网络必须与torchvision自带的网络的权重也就是pth文件的结构、尺寸和变量命名完全一致,否则无法加载权重文件。

此时可比较2个字典逐一加载

import torch
import torchvision
import cv2 as cv
from utils.utils import letter_box
from model.backbone import ResNet18

model1 = ResNet18(1)
model2 = torchvision.models.resnet18(progress=False)
fc = model2.fc
model2.fc = torch.nn.Linear(512, 1)
# print(model)
model_dict1 = model1.state_dict()
model_dict2 = torch.load('resnet18.pth')
model_list1 = list(model_dict1.keys())
model_list2 = list(model_dict2.keys())
len1 = len(model_list1)
len2 = len(model_list2)
minlen = min(len1, len2)
for n in range(minlen):
    if model_dict1[model_list1[n]].shape != model_dict2[model_list2[n]].shape:
        continue
    model_dict1[model_list1[n]] = model_dict2[model_list2[n]]
model1.load_state_dict(model_dict1)
missing, unspected = model2.load_state_dict(model_dict2)
image = cv.imread('zhn1.jpg')
image = letter_box(image, 224)
image = image[:, :, ::-1].transpose(2, 0, 1)
print('Network loading complete.')
model1.eval()
model2.eval()
with torch.no_grad():
    image = torch.tensor(image/256, dtype=torch.float32).unsqueeze(0)
    predict1 = model1(image)
    predict2 = model2(image)
print('finished')
# torch.save(model.state_dict(), 'resnet18.pth')

以上为全部程序,最终可测试原模型与加载了自带权重的自定义模型的输出是否相等。

补充:使用Pytorch搭建ResNet分类网络并基于迁移学习训练

如果stride=1,padding=1

卷积处理是不会改变特征矩阵的高和宽

使用BN层时

卷积中的参数bias置为False(有无偏置BN层的输出都相同),BN层放在conv层和relu层的中间

复习BN层:

Batch Norm 层是对每层数据归一化后再进行线性变换改善数据分布, 其中的线性变换是可学习的.

Batch Norm优点:减轻过拟合;改善梯度传播(权重不会过高或过低)容许较高的学习率,能够提高训练速度。减轻对初始化权重的强依赖,使得数据分布在激活函数的非饱和区域,一定程度上解决梯度消失问题。作为一种正则化的方式,在某种程度上减少对dropout的使用。

Batch Norm层摆放位置:在激活层(如 ReLU )之前还是之后,没有一个统一的定论。

BN层与 Dropout 合作:Batch Norm的提出使得dropout的使用减少,但是Batch Norm不能完全取代dropout,保留较小的dropout率,如0.2可能效果更佳。

为什么要先normalize再通过γ,β线性变换恢复接近原来的样子,这不是多此一举吗?

在一定条件下可以纠正原始数据的分布(方差,均值变为新值γ,β),当原始数据分布足够好时就是恒等映射,不改变分布。如果不做BN,方差和均值对前面网络的参数有复杂的关联依赖,具有复杂的非线性。在新参数 γH′ + β 中仅由 γ,β 确定,与前边网络的参数无关,因此新参数很容易通过梯度下降来学习,能够学习到较好的分布。

迁移学习导入权重和下载权重:

import torchvision.models.resnet#ctrl+鼠标左键点击即可下载权重
net = resnet34()#一开始不能设置全连接层的输出种类为自己想要的,必须先将模型参数载入,再修改全连接层
# 官方提供载入预训练模型的方法
model_weight_path = "./resnet34-pre.pth"#权重路径
missing_keys, unexpected_keys = net.load_state_dict(torch.load(model_weight_path), strict=False)#载入模型权重
inchannel = net.fc.in_features
net.fc = nn.Linear(inchannel, 5)#重新确定全连接层

完整代码:

model部分:

import torch.nn as nn
import torch
class BasicBlock(nn.Module):#对应18层和34层所对应的残差结构(既要有实线残差结构功能,也要有虚线残差结构功能)
    expansion = 1#残差结构主分支上的三个卷积层是否相同,相同为1,第三层是一二层四倍则为4
    def __init__(self, in_channel, out_channel, stride=1, downsample=None):#downsample代表虚线残差结构选项
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=in_channel, out_channels=out_channel,
                               kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channel)
        self.relu = nn.ReLU()
        self.conv2 = nn.Conv2d(in_channels=out_channel, out_channels=out_channel,
                               kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channel)
        self.downsample = downsample
    def forward(self, x):
        identity = x
        if self.downsample is not None:
            identity = self.downsample(x)#得到捷径分支的输出
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        out += identity
        out = self.relu(out)
        return out#得到残差结构的最终输出

class Bottleneck(nn.Module):#对应50层、101层和152层所对应的残差结构
    expansion = 4#第三层卷积核个数是第一层和第二层的四倍
    def __init__(self, in_channel, out_channel, stride=1, downsample=None):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=in_channel, out_channels=out_channel,
                               kernel_size=1, stride=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channel)
        self.conv2 = nn.Conv2d(in_channels=out_channel, out_channels=out_channel,
                               kernel_size=3, stride=stride, bias=False, padding=1)
        self.bn2 = nn.BatchNorm2d(out_channel)
        self.conv3 = nn.Conv2d(in_channels=out_channel, out_channels=out_channel*self.expansion,
                               kernel_size=1, stride=1, bias=False)
        self.bn3 = nn.BatchNorm2d(out_channel*self.expansion)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
    def forward(self, x):
        identity = x
        if self.downsample is not None:
            identity = self.downsample(x)
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)
        out = self.conv3(out)
        out = self.bn3(out)
        out += identity
        out = self.relu(out)
        return out

class ResNet(nn.Module):#定义整个网络的框架部分
#blocks_num是残差结构的数目,是一个列表参数,block对应哪个残差模块
    def __init__(self, block, blocks_num, num_classes=1000, include_top=True):
        super(ResNet, self).__init__()
        self.include_top = include_top
        self.in_channel = 64#通过第一个池化层后所得到的特征矩阵的深度
        self.conv1 = nn.Conv2d(3, self.in_channel, kernel_size=7, stride=2,
                               padding=3, bias=False)
        self.bn1 = nn.BatchNorm2d(self.in_channel)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, 64, blocks_num[0])
        self.layer2 = self._make_layer(block, 128, blocks_num[1], stride=2)
        self.layer3 = self._make_layer(block, 256, blocks_num[2], stride=2)
        self.layer4 = self._make_layer(block, 512, blocks_num[3], stride=2)
        if self.include_top:
            self.avgpool = nn.AdaptiveAvgPool2d((1, 1))  # output size = (1, 1)
            self.fc = nn.Linear(512 * block.expansion, num_classes)
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
    def _make_layer(self, block, channel, block_num, stride=1):#channel:残差结构中,第一个卷积层所使用的卷积核的个数
        downsample = None
        if stride != 1 or self.in_channel != channel * block.expansion:#18层和34层会直接跳过这个if语句
            downsample = nn.Sequential(
                nn.Conv2d(self.in_channel, channel * block.expansion, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(channel * block.expansion))
        layers = []
        layers.append(block(self.in_channel, channel, downsample=downsample, stride=stride))
        self.in_channel = channel * block.expansion
        for _ in range(1, block_num):
            layers.append(block(self.in_channel, channel))
        return nn.Sequential(*layers)
    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        if self.include_top:#默认是true
            x = self.avgpool(x)
            x = torch.flatten(x, 1)
            x = self.fc(x)
        return x

def resnet34(num_classes=1000, include_top=True):
    return ResNet(BasicBlock, [3, 4, 6, 3], num_classes=num_classes, include_top=include_top)

def resnet101(num_classes=1000, include_top=True):
    return ResNet(Bottleneck, [3, 4, 23, 3], num_classes=num_classes, include_top=include_top)

训练部分:

import torch
import torch.nn as nn
from torchvision import transforms, datasets
import json
import matplotlib.pyplot as plt
import os
import torch.optim as optim
from model import resnet34, resnet101
import torchvision.models.resnet#ctrl+鼠标左键点击即可下载权重
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)
data_transform = {
    "train": transforms.Compose([transforms.RandomResizedCrop(224),
                                 transforms.RandomHorizontalFlip(),
                                 transforms.ToTensor(),
                                 transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),#和官网初始化方法保持一致
    "val": transforms.Compose([transforms.Resize(256),
                               transforms.CenterCrop(224),
                               transforms.ToTensor(),
                               transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])}

data_root = os.path.abspath(os.path.join(os.getcwd(), "../.."))  # get data root path
image_path = data_root + "/data_set/flower_data/"  # flower data set path
train_dataset = datasets.ImageFolder(root=image_path+"train",
                                     transform=data_transform["train"])
train_num = len(train_dataset)
# {'daisy':0, 'dandelion':1, 'roses':2, 'sunflower':3, 'tulips':4}
flower_list = train_dataset.class_to_idx
cla_dict = dict((val, key) for key, val in flower_list.items())
# write dict into json file
json_str = json.dumps(cla_dict, indent=4)
with open('class_indices.json', 'w') as json_file:
    json_file.write(json_str)
batch_size = 16
train_loader = torch.utils.data.DataLoader(train_dataset,
                                           batch_size=batch_size, shuffle=True,
                                           num_workers=0)
validate_dataset = datasets.ImageFolder(root=image_path + "val",
                                        transform=data_transform["val"])
val_num = len(validate_dataset)
validate_loader = torch.utils.data.DataLoader(validate_dataset,
                                              batch_size=batch_size, shuffle=False,
                                              num_workers=0)
net = resnet34()#一开始不能设置全连接层的输出种类为自己想要的,必须先将模型参数载入,再修改全连接层
# 官方提供载入预训练模型的方法
model_weight_path = "./resnet34-pre.pth"#权重路径
missing_keys, unexpected_keys = net.load_state_dict(torch.load(model_weight_path), strict=False)#载入模型权重
inchannel = net.fc.in_features
net.fc = nn.Linear(inchannel, 5)#重新确定全连接层
net.to(device)
loss_function = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), lr=0.0001)
best_acc = 0.0
save_path = './resNet34.pth'
for epoch in range(3):
    # train
    net.train()#控制BN层状态
    running_loss = 0.0
    for step, data in enumerate(train_loader, start=0):
        images, labels = data
        optimizer.zero_grad()
        logits = net(images.to(device))
        loss = loss_function(logits, labels.to(device))
        loss.backward()
        optimizer.step()
        # print statistics
        running_loss += loss.item()
        # print train process
        rate = (step+1)/len(train_loader)
        a = "*" * int(rate * 50)
        b = "." * int((1 - rate) * 50)
        print("\rtrain loss: {:^3.0f}%[{}->{}]{:.4f}".format(int(rate*100), a, b, loss), end="")
    print()
    # validate
    net.eval()#控制BN层状态
    acc = 0.0  # accumulate accurate number / epoch
    with torch.no_grad():
        for val_data in validate_loader:
            val_images, val_labels = val_data
            outputs = net(val_images.to(device))  # eval model only have last output layer
            # loss = loss_function(outputs, test_labels)
            predict_y = torch.max(outputs, dim=1)[1]
            acc += (predict_y == val_labels.to(device)).sum().item()
        val_accurate = acc / val_num
        if val_accurate > best_acc:
            best_acc = val_accurate
            torch.save(net.state_dict(), save_path)
        print('[epoch %d] train_loss: %.3f  test_accuracy: %.3f' %
              (epoch + 1, running_loss / step, val_accurate))
print('Finished Training')

预测部分:

import torch
from model import resnet34
from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as plt
import json
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
data_transform = transforms.Compose(
    [transforms.Resize(256),
     transforms.CenterCrop(224),
     transforms.ToTensor(),
     transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])#采用和训练方法一样的标准化处理
# load image
img = Image.open("../aa.jpg")
plt.imshow(img)
# [N, C, H, W]
img = data_transform(img)
# expand batch dimension
img = torch.unsqueeze(img, dim=0)
# read class_indict
try:
    json_file = open('./class_indices.json', 'r')
    class_indict = json.load(json_file)
except Exception as e:
    print(e)
    exit(-1)
# create model
model = resnet34(num_classes=5)
# load model weights
model_weight_path = "./resNet34.pth"
model.load_state_dict(torch.load(model_weight_path, map_location=device))#载入训练好的模型参数
model.eval()#使用eval()模式
with torch.no_grad():#不跟踪损失梯度
    # predict class
    output = torch.squeeze(model(img))#压缩batch维度
    predict = torch.softmax(output, dim=0)#通过softmax得到概率分布
    predict_cla = torch.argmax(predict).numpy()#寻找最大值所对应的索引
print(class_indict[str(predict_cla)], predict[predict_cla].numpy())#打印类别信息和概率
plt.show()

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python数据结构之图的实现方法
Jul 08 Python
简单谈谈Python中的闭包
Nov 30 Python
Python实现mysql数据库更新表数据接口的功能
Nov 19 Python
Python高级用法总结
May 26 Python
python配置grpc环境
Jan 01 Python
django主动抛出403异常的方法详解
Jan 04 Python
深入浅析python3中的unicode和bytes问题
Jul 03 Python
python实现飞船游戏的纵向移动
Apr 24 Python
Python如何把十进制数转换成ip地址
May 25 Python
Keras 快速解决OOM超内存的问题
Jun 11 Python
python使用matplotlib:subplot绘制多个子图的示例
Sep 24 Python
python实现快速文件格式批量转换的方法
Oct 16 Python
如何使用flask将模型部署为服务
May 13 #Python
教你用python控制安卓手机
Python数据分析入门之数据读取与存储
May 13 #Python
python执行js代码的方法
pytorch加载预训练模型与自己模型不匹配的解决方案
May 13 #Python
Python数据分析入门之教你怎么搭建环境
Pytorch 统计模型参数量的操作 param.numel()
May 13 #Python
You might like
为什么《星际争霸》是测试人工智能的理想战场
2019/12/03 星际争霸
php文档工具PHP Documentor安装与使用方法
2016/01/25 PHP
PHP将二维数组某一个字段相同的数组合并起来的方法
2016/02/26 PHP
php实现URL加密解密的方法
2016/11/17 PHP
laravel-admin的图片删除实例
2019/09/30 PHP
js判断变量是否未定义的代码
2020/03/28 Javascript
使用jquery实现图文切换效果另加特效
2013/01/20 Javascript
跟我学Node.js(四)---Node.js的模块载入方式与机制
2014/06/04 Javascript
js实现的牛顿摆效果
2015/03/31 Javascript
jQuery 1.9.1源码分析系列(十三)之位置大小操作
2015/12/02 Javascript
EasyUi中的Combogrid 实现分页和动态搜索远程数据
2016/04/01 Javascript
jQuery添加options点击事件并传值实例代码
2016/05/18 Javascript
利用python分析access日志的方法
2016/10/26 Javascript
微信小程序侧边栏滑动特效(左右滑动)
2017/01/23 Javascript
微信小程序 动态绑定事件并实现事件修改样式
2017/04/13 Javascript
vue脚手架中配置Sass的方法
2018/01/04 Javascript
Nodejs异步流程框架async的方法
2019/06/07 NodeJs
layui表单验证select下拉框实现验证的方法
2019/09/05 Javascript
pycharm 使用心得(八)如何调用另一文件中的函数
2014/06/06 Python
Linux下编译安装MySQL-Python教程
2015/02/02 Python
利用Python实现命令行版的火车票查看器
2016/08/05 Python
详解python中的线程
2018/02/10 Python
python 并发编程 多路复用IO模型详解
2019/08/20 Python
移动端开发HTML5页面点击按钮后出现闪烁或黑色背景的解决办法
2018/09/19 HTML / CSS
C语言基础笔试题
2013/04/27 面试题
linux面试题参考答案(5)
2016/11/05 面试题
办公室前台岗位职责范本
2013/12/10 职场文书
医院实习介绍信
2014/01/12 职场文书
优秀班干部事迹材料
2014/01/26 职场文书
森林病虫害防治方案
2014/06/02 职场文书
设计专业自荐信
2014/06/19 职场文书
庆七一活动总结
2014/08/27 职场文书
高中生毕业评语
2014/12/30 职场文书
pytorch fine-tune 预训练的模型操作
2021/06/03 Python
25张裸眼3D图片,带你重温童年的记忆,感受3D的魅力
2022/02/06 杂记
如何使用注解方式实现 Redis 分布式锁
2022/07/23 Redis