pytorch实现加载保存查看checkpoint文件


Posted in Python onJuly 15, 2022

1.保存加载checkpoint文件

# 方式一:保存加载整个state_dict(推荐)
# 保存
torch.save(model.state_dict(), PATH)
# 加载
model.load_state_dict(torch.load(PATH))
# 测试时不启用 BatchNormalization 和 Dropout
model.eval()
# 方式二:保存加载整个模型
# 保存
torch.save(model, PATH)
# 加载
model = torch.load(PATH)
model.eval()
# 方式三:保存用于继续训练的checkpoint或者多个模型
# 保存
torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            ...
            }, PATH)
# 加载
checkpoint = torch.load(PATH)
start_epoch=checkpoint['epoch']
model.load_state_dict(checkpoint['model_state_dict'])
# 测试时
model.eval()
# 或者训练时
model.train()

2.跨gpu和cpu

# GPU上保存,CPU上加载
# 保存
torch.save(model.state_dict(), PATH)
# 加载
device = torch.device('cpu')
model.load_state_dict(torch.load(PATH, map_location=device))
# 如果是多gpu保存,需要去除关键字中的module,见第4部分
# GPU上保存,GPU上加载
# 保存
torch.save(model.state_dict(), PATH)
# 加载
device = torch.device("cuda")
model.load_state_dict(torch.load(PATH))
model.to(device)
# CPU上保存,GPU上加载
# 保存
torch.save(model.state_dict(), PATH)
# 加载
device = torch.device("cuda")
# 选择希望使用的GPU
model.load_state_dict(torch.load(PATH, map_location="cuda:0"))  
model.to(device)

3.查看checkpoint文件内容

# 打印模型的 state_dict
print("Model's state_dict:")
for param_tensor in model.state_dict():
    print(param_tensor, "\t", model.state_dict()[param_tensor].size())

4.常见问题

多gpu

报错为KeyError: ‘unexpected key “module.conv1.weight” in state_dict’

原因:当使用多gpu时,会使用torch.nn.DataParallel,所以checkpoint中有module字样

#解决1:加载时将module去掉

# 创建一个不包含`module.`的新OrderedDict
from collections import OrderedDict
new_state_dict = OrderedDict()
for k, v in state_dict.items():
    name = k[7:] # 去掉 `module.`
    new_state_dict[name] = v
# 加载参数
model.load_state_dict(new_state_dict)
# 解决2:保存checkpoint时不保存module
torch.save(model.module.state_dict(), PATH)

pytorch保存和加载文件的方法,从断点处继续训练

'''本文件用于举例说明pytorch保存和加载文件的方法''' 
import torch as torch
import torchvision as tv
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.transforms as transforms
import os
  
# 参数声明
batch_size = 32
epochs = 10
WORKERS = 0  # dataloder线程数
test_flag = False  # 测试标志,True时加载保存好的模型进行测试
ROOT = '/home/pxt/pytorch/cifar'  # MNIST数据集保存路径
log_dir = '/home/pxt/pytorch/logs/cifar_model.pth'  # 模型保存路径
# 加载MNIST数据集
transform = tv.transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
 
train_data = tv.datasets.CIFAR10(root=ROOT, train=True, download=True, transform=transform)
test_data = tv.datasets.CIFAR10(root=ROOT, train=False, download=False, transform=transform)
 
train_load = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=WORKERS)
test_load = torch.utils.data.DataLoader(test_data, batch_size=batch_size, shuffle=False, num_workers=WORKERS)
 
 
# 构造模型
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, 3, padding=1)
        self.conv2 = nn.Conv2d(64, 128, 3, padding=1)
        self.conv3 = nn.Conv2d(128, 256, 3, padding=1)
        self.conv4 = nn.Conv2d(256, 256, 3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(256 * 8 * 8, 1024)
        self.fc2 = nn.Linear(1024, 256)
        self.fc3 = nn.Linear(256, 10)
 
    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = self.pool(F.relu(self.conv2(x)))
        x = F.relu(self.conv3(x))
        x = self.pool(F.relu(self.conv4(x)))
        x = x.view(-1, x.size()[1] * x.size()[2] * x.size()[3])
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x
  
model = Net().cpu()
 
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)
 
 
# 模型训练
def train(model, train_loader, epoch):
    model.train()
    train_loss = 0
    for i, data in enumerate(train_loader, 0):
        x, y = data
        x = x.cpu()
        y = y.cpu()
 
        optimizer.zero_grad()
        y_hat = model(x)
        loss = criterion(y_hat, y)
        loss.backward()
        optimizer.step()
        train_loss += loss
        print('正在进行第{}个epoch中的第{}次循环'.format(epoch,i))
 
    loss_mean = train_loss / (i + 1)
    print('Train Epoch: {}\t Loss: {:.6f}'.format(epoch, loss_mean.item()))
 
 
# 模型测试
def test(model, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for i, data in enumerate(test_loader, 0):
            x, y = data
            x = x.cpu()
            y = y.cpu()
 
            optimizer.zero_grad()
            y_hat = model(x)
            test_loss += criterion(y_hat, y).item()
            pred = y_hat.max(1, keepdim=True)[1]
            correct += pred.eq(y.view_as(pred)).sum().item()
        test_loss /= (i + 1)
        print('Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
            test_loss, correct, len(test_data), 100. * correct / len(test_data)))
  
def main():
    # 如果test_flag=True,则加载已保存的模型并进行测试,测试以后不进行此模块以后的步骤
    if test_flag:
        # 加载保存的模型直接进行测试机验证
        checkpoint = torch.load(log_dir)
        model.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        start_epoch = checkpoint['epoch']
        test(model, test_load)
        return
 
    # 如果有保存的模型,则加载模型,并在其基础上继续训练
    if os.path.exists(log_dir):
        checkpoint = torch.load(log_dir)
        model.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        start_epoch = checkpoint['epoch']
        print('加载 epoch {} 成功!'.format(start_epoch))
    else:
        start_epoch = 0
        print('无保存了的模型,将从头开始训练!')
 
    for epoch in range(start_epoch+1, epochs):
        train(model, train_load, epoch)
        test(model, test_load)
        # 保存模型
        state = {'model':model.state_dict(), 'optimizer':optimizer.state_dict(), 'epoch':epoch}
        torch.save(state, log_dir)
 
if __name__ == '__main__':
    main()

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python json模块使用实例
Apr 11 Python
Python批量转换文件编码格式
May 17 Python
Python实现监控程序执行时间并将其写入日志的方法
Jun 30 Python
详解Python pygame安装过程笔记
Jun 05 Python
python里使用正则表达式的组嵌套实例详解
Oct 24 Python
两个元祖T1=('a', 'b'),T2=('c', 'd')使用匿名函数将其转变成[{'a': 'c'},{'b': 'd'}]的几种方法
Mar 05 Python
python调用webservice接口的实现
Jul 12 Python
在自动化中用python实现键盘操作的方法详解
Jul 19 Python
Django在admin后台集成TinyMCE富文本编辑器的例子
Aug 09 Python
python程序中的线程操作 concurrent模块使用详解
Sep 23 Python
Python callable内置函数原理解析
Mar 05 Python
python 用pandas实现数据透视表功能
Dec 21 Python
pytest实现多进程与多线程运行超好用的插件
Jul 15 #Python
python如何将mat文件转为png
Jul 15 #Python
python读取mat文件生成h5文件的实现
Jul 15 #Python
全网非常详细的pytest配置文件
Jul 15 #Python
Python如何加载模型并查看网络
Jul 15 #Python
Python绘制散点图之可视化神器pyecharts
Jul 07 #Python
Python可视化神器pyecharts之绘制箱形图
Jul 07 #Python
You might like
php 安全过滤函数代码
2011/05/07 PHP
ThinkPHP表单自动提交验证实例教程
2014/07/18 PHP
php curl 模拟登录并获取数据实例详解
2016/12/22 PHP
php从数据库中获取数据用ajax传送到前台的方法
2018/08/20 PHP
js 学习笔记(三)
2009/12/29 Javascript
浅谈jQuery中对象遍历.eq().first().last().slice()方法
2014/11/26 Javascript
JS实现超简单的仿QQ折叠菜单效果
2015/09/21 Javascript
浅析jQuery移动开发中内联按钮和分组按钮的编写
2015/12/04 Javascript
jQuery模拟360浏览器切屏效果幻灯片(附demo源码下载)
2016/01/29 Javascript
JS 实现倒计时数字时钟效果【附实例代码】
2016/03/30 Javascript
JavaScript中push(),join() 函数 实例详解
2016/09/06 Javascript
详解用原生JavaScript实现jQuery的某些简单功能
2016/12/19 Javascript
jquery封装插件时匿名函数形参和实参的写法解释
2017/02/14 Javascript
详解angular中的作用域及继承
2017/05/31 Javascript
微信小程序实现上传图片功能
2018/05/28 Javascript
8 个有用的JS技巧(推荐)
2019/07/03 Javascript
基于JavaScript实现十五拼图代码实例
2020/04/26 Javascript
[03:02]安得倚天剑,跨海斩长鲸——中国军团出征DOTA2国际邀请赛
2018/08/14 DOTA
Python实现压缩和解压缩ZIP文件的方法分析
2017/09/28 Python
用python处理图片之打开\显示\保存图像的方法
2018/05/04 Python
python使用__slots__让你的代码更加节省内存
2018/09/05 Python
Python找出微信上删除你好友的人脚本写法
2018/11/01 Python
在python中使用xlrd获取合并单元格的方法
2018/12/26 Python
Tensorflow中tf.ConfigProto()的用法详解
2020/02/06 Python
对Matlab中共轭、转置和共轭装置的区别说明
2020/05/11 Python
PyCharm+PyQt5+QtDesigner配置详解
2020/08/12 Python
购买一个高级域名:BuyDomains
2018/03/11 全球购物
Spartoo瑞典:鞋子、包包和衣服
2018/09/15 全球购物
德国最新街头服饰网上商店:BODYCHECK
2019/09/15 全球购物
白俄罗斯女装和针织品网上商店:Presli.by
2019/10/13 全球购物
全民健身日活动方案
2014/01/29 职场文书
关于青春的演讲稿800字
2014/08/22 职场文书
关于成立领导小组的通知
2015/04/23 职场文书
甲午大海战观后感
2015/06/02 职场文书
Keras在mnist上的CNN实践,并且自定义loss函数曲线图操作
2021/05/25 Python
Oracle中日期的使用方法实例
2022/07/07 Oracle