pytorch实现加载保存查看checkpoint文件


Posted in Python onJuly 15, 2022

1.保存加载checkpoint文件

# 方式一:保存加载整个state_dict(推荐)
# 保存
torch.save(model.state_dict(), PATH)
# 加载
model.load_state_dict(torch.load(PATH))
# 测试时不启用 BatchNormalization 和 Dropout
model.eval()
# 方式二:保存加载整个模型
# 保存
torch.save(model, PATH)
# 加载
model = torch.load(PATH)
model.eval()
# 方式三:保存用于继续训练的checkpoint或者多个模型
# 保存
torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            ...
            }, PATH)
# 加载
checkpoint = torch.load(PATH)
start_epoch=checkpoint['epoch']
model.load_state_dict(checkpoint['model_state_dict'])
# 测试时
model.eval()
# 或者训练时
model.train()

2.跨gpu和cpu

# GPU上保存,CPU上加载
# 保存
torch.save(model.state_dict(), PATH)
# 加载
device = torch.device('cpu')
model.load_state_dict(torch.load(PATH, map_location=device))
# 如果是多gpu保存,需要去除关键字中的module,见第4部分
# GPU上保存,GPU上加载
# 保存
torch.save(model.state_dict(), PATH)
# 加载
device = torch.device("cuda")
model.load_state_dict(torch.load(PATH))
model.to(device)
# CPU上保存,GPU上加载
# 保存
torch.save(model.state_dict(), PATH)
# 加载
device = torch.device("cuda")
# 选择希望使用的GPU
model.load_state_dict(torch.load(PATH, map_location="cuda:0"))  
model.to(device)

3.查看checkpoint文件内容

# 打印模型的 state_dict
print("Model's state_dict:")
for param_tensor in model.state_dict():
    print(param_tensor, "\t", model.state_dict()[param_tensor].size())

4.常见问题

多gpu

报错为KeyError: ‘unexpected key “module.conv1.weight” in state_dict’

原因:当使用多gpu时,会使用torch.nn.DataParallel,所以checkpoint中有module字样

#解决1:加载时将module去掉

# 创建一个不包含`module.`的新OrderedDict
from collections import OrderedDict
new_state_dict = OrderedDict()
for k, v in state_dict.items():
    name = k[7:] # 去掉 `module.`
    new_state_dict[name] = v
# 加载参数
model.load_state_dict(new_state_dict)
# 解决2:保存checkpoint时不保存module
torch.save(model.module.state_dict(), PATH)

pytorch保存和加载文件的方法,从断点处继续训练

'''本文件用于举例说明pytorch保存和加载文件的方法''' 
import torch as torch
import torchvision as tv
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.transforms as transforms
import os
  
# 参数声明
batch_size = 32
epochs = 10
WORKERS = 0  # dataloder线程数
test_flag = False  # 测试标志,True时加载保存好的模型进行测试
ROOT = '/home/pxt/pytorch/cifar'  # MNIST数据集保存路径
log_dir = '/home/pxt/pytorch/logs/cifar_model.pth'  # 模型保存路径
# 加载MNIST数据集
transform = tv.transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
 
train_data = tv.datasets.CIFAR10(root=ROOT, train=True, download=True, transform=transform)
test_data = tv.datasets.CIFAR10(root=ROOT, train=False, download=False, transform=transform)
 
train_load = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=WORKERS)
test_load = torch.utils.data.DataLoader(test_data, batch_size=batch_size, shuffle=False, num_workers=WORKERS)
 
 
# 构造模型
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, 3, padding=1)
        self.conv2 = nn.Conv2d(64, 128, 3, padding=1)
        self.conv3 = nn.Conv2d(128, 256, 3, padding=1)
        self.conv4 = nn.Conv2d(256, 256, 3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(256 * 8 * 8, 1024)
        self.fc2 = nn.Linear(1024, 256)
        self.fc3 = nn.Linear(256, 10)
 
    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = self.pool(F.relu(self.conv2(x)))
        x = F.relu(self.conv3(x))
        x = self.pool(F.relu(self.conv4(x)))
        x = x.view(-1, x.size()[1] * x.size()[2] * x.size()[3])
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x
  
model = Net().cpu()
 
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)
 
 
# 模型训练
def train(model, train_loader, epoch):
    model.train()
    train_loss = 0
    for i, data in enumerate(train_loader, 0):
        x, y = data
        x = x.cpu()
        y = y.cpu()
 
        optimizer.zero_grad()
        y_hat = model(x)
        loss = criterion(y_hat, y)
        loss.backward()
        optimizer.step()
        train_loss += loss
        print('正在进行第{}个epoch中的第{}次循环'.format(epoch,i))
 
    loss_mean = train_loss / (i + 1)
    print('Train Epoch: {}\t Loss: {:.6f}'.format(epoch, loss_mean.item()))
 
 
# 模型测试
def test(model, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for i, data in enumerate(test_loader, 0):
            x, y = data
            x = x.cpu()
            y = y.cpu()
 
            optimizer.zero_grad()
            y_hat = model(x)
            test_loss += criterion(y_hat, y).item()
            pred = y_hat.max(1, keepdim=True)[1]
            correct += pred.eq(y.view_as(pred)).sum().item()
        test_loss /= (i + 1)
        print('Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
            test_loss, correct, len(test_data), 100. * correct / len(test_data)))
  
def main():
    # 如果test_flag=True,则加载已保存的模型并进行测试,测试以后不进行此模块以后的步骤
    if test_flag:
        # 加载保存的模型直接进行测试机验证
        checkpoint = torch.load(log_dir)
        model.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        start_epoch = checkpoint['epoch']
        test(model, test_load)
        return
 
    # 如果有保存的模型,则加载模型,并在其基础上继续训练
    if os.path.exists(log_dir):
        checkpoint = torch.load(log_dir)
        model.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        start_epoch = checkpoint['epoch']
        print('加载 epoch {} 成功!'.format(start_epoch))
    else:
        start_epoch = 0
        print('无保存了的模型,将从头开始训练!')
 
    for epoch in range(start_epoch+1, epochs):
        train(model, train_load, epoch)
        test(model, test_load)
        # 保存模型
        state = {'model':model.state_dict(), 'optimizer':optimizer.state_dict(), 'epoch':epoch}
        torch.save(state, log_dir)
 
if __name__ == '__main__':
    main()

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python实现2048小游戏
Mar 30 Python
python从网络读取图片并直接进行处理的方法
May 22 Python
Python 搭建Web站点之Web服务器与Web框架
Nov 06 Python
json跨域调用python的方法详解
Jan 11 Python
python3 中文乱码与默认编码格式设定方法
Oct 31 Python
对python中的乘法dot和对应分量相乘multiply详解
Nov 14 Python
在python下读取并展示raw格式的图片实例
Jan 24 Python
Python向excel中写入数据的方法
May 05 Python
Django 配置多站点多域名的实现步骤
May 17 Python
Python3 sys.argv[ ]用法详解
Oct 24 Python
Python实现删除某列中含有空值的行的示例代码
Jul 20 Python
Python使用plt.boxplot()函数绘制箱图、常用方法以及含义详解
Aug 14 Python
pytest实现多进程与多线程运行超好用的插件
Jul 15 #Python
python如何将mat文件转为png
Jul 15 #Python
python读取mat文件生成h5文件的实现
Jul 15 #Python
全网非常详细的pytest配置文件
Jul 15 #Python
Python如何加载模型并查看网络
Jul 15 #Python
Python绘制散点图之可视化神器pyecharts
Jul 07 #Python
Python可视化神器pyecharts之绘制箱形图
Jul 07 #Python
You might like
PHP的类 功能齐全的发送邮件类
2006/10/09 PHP
use jscript Create a SQL Server database
2007/06/16 Javascript
javascript 动态修改样式和层叠样式表代码
2010/04/27 Javascript
jQuery针对各类元素操作基础教程
2014/08/29 Javascript
JS实现网页背景颜色与select框中颜色同时变化的方法
2015/02/27 Javascript
javascript制作幻灯片(360度全景图片)
2015/07/28 Javascript
详解JavaScript正则表达式之分组匹配及反向引用
2016/03/09 Javascript
jQuery grep()方法详解及实例代码
2016/10/30 Javascript
详解angularjs利用ui-route异步加载组件
2017/05/21 Javascript
JS实现下拉菜单列表与登录注册弹窗效果
2017/08/10 Javascript
vue实现留言板todolist功能
2017/08/16 Javascript
在vue项目创建的后初始化首次使用stylus安装方法分享
2018/01/25 Javascript
Vue表单demo v-model双向绑定问题
2018/06/29 Javascript
原生JS forEach()和map()遍历的区别、兼容写法及jQuery $.each、$.map遍历操作
2019/02/27 jQuery
通过实践编写优雅的JavaScript代码
2019/05/30 Javascript
浅谈Vue为什么不能检测数组变动
2019/10/14 Javascript
JQuery实现ul中添加LI和删除指定的Li元素功能完整示例
2019/10/16 jQuery
[51:06]DOTA2-DPC中国联赛 正赛 Elephant vs Aster BO3 第二场 1月26日
2021/03/11 DOTA
对于Python异常处理慎用“except:pass”建议
2015/04/02 Python
利用Python脚本生成sitemap.xml的实现方法
2017/01/31 Python
Python实现比较扑克牌大小程序代码示例
2017/12/06 Python
html5如何及时更新缓存文件(js、css或图片)
2013/06/24 HTML / CSS
html5模拟平抛运动(模拟小球平抛运动过程)
2013/07/25 HTML / CSS
玩具反斗城葡萄牙官方商城:Toys"R"Us葡萄牙
2016/10/21 全球购物
欧舒丹比利时官网:L’OCCITANE比利时
2017/04/25 全球购物
Myprotein台湾官方网站:全球领先的运动营养品牌
2018/12/10 全球购物
时尚孕妇装:HATCH Collection
2019/09/24 全球购物
Tuckernuck官网:经典的美国品质服装、鞋子和配饰
2021/01/11 全球购物
在Ajax应用中信息是如何在浏览器和服务器之间传递的
2016/05/31 面试题
单位实习证明怎么写
2014/01/17 职场文书
股东合作协议书
2014/09/12 职场文书
2014年大学生党员自我评议
2014/09/22 职场文书
六查六看六改心得体会
2014/10/14 职场文书
2015年话务员工作总结
2015/04/29 职场文书
Python基础之元类详解
2021/04/29 Python
mysql联合索引的使用规则
2021/06/23 MySQL