pytorch实现加载保存查看checkpoint文件


Posted in Python onJuly 15, 2022

1.保存加载checkpoint文件

# 方式一:保存加载整个state_dict(推荐)
# 保存
torch.save(model.state_dict(), PATH)
# 加载
model.load_state_dict(torch.load(PATH))
# 测试时不启用 BatchNormalization 和 Dropout
model.eval()
# 方式二:保存加载整个模型
# 保存
torch.save(model, PATH)
# 加载
model = torch.load(PATH)
model.eval()
# 方式三:保存用于继续训练的checkpoint或者多个模型
# 保存
torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            ...
            }, PATH)
# 加载
checkpoint = torch.load(PATH)
start_epoch=checkpoint['epoch']
model.load_state_dict(checkpoint['model_state_dict'])
# 测试时
model.eval()
# 或者训练时
model.train()

2.跨gpu和cpu

# GPU上保存,CPU上加载
# 保存
torch.save(model.state_dict(), PATH)
# 加载
device = torch.device('cpu')
model.load_state_dict(torch.load(PATH, map_location=device))
# 如果是多gpu保存,需要去除关键字中的module,见第4部分
# GPU上保存,GPU上加载
# 保存
torch.save(model.state_dict(), PATH)
# 加载
device = torch.device("cuda")
model.load_state_dict(torch.load(PATH))
model.to(device)
# CPU上保存,GPU上加载
# 保存
torch.save(model.state_dict(), PATH)
# 加载
device = torch.device("cuda")
# 选择希望使用的GPU
model.load_state_dict(torch.load(PATH, map_location="cuda:0"))  
model.to(device)

3.查看checkpoint文件内容

# 打印模型的 state_dict
print("Model's state_dict:")
for param_tensor in model.state_dict():
    print(param_tensor, "\t", model.state_dict()[param_tensor].size())

4.常见问题

多gpu

报错为KeyError: ‘unexpected key “module.conv1.weight” in state_dict’

原因:当使用多gpu时,会使用torch.nn.DataParallel,所以checkpoint中有module字样

#解决1:加载时将module去掉

# 创建一个不包含`module.`的新OrderedDict
from collections import OrderedDict
new_state_dict = OrderedDict()
for k, v in state_dict.items():
    name = k[7:] # 去掉 `module.`
    new_state_dict[name] = v
# 加载参数
model.load_state_dict(new_state_dict)
# 解决2:保存checkpoint时不保存module
torch.save(model.module.state_dict(), PATH)

pytorch保存和加载文件的方法,从断点处继续训练

'''本文件用于举例说明pytorch保存和加载文件的方法''' 
import torch as torch
import torchvision as tv
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.transforms as transforms
import os
  
# 参数声明
batch_size = 32
epochs = 10
WORKERS = 0  # dataloder线程数
test_flag = False  # 测试标志,True时加载保存好的模型进行测试
ROOT = '/home/pxt/pytorch/cifar'  # MNIST数据集保存路径
log_dir = '/home/pxt/pytorch/logs/cifar_model.pth'  # 模型保存路径
# 加载MNIST数据集
transform = tv.transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
 
train_data = tv.datasets.CIFAR10(root=ROOT, train=True, download=True, transform=transform)
test_data = tv.datasets.CIFAR10(root=ROOT, train=False, download=False, transform=transform)
 
train_load = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=WORKERS)
test_load = torch.utils.data.DataLoader(test_data, batch_size=batch_size, shuffle=False, num_workers=WORKERS)
 
 
# 构造模型
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, 3, padding=1)
        self.conv2 = nn.Conv2d(64, 128, 3, padding=1)
        self.conv3 = nn.Conv2d(128, 256, 3, padding=1)
        self.conv4 = nn.Conv2d(256, 256, 3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(256 * 8 * 8, 1024)
        self.fc2 = nn.Linear(1024, 256)
        self.fc3 = nn.Linear(256, 10)
 
    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = self.pool(F.relu(self.conv2(x)))
        x = F.relu(self.conv3(x))
        x = self.pool(F.relu(self.conv4(x)))
        x = x.view(-1, x.size()[1] * x.size()[2] * x.size()[3])
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x
  
model = Net().cpu()
 
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)
 
 
# 模型训练
def train(model, train_loader, epoch):
    model.train()
    train_loss = 0
    for i, data in enumerate(train_loader, 0):
        x, y = data
        x = x.cpu()
        y = y.cpu()
 
        optimizer.zero_grad()
        y_hat = model(x)
        loss = criterion(y_hat, y)
        loss.backward()
        optimizer.step()
        train_loss += loss
        print('正在进行第{}个epoch中的第{}次循环'.format(epoch,i))
 
    loss_mean = train_loss / (i + 1)
    print('Train Epoch: {}\t Loss: {:.6f}'.format(epoch, loss_mean.item()))
 
 
# 模型测试
def test(model, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for i, data in enumerate(test_loader, 0):
            x, y = data
            x = x.cpu()
            y = y.cpu()
 
            optimizer.zero_grad()
            y_hat = model(x)
            test_loss += criterion(y_hat, y).item()
            pred = y_hat.max(1, keepdim=True)[1]
            correct += pred.eq(y.view_as(pred)).sum().item()
        test_loss /= (i + 1)
        print('Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
            test_loss, correct, len(test_data), 100. * correct / len(test_data)))
  
def main():
    # 如果test_flag=True,则加载已保存的模型并进行测试,测试以后不进行此模块以后的步骤
    if test_flag:
        # 加载保存的模型直接进行测试机验证
        checkpoint = torch.load(log_dir)
        model.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        start_epoch = checkpoint['epoch']
        test(model, test_load)
        return
 
    # 如果有保存的模型,则加载模型,并在其基础上继续训练
    if os.path.exists(log_dir):
        checkpoint = torch.load(log_dir)
        model.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        start_epoch = checkpoint['epoch']
        print('加载 epoch {} 成功!'.format(start_epoch))
    else:
        start_epoch = 0
        print('无保存了的模型,将从头开始训练!')
 
    for epoch in range(start_epoch+1, epochs):
        train(model, train_load, epoch)
        test(model, test_load)
        # 保存模型
        state = {'model':model.state_dict(), 'optimizer':optimizer.state_dict(), 'epoch':epoch}
        torch.save(state, log_dir)
 
if __name__ == '__main__':
    main()

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Cython 三分钟入门教程
Sep 17 Python
pydev使用wxpython找不到路径的解决方法
Feb 10 Python
开源软件包和环境管理系统Anaconda的安装使用
Sep 04 Python
利用TensorFlow训练简单的二分类神经网络模型的方法
Mar 05 Python
pandas apply 函数 实现多进程的示例讲解
Apr 20 Python
Python实现迭代时使用索引的方法示例
Jun 05 Python
Django model update的多种用法介绍
Mar 28 Python
Django中URL的参数传递的实现
Aug 04 Python
基于Python中的yield表达式介绍
Nov 19 Python
selenium+python配置chrome浏览器的选项的实现
Mar 18 Python
python3实现名片管理系统(控制台版)
Nov 29 Python
python批量创建变量并赋值操作
Jun 03 Python
pytest实现多进程与多线程运行超好用的插件
Jul 15 #Python
python如何将mat文件转为png
Jul 15 #Python
python读取mat文件生成h5文件的实现
Jul 15 #Python
全网非常详细的pytest配置文件
Jul 15 #Python
Python如何加载模型并查看网络
Jul 15 #Python
Python绘制散点图之可视化神器pyecharts
Jul 07 #Python
Python可视化神器pyecharts之绘制箱形图
Jul 07 #Python
You might like
PHP与javascript的两种交互方式
2006/10/09 PHP
iOS10推送通知开发教程
2016/09/19 PHP
PHP Socket网络操作类定义与用法示例
2017/08/30 PHP
PHP信号处理机制的操作代码讲解
2019/04/19 PHP
List Installed Software Features
2007/06/11 Javascript
JQuery 学习笔记 选择器之一
2009/07/23 Javascript
js点击事件链接的问题解决
2014/04/25 Javascript
针对初学者的jQuery入门指南
2015/08/15 Javascript
Extjs4.0 ComboBox如何实现三级联动
2016/05/11 Javascript
ionic js 复选框 与普通的 HTML 复选框到底有没区别
2016/06/06 Javascript
BootStrap下拉菜单和滚动监听插件实现代码
2016/09/26 Javascript
JS冒泡事件与事件捕获实例详解
2016/11/25 Javascript
Node.js自定义实现文件路由功能
2017/09/22 Javascript
vue form check 表单验证的实现代码
2018/12/09 Javascript
js实现倒计时器自定义时间和暂停
2019/02/25 Javascript
微信公众号服务器验证Token步骤图解
2019/12/30 Javascript
javascript操作向表格中动态加载数据
2020/08/27 Javascript
windows如何把已安装的nodejs高版本降级为低版本(图文教程)
2020/12/14 NodeJs
[02:56]《DAC最前线》之国外战队抵达上海备战亚洲邀请赛
2015/01/28 DOTA
[01:14:30]TNC vs VG 2019国际邀请赛淘汰赛 胜者组赛BO3 第二场 8.20.mp4
2019/08/22 DOTA
Python continue继续循环用法总结
2018/06/10 Python
python自动化测试之如何解析excel文件
2019/06/27 Python
TensorFlow基于MNIST数据集实现车牌识别(初步演示版)
2019/08/05 Python
Python编程学习之如何判断3个数的大小
2019/08/07 Python
基于python实现雪花算法过程详解
2019/11/16 Python
python爬虫爬取淘宝商品比价(附淘宝反爬虫机制解决小办法)
2020/12/03 Python
必须要使用游标的SQL语句有那些
2012/05/07 面试题
工程力学专业毕业生求职信
2013/10/06 职场文书
技能竞赛活动方案
2014/02/21 职场文书
办公室主任竞聘演讲稿
2014/05/15 职场文书
毕业生代领毕业材料的授权委托书
2014/09/29 职场文书
大学生暑假实习总结
2015/07/13 职场文书
2016年秋季运动会加油稿
2015/12/21 职场文书
大学军训口号大全
2015/12/24 职场文书
2016年第十九届推普周活动总结
2016/04/06 职场文书
OpenCV-Python模板匹配人眼的实例
2021/06/08 Python