自己搭建resnet18网络并加载torchvision自带权重的操作


Posted in Python onMay 13, 2021

直接搭建网络必须与torchvision自带的网络的权重也就是pth文件的结构、尺寸和变量命名完全一致,否则无法加载权重文件。

此时可比较2个字典逐一加载

import torch
import torchvision
import cv2 as cv
from utils.utils import letter_box
from model.backbone import ResNet18

model1 = ResNet18(1)
model2 = torchvision.models.resnet18(progress=False)
fc = model2.fc
model2.fc = torch.nn.Linear(512, 1)
# print(model)
model_dict1 = model1.state_dict()
model_dict2 = torch.load('resnet18.pth')
model_list1 = list(model_dict1.keys())
model_list2 = list(model_dict2.keys())
len1 = len(model_list1)
len2 = len(model_list2)
minlen = min(len1, len2)
for n in range(minlen):
    if model_dict1[model_list1[n]].shape != model_dict2[model_list2[n]].shape:
        continue
    model_dict1[model_list1[n]] = model_dict2[model_list2[n]]
model1.load_state_dict(model_dict1)
missing, unspected = model2.load_state_dict(model_dict2)
image = cv.imread('zhn1.jpg')
image = letter_box(image, 224)
image = image[:, :, ::-1].transpose(2, 0, 1)
print('Network loading complete.')
model1.eval()
model2.eval()
with torch.no_grad():
    image = torch.tensor(image/256, dtype=torch.float32).unsqueeze(0)
    predict1 = model1(image)
    predict2 = model2(image)
print('finished')
# torch.save(model.state_dict(), 'resnet18.pth')

以上为全部程序,最终可测试原模型与加载了自带权重的自定义模型的输出是否相等。

补充:使用Pytorch搭建ResNet分类网络并基于迁移学习训练

如果stride=1,padding=1

卷积处理是不会改变特征矩阵的高和宽

使用BN层时

卷积中的参数bias置为False(有无偏置BN层的输出都相同),BN层放在conv层和relu层的中间

复习BN层:

Batch Norm 层是对每层数据归一化后再进行线性变换改善数据分布, 其中的线性变换是可学习的.

Batch Norm优点:减轻过拟合;改善梯度传播(权重不会过高或过低)容许较高的学习率,能够提高训练速度。减轻对初始化权重的强依赖,使得数据分布在激活函数的非饱和区域,一定程度上解决梯度消失问题。作为一种正则化的方式,在某种程度上减少对dropout的使用。

Batch Norm层摆放位置:在激活层(如 ReLU )之前还是之后,没有一个统一的定论。

BN层与 Dropout 合作:Batch Norm的提出使得dropout的使用减少,但是Batch Norm不能完全取代dropout,保留较小的dropout率,如0.2可能效果更佳。

为什么要先normalize再通过γ,β线性变换恢复接近原来的样子,这不是多此一举吗?

在一定条件下可以纠正原始数据的分布(方差,均值变为新值γ,β),当原始数据分布足够好时就是恒等映射,不改变分布。如果不做BN,方差和均值对前面网络的参数有复杂的关联依赖,具有复杂的非线性。在新参数 γH′ + β 中仅由 γ,β 确定,与前边网络的参数无关,因此新参数很容易通过梯度下降来学习,能够学习到较好的分布。

迁移学习导入权重和下载权重:

import torchvision.models.resnet#ctrl+鼠标左键点击即可下载权重
net = resnet34()#一开始不能设置全连接层的输出种类为自己想要的,必须先将模型参数载入,再修改全连接层
# 官方提供载入预训练模型的方法
model_weight_path = "./resnet34-pre.pth"#权重路径
missing_keys, unexpected_keys = net.load_state_dict(torch.load(model_weight_path), strict=False)#载入模型权重
inchannel = net.fc.in_features
net.fc = nn.Linear(inchannel, 5)#重新确定全连接层

完整代码:

model部分:

import torch.nn as nn
import torch
class BasicBlock(nn.Module):#对应18层和34层所对应的残差结构(既要有实线残差结构功能,也要有虚线残差结构功能)
    expansion = 1#残差结构主分支上的三个卷积层是否相同,相同为1,第三层是一二层四倍则为4
    def __init__(self, in_channel, out_channel, stride=1, downsample=None):#downsample代表虚线残差结构选项
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=in_channel, out_channels=out_channel,
                               kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channel)
        self.relu = nn.ReLU()
        self.conv2 = nn.Conv2d(in_channels=out_channel, out_channels=out_channel,
                               kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channel)
        self.downsample = downsample
    def forward(self, x):
        identity = x
        if self.downsample is not None:
            identity = self.downsample(x)#得到捷径分支的输出
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        out += identity
        out = self.relu(out)
        return out#得到残差结构的最终输出

class Bottleneck(nn.Module):#对应50层、101层和152层所对应的残差结构
    expansion = 4#第三层卷积核个数是第一层和第二层的四倍
    def __init__(self, in_channel, out_channel, stride=1, downsample=None):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=in_channel, out_channels=out_channel,
                               kernel_size=1, stride=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channel)
        self.conv2 = nn.Conv2d(in_channels=out_channel, out_channels=out_channel,
                               kernel_size=3, stride=stride, bias=False, padding=1)
        self.bn2 = nn.BatchNorm2d(out_channel)
        self.conv3 = nn.Conv2d(in_channels=out_channel, out_channels=out_channel*self.expansion,
                               kernel_size=1, stride=1, bias=False)
        self.bn3 = nn.BatchNorm2d(out_channel*self.expansion)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
    def forward(self, x):
        identity = x
        if self.downsample is not None:
            identity = self.downsample(x)
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)
        out = self.conv3(out)
        out = self.bn3(out)
        out += identity
        out = self.relu(out)
        return out

class ResNet(nn.Module):#定义整个网络的框架部分
#blocks_num是残差结构的数目,是一个列表参数,block对应哪个残差模块
    def __init__(self, block, blocks_num, num_classes=1000, include_top=True):
        super(ResNet, self).__init__()
        self.include_top = include_top
        self.in_channel = 64#通过第一个池化层后所得到的特征矩阵的深度
        self.conv1 = nn.Conv2d(3, self.in_channel, kernel_size=7, stride=2,
                               padding=3, bias=False)
        self.bn1 = nn.BatchNorm2d(self.in_channel)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, 64, blocks_num[0])
        self.layer2 = self._make_layer(block, 128, blocks_num[1], stride=2)
        self.layer3 = self._make_layer(block, 256, blocks_num[2], stride=2)
        self.layer4 = self._make_layer(block, 512, blocks_num[3], stride=2)
        if self.include_top:
            self.avgpool = nn.AdaptiveAvgPool2d((1, 1))  # output size = (1, 1)
            self.fc = nn.Linear(512 * block.expansion, num_classes)
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
    def _make_layer(self, block, channel, block_num, stride=1):#channel:残差结构中,第一个卷积层所使用的卷积核的个数
        downsample = None
        if stride != 1 or self.in_channel != channel * block.expansion:#18层和34层会直接跳过这个if语句
            downsample = nn.Sequential(
                nn.Conv2d(self.in_channel, channel * block.expansion, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(channel * block.expansion))
        layers = []
        layers.append(block(self.in_channel, channel, downsample=downsample, stride=stride))
        self.in_channel = channel * block.expansion
        for _ in range(1, block_num):
            layers.append(block(self.in_channel, channel))
        return nn.Sequential(*layers)
    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        if self.include_top:#默认是true
            x = self.avgpool(x)
            x = torch.flatten(x, 1)
            x = self.fc(x)
        return x

def resnet34(num_classes=1000, include_top=True):
    return ResNet(BasicBlock, [3, 4, 6, 3], num_classes=num_classes, include_top=include_top)

def resnet101(num_classes=1000, include_top=True):
    return ResNet(Bottleneck, [3, 4, 23, 3], num_classes=num_classes, include_top=include_top)

训练部分:

import torch
import torch.nn as nn
from torchvision import transforms, datasets
import json
import matplotlib.pyplot as plt
import os
import torch.optim as optim
from model import resnet34, resnet101
import torchvision.models.resnet#ctrl+鼠标左键点击即可下载权重
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)
data_transform = {
    "train": transforms.Compose([transforms.RandomResizedCrop(224),
                                 transforms.RandomHorizontalFlip(),
                                 transforms.ToTensor(),
                                 transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),#和官网初始化方法保持一致
    "val": transforms.Compose([transforms.Resize(256),
                               transforms.CenterCrop(224),
                               transforms.ToTensor(),
                               transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])}

data_root = os.path.abspath(os.path.join(os.getcwd(), "../.."))  # get data root path
image_path = data_root + "/data_set/flower_data/"  # flower data set path
train_dataset = datasets.ImageFolder(root=image_path+"train",
                                     transform=data_transform["train"])
train_num = len(train_dataset)
# {'daisy':0, 'dandelion':1, 'roses':2, 'sunflower':3, 'tulips':4}
flower_list = train_dataset.class_to_idx
cla_dict = dict((val, key) for key, val in flower_list.items())
# write dict into json file
json_str = json.dumps(cla_dict, indent=4)
with open('class_indices.json', 'w') as json_file:
    json_file.write(json_str)
batch_size = 16
train_loader = torch.utils.data.DataLoader(train_dataset,
                                           batch_size=batch_size, shuffle=True,
                                           num_workers=0)
validate_dataset = datasets.ImageFolder(root=image_path + "val",
                                        transform=data_transform["val"])
val_num = len(validate_dataset)
validate_loader = torch.utils.data.DataLoader(validate_dataset,
                                              batch_size=batch_size, shuffle=False,
                                              num_workers=0)
net = resnet34()#一开始不能设置全连接层的输出种类为自己想要的,必须先将模型参数载入,再修改全连接层
# 官方提供载入预训练模型的方法
model_weight_path = "./resnet34-pre.pth"#权重路径
missing_keys, unexpected_keys = net.load_state_dict(torch.load(model_weight_path), strict=False)#载入模型权重
inchannel = net.fc.in_features
net.fc = nn.Linear(inchannel, 5)#重新确定全连接层
net.to(device)
loss_function = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), lr=0.0001)
best_acc = 0.0
save_path = './resNet34.pth'
for epoch in range(3):
    # train
    net.train()#控制BN层状态
    running_loss = 0.0
    for step, data in enumerate(train_loader, start=0):
        images, labels = data
        optimizer.zero_grad()
        logits = net(images.to(device))
        loss = loss_function(logits, labels.to(device))
        loss.backward()
        optimizer.step()
        # print statistics
        running_loss += loss.item()
        # print train process
        rate = (step+1)/len(train_loader)
        a = "*" * int(rate * 50)
        b = "." * int((1 - rate) * 50)
        print("\rtrain loss: {:^3.0f}%[{}->{}]{:.4f}".format(int(rate*100), a, b, loss), end="")
    print()
    # validate
    net.eval()#控制BN层状态
    acc = 0.0  # accumulate accurate number / epoch
    with torch.no_grad():
        for val_data in validate_loader:
            val_images, val_labels = val_data
            outputs = net(val_images.to(device))  # eval model only have last output layer
            # loss = loss_function(outputs, test_labels)
            predict_y = torch.max(outputs, dim=1)[1]
            acc += (predict_y == val_labels.to(device)).sum().item()
        val_accurate = acc / val_num
        if val_accurate > best_acc:
            best_acc = val_accurate
            torch.save(net.state_dict(), save_path)
        print('[epoch %d] train_loss: %.3f  test_accuracy: %.3f' %
              (epoch + 1, running_loss / step, val_accurate))
print('Finished Training')

预测部分:

import torch
from model import resnet34
from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as plt
import json
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
data_transform = transforms.Compose(
    [transforms.Resize(256),
     transforms.CenterCrop(224),
     transforms.ToTensor(),
     transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])#采用和训练方法一样的标准化处理
# load image
img = Image.open("../aa.jpg")
plt.imshow(img)
# [N, C, H, W]
img = data_transform(img)
# expand batch dimension
img = torch.unsqueeze(img, dim=0)
# read class_indict
try:
    json_file = open('./class_indices.json', 'r')
    class_indict = json.load(json_file)
except Exception as e:
    print(e)
    exit(-1)
# create model
model = resnet34(num_classes=5)
# load model weights
model_weight_path = "./resNet34.pth"
model.load_state_dict(torch.load(model_weight_path, map_location=device))#载入训练好的模型参数
model.eval()#使用eval()模式
with torch.no_grad():#不跟踪损失梯度
    # predict class
    output = torch.squeeze(model(img))#压缩batch维度
    predict = torch.softmax(output, dim=0)#通过softmax得到概率分布
    predict_cla = torch.argmax(predict).numpy()#寻找最大值所对应的索引
print(class_indict[str(predict_cla)], predict[predict_cla].numpy())#打印类别信息和概率
plt.show()

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python encode和decode的妙用
Sep 02 Python
python在windows下实现备份程序实例
Jul 04 Python
优化Python代码使其加快作用域内的查找
Mar 30 Python
Django URL传递参数的方法总结
Aug 28 Python
Python正则表达式教程之三:贪婪/非贪婪特性
Mar 02 Python
Python3 适合初学者学习的银行账户登录系统实例
Aug 08 Python
python实现远程通过网络邮件控制计算机重启或关机
Feb 22 Python
利用python循环创建多个文件的方法
Oct 25 Python
Python txt文件加入字典并查询的方法
Jan 15 Python
详解Python Matplot中文显示完美解决方案
Mar 07 Python
python算法与数据结构之单链表的实现代码
Jun 27 Python
解决Django删除migrations文件夹中的文件后出现的异常问题
Aug 31 Python
如何使用flask将模型部署为服务
May 13 #Python
教你用python控制安卓手机
Python数据分析入门之数据读取与存储
May 13 #Python
python执行js代码的方法
pytorch加载预训练模型与自己模型不匹配的解决方案
May 13 #Python
Python数据分析入门之教你怎么搭建环境
Pytorch 统计模型参数量的操作 param.numel()
May 13 #Python
You might like
一个简单实现多条件查询的例子
2006/10/09 PHP
php动态实现表格跨行跨列实现代码
2012/11/06 PHP
PHP中通过HTTP_USER_AGENT判断是否为手机移动终端的函数代码
2013/02/14 PHP
php使用pack处理二进制文件的方法
2014/07/03 PHP
PHP基于单例模式编写PDO类的方法
2016/09/13 PHP
php实现连接access数据库并转txt写入的方法
2017/02/08 PHP
Laravel框架实现利用中间件进行操作日志记录功能
2018/06/06 PHP
php ZipArchive实现多文件打包下载实例
2019/10/31 PHP
js跟随滚动条滚动浮动代码
2009/12/31 Javascript
jquery之empty()与remove()区别说明
2010/09/10 Javascript
eval的两组性能测试数据
2012/08/17 Javascript
javascript中方便增删改cookie的一个类
2012/10/11 Javascript
Jquery的each里用return true或false代替break或continue
2014/05/21 Javascript
js 判断图片是否加载完以及实现图片的预下载
2014/08/14 Javascript
Node.js实现在目录中查找某个字符串及所在文件
2014/09/03 Javascript
详解webpack es6 to es5支持配置
2017/05/04 Javascript
JavaScript实现无穷滚动加载数据
2017/05/06 Javascript
ES6 Generator基本使用方法示例
2020/06/06 Javascript
vue tab滚动到一定高度,固定在顶部,点击tab切换不同的内容操作
2020/07/22 Javascript
使用vue引入maptalks地图及聚合效果的实现
2020/08/10 Javascript
python解析中国天气网的天气数据
2014/03/21 Python
如何将python中的List转化成dictionary
2016/08/15 Python
Python标准库之collections包的使用教程
2017/04/27 Python
Python tornado队列示例-一个并发web爬虫代码分享
2018/01/09 Python
Python实现带参数与不带参数的多重继承示例
2018/01/30 Python
python二维码操作:对QRCode和MyQR入门详解
2019/06/24 Python
Python中 CSV格式清洗与转换的实例代码
2019/08/29 Python
解决python中的幂函数、指数函数问题
2019/11/25 Python
Kaufmann Mercantile官网:家居装饰、配件、户外及更多
2018/09/28 全球购物
Hello Molly美国:女性时尚在线
2019/08/26 全球购物
信息系统专业个人求职信范文
2013/12/07 职场文书
工地资料员岗位职责
2013/12/31 职场文书
2014年元旦促销活动方案
2014/02/22 职场文书
《吃水不忘挖井人》教学反思
2014/04/15 职场文书
常务副总经理任命书
2014/06/05 职场文书
教师反邪教心得体会
2016/01/15 职场文书