pytorch实现加载保存查看checkpoint文件


Posted in Python onJuly 15, 2022

1.保存加载checkpoint文件

# 方式一:保存加载整个state_dict(推荐)
# 保存
torch.save(model.state_dict(), PATH)
# 加载
model.load_state_dict(torch.load(PATH))
# 测试时不启用 BatchNormalization 和 Dropout
model.eval()
# 方式二:保存加载整个模型
# 保存
torch.save(model, PATH)
# 加载
model = torch.load(PATH)
model.eval()
# 方式三:保存用于继续训练的checkpoint或者多个模型
# 保存
torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            ...
            }, PATH)
# 加载
checkpoint = torch.load(PATH)
start_epoch=checkpoint['epoch']
model.load_state_dict(checkpoint['model_state_dict'])
# 测试时
model.eval()
# 或者训练时
model.train()

2.跨gpu和cpu

# GPU上保存,CPU上加载
# 保存
torch.save(model.state_dict(), PATH)
# 加载
device = torch.device('cpu')
model.load_state_dict(torch.load(PATH, map_location=device))
# 如果是多gpu保存,需要去除关键字中的module,见第4部分
# GPU上保存,GPU上加载
# 保存
torch.save(model.state_dict(), PATH)
# 加载
device = torch.device("cuda")
model.load_state_dict(torch.load(PATH))
model.to(device)
# CPU上保存,GPU上加载
# 保存
torch.save(model.state_dict(), PATH)
# 加载
device = torch.device("cuda")
# 选择希望使用的GPU
model.load_state_dict(torch.load(PATH, map_location="cuda:0"))  
model.to(device)

3.查看checkpoint文件内容

# 打印模型的 state_dict
print("Model's state_dict:")
for param_tensor in model.state_dict():
    print(param_tensor, "\t", model.state_dict()[param_tensor].size())

4.常见问题

多gpu

报错为KeyError: ‘unexpected key “module.conv1.weight” in state_dict’

原因:当使用多gpu时,会使用torch.nn.DataParallel,所以checkpoint中有module字样

#解决1:加载时将module去掉

# 创建一个不包含`module.`的新OrderedDict
from collections import OrderedDict
new_state_dict = OrderedDict()
for k, v in state_dict.items():
    name = k[7:] # 去掉 `module.`
    new_state_dict[name] = v
# 加载参数
model.load_state_dict(new_state_dict)
# 解决2:保存checkpoint时不保存module
torch.save(model.module.state_dict(), PATH)

pytorch保存和加载文件的方法,从断点处继续训练

'''本文件用于举例说明pytorch保存和加载文件的方法''' 
import torch as torch
import torchvision as tv
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.transforms as transforms
import os
  
# 参数声明
batch_size = 32
epochs = 10
WORKERS = 0  # dataloder线程数
test_flag = False  # 测试标志,True时加载保存好的模型进行测试
ROOT = '/home/pxt/pytorch/cifar'  # MNIST数据集保存路径
log_dir = '/home/pxt/pytorch/logs/cifar_model.pth'  # 模型保存路径
# 加载MNIST数据集
transform = tv.transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
 
train_data = tv.datasets.CIFAR10(root=ROOT, train=True, download=True, transform=transform)
test_data = tv.datasets.CIFAR10(root=ROOT, train=False, download=False, transform=transform)
 
train_load = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=WORKERS)
test_load = torch.utils.data.DataLoader(test_data, batch_size=batch_size, shuffle=False, num_workers=WORKERS)
 
 
# 构造模型
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, 3, padding=1)
        self.conv2 = nn.Conv2d(64, 128, 3, padding=1)
        self.conv3 = nn.Conv2d(128, 256, 3, padding=1)
        self.conv4 = nn.Conv2d(256, 256, 3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(256 * 8 * 8, 1024)
        self.fc2 = nn.Linear(1024, 256)
        self.fc3 = nn.Linear(256, 10)
 
    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = self.pool(F.relu(self.conv2(x)))
        x = F.relu(self.conv3(x))
        x = self.pool(F.relu(self.conv4(x)))
        x = x.view(-1, x.size()[1] * x.size()[2] * x.size()[3])
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x
  
model = Net().cpu()
 
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)
 
 
# 模型训练
def train(model, train_loader, epoch):
    model.train()
    train_loss = 0
    for i, data in enumerate(train_loader, 0):
        x, y = data
        x = x.cpu()
        y = y.cpu()
 
        optimizer.zero_grad()
        y_hat = model(x)
        loss = criterion(y_hat, y)
        loss.backward()
        optimizer.step()
        train_loss += loss
        print('正在进行第{}个epoch中的第{}次循环'.format(epoch,i))
 
    loss_mean = train_loss / (i + 1)
    print('Train Epoch: {}\t Loss: {:.6f}'.format(epoch, loss_mean.item()))
 
 
# 模型测试
def test(model, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for i, data in enumerate(test_loader, 0):
            x, y = data
            x = x.cpu()
            y = y.cpu()
 
            optimizer.zero_grad()
            y_hat = model(x)
            test_loss += criterion(y_hat, y).item()
            pred = y_hat.max(1, keepdim=True)[1]
            correct += pred.eq(y.view_as(pred)).sum().item()
        test_loss /= (i + 1)
        print('Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
            test_loss, correct, len(test_data), 100. * correct / len(test_data)))
  
def main():
    # 如果test_flag=True,则加载已保存的模型并进行测试,测试以后不进行此模块以后的步骤
    if test_flag:
        # 加载保存的模型直接进行测试机验证
        checkpoint = torch.load(log_dir)
        model.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        start_epoch = checkpoint['epoch']
        test(model, test_load)
        return
 
    # 如果有保存的模型,则加载模型,并在其基础上继续训练
    if os.path.exists(log_dir):
        checkpoint = torch.load(log_dir)
        model.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        start_epoch = checkpoint['epoch']
        print('加载 epoch {} 成功!'.format(start_epoch))
    else:
        start_epoch = 0
        print('无保存了的模型,将从头开始训练!')
 
    for epoch in range(start_epoch+1, epochs):
        train(model, train_load, epoch)
        test(model, test_load)
        # 保存模型
        state = {'model':model.state_dict(), 'optimizer':optimizer.state_dict(), 'epoch':epoch}
        torch.save(state, log_dir)
 
if __name__ == '__main__':
    main()

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python实现的一个火车票转让信息采集器
Jul 09 Python
Python编程使用tkinter模块实现计算器软件完整代码示例
Nov 29 Python
python实现手机通讯录搜索功能
Feb 22 Python
python中实现将多个print输出合成一个数组
Apr 19 Python
PyQt5每天必学之布局管理
Apr 19 Python
Python定义二叉树及4种遍历方法实例详解
Jul 05 Python
在python中实现对list求和及求积
Nov 14 Python
Python图像处理之gif动态图的解析与合成操作详解
Dec 30 Python
如何在django中添加日志功能
Feb 06 Python
Spring http服务远程调用实现过程解析
Jun 11 Python
python中的插入排序的简单用法
Jan 19 Python
Python爬虫进阶之Beautiful Soup库详解
Apr 29 Python
pytest实现多进程与多线程运行超好用的插件
Jul 15 #Python
python如何将mat文件转为png
Jul 15 #Python
python读取mat文件生成h5文件的实现
Jul 15 #Python
全网非常详细的pytest配置文件
Jul 15 #Python
Python如何加载模型并查看网络
Jul 15 #Python
Python绘制散点图之可视化神器pyecharts
Jul 07 #Python
Python可视化神器pyecharts之绘制箱形图
Jul 07 #Python
You might like
使用PHP生成二维码的方法汇总
2015/07/22 PHP
php实现批量删除挂马文件及批量替换页面内容完整实例
2016/07/08 PHP
[原创]PHP正则匹配中英文、数字及下划线的方法【用户名验证】
2017/08/01 PHP
PHP设计模式之工厂模式定义与用法详解
2018/04/03 PHP
微信公众平台开发教程③ PHP实现微信公众号支付功能图文详解
2019/04/10 PHP
PHP实现页面静态化深入讲解
2021/03/04 PHP
多种方法实现JS动态添加事件
2013/11/01 Javascript
jquery制作居中遮罩层效果分享
2014/02/21 Javascript
jQuery解析XML与传统JavaScript方法的差别实例分析
2015/03/05 Javascript
使用jQuery实现图片遮罩半透明坠落遮挡
2015/03/16 Javascript
EasyUi中的Combogrid 实现分页和动态搜索远程数据
2016/04/01 Javascript
Bootstrap select多选下拉框实现代码
2016/12/23 Javascript
jQuery Validate表单验证插件的基本使用方法及功能拓展
2017/01/04 Javascript
详解JavaScript按概率随机生成事件
2017/08/02 Javascript
Vue2仿淘宝实现省市区三级联动
2020/04/15 Javascript
微信小程序template模板与component组件的区别和使用详解
2019/05/22 Javascript
vue中改变滚动条样式的方法
2020/03/03 Javascript
vue开发简单上传图片功能
2020/06/30 Javascript
探究Python的Tornado框架对子域名和泛域名的支持
2015/05/02 Python
浅谈django的render函数的参数问题
2018/10/16 Python
Flask框架实现的前端RSA加密与后端Python解密功能详解
2019/08/13 Python
详解有关PyCharm安装库失败的问题的解决方法
2020/02/02 Python
Django 解决开发自定义抛出异常的问题
2020/05/21 Python
python+excel接口自动化获取token并作为请求参数进行传参操作
2020/11/10 Python
HTML5 Convas APIs方法详解
2015/04/24 HTML / CSS
工程管理造价应届生求职信
2013/11/13 职场文书
决定成败的关键——创业计划书
2014/01/24 职场文书
函授药学自我鉴定
2014/02/07 职场文书
高中军训感想800字
2014/02/23 职场文书
《音乐之都维也纳》教学反思
2014/04/16 职场文书
积极向上的团队口号
2014/06/06 职场文书
党的群众路线教育实践活动党员个人剖析材料
2014/10/08 职场文书
交警正风肃纪剖析材料
2014/10/29 职场文书
和领导吃饭祝酒词
2015/08/11 职场文书
创业者如何撰写出一份打动投资人的商业计划书?
2019/07/02 职场文书
bootstrapv4轮播图去除两侧阴影及线框的方法
2022/02/15 HTML / CSS