深入理解Pytorch微调torchvision模型


Posted in Python onNovember 11, 2021

一、简介

在本小节,深入探讨如何对torchvision进行微调和特征提取。所有模型都已经预先在1000类的magenet数据集上训练完成。 本节将深入介绍如何使用几个现代的CNN架构,并将直观展示如何微调任意的PyTorch模型。
本节将执行两种类型的迁移学习:

  • 微调:从预训练模型开始,更新我们新任务的所有模型参数,实质上是重新训练整个模型。
  • 特征提取:从预训练模型开始,仅更新从中导出预测的最终图层权重。它被称为特征提取,因为我们使用预训练的CNN作为固定 的特征提取器,并且仅改变输出层。

通常这两种迁移学习方法都会遵循一下步骤:

  • 初始化预训练模型
  • 重组最后一层,使其具有与新数据集类别数相同的输出数
  • 为优化算法定义想要的训练期间更新的参数
  • 运行训练步骤

二、导入相关包

from __future__ import print_function
from __future__ import division
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import torchvision 
from torchvision import datasets,models,transforms
import matplotlib.pyplot as plt
import time
import os
import copy
print("Pytorch version:",torch.__version__)
print("torchvision version:",torchvision.__version__)

运行结果

深入理解Pytorch微调torchvision模型

三、数据输入

数据集——>我在这里

链接:https://pan.baidu.com/s/1G3yRfKTQf9sIq1iCSoymWQ
提取码:1234

#%%输入
data_dir="D:\Python\Pytorch\data\hymenoptera_data"
# 从[resnet,alexnet,vgg,squeezenet,desenet,inception]
model_name='squeezenet'
# 数据集中类别数量
num_classes=2
# 训练的批量大小
batch_size=8
# 训练epoch数
num_epochs=15
# 用于特征提取的标志。为FALSE,微调整个模型,为TRUE只更新图层参数
feature_extract=True

四、辅助函数

1、模型训练和验证

  • train_model函数处理给定模型的训练和验证。作为输入,它需要PyTorch模型、数据加载器字典、损失函数、优化器、用于训练和验 证epoch数,以及当模型是初始模型时的布尔标志。
  • is_inception标志用于容纳 Inception v3 模型,因为该体系结构使用辅助输出, 并且整体模型损失涉及辅助输出和最终输出,如此处所述。 这个函数训练指定数量的epoch,并且在每个epoch之后运行完整的验证步骤。它还跟踪最佳性能的模型(从验证准确率方面),并在训练 结束时返回性能最好的模型。在每个epoch之后,打印训练和验证正确率。
#%%模型训练和验证
device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
def train_model(model,dataloaders,criterion,optimizer,num_epochs=25,is_inception=False):
    since=time.time()
    val_acc_history=[]
    best_model_wts=copy.deepcopy(model.state_dict())
    best_acc=0.0
    for epoch in range(num_epochs):
        print('Epoch{}/{}'.format(epoch, num_epochs-1))
        print('-'*10)
        # 每个epoch都有一个训练和验证阶段
        for phase in['train','val']:
            if phase=='train':
                model.train()
            else:
                model.eval()
                
            running_loss=0.0
            running_corrects=0
            # 迭代数据
            for inputs,labels in dataloaders[phase]:
                inputs=inputs.to(device)
                labels=labels.to(device)
                # 梯度置零
                optimizer.zero_grad()
                # 向前传播
                with torch.set_grad_enabled(phase=='train'):
                    # 获取模型输出并计算损失,开始的特殊情况在训练中他有一个辅助输出
                    # 在训练模式下,通过将最终输出和辅助输出相加来计算损耗,在测试中值考虑最终输出
                    if is_inception and phase=='train':
                        outputs,aux_outputs=model(inputs)
                        loss1=criterion(outputs,labels)
                        loss2=criterion(aux_outputs,labels)
                        loss=loss1+0.4*loss2
                    else:
                        outputs=model(inputs)
                        loss=criterion(outputs,labels)
                        
                    _,preds=torch.max(outputs,1)
                    
                    if phase=='train':
                        loss.backward()
                        optimizer.step()
                        
                # 添加
                running_loss+=loss.item()*inputs.size(0)
                running_corrects+=torch.sum(preds==labels.data)
                
            epoch_loss=running_loss/len(dataloaders[phase].dataset)
            epoch_acc=running_corrects.double()/len(dataloaders[phase].dataset)
            
            print('{}loss : {:.4f} acc:{:.4f}'.format(phase, epoch_loss,epoch_acc))
            
            if phase=='train' and epoch_acc>best_acc:
                best_acc=epoch_acc
                best_model_wts=copy.deepcopy(model.state_dict())
            if phase=='val':
                val_acc_history.append(epoch_acc)
            
        print()

    time_elapsed=time.time()-since
    print('training complete in {:.0f}s'.format(time_elapsed//60, time_elapsed%60))
    print('best val acc:{:.4f}'.format(best_acc))
    
    model.load_state_dict(best_model_wts)
    return model,val_acc_history

2、设置模型参数的'.requires_grad属性'

当我们进行特征提取时,此辅助函数将模型中参数的 .requires_grad 属性设置为False。
默认情况下,当我们加载一个预训练模型时,所有参数都是 .requires_grad = True,如果我们从头开始训练或微调,这种设置就没问题。
但是,如果我们要运行特征提取并且只想为新初始化的层计算梯度,那么我们希望所有其他参数不需要梯度变化。

#%%设置模型参数的.require——grad属性
def set_parameter_requires_grad(model,feature_extracting):
    if feature_extracting:
        for param in model.parameters():
            param.require_grad=False

靓仔今天先去跑步了,再不跑来不及了,先更这么多,后续明天继续~(感谢有人没有催更!感谢监督!希望继续监督!)

以上就是深入理解Pytorch微调torchvision模型的详细内容,更多关于Pytorch torchvision模型的资料请关注三水点靠木其它相关文章!

Python 相关文章推荐
Python编写百度贴吧的简单爬虫
Apr 02 Python
在Python中操作字典之clear()方法的使用
May 21 Python
使用11行Python代码盗取了室友的U盘内容
Oct 23 Python
Django数据库连接丢失问题的解决方法
Dec 29 Python
Python OpenCV 使用滑动条来调整函数参数的方法
Jul 08 Python
Python提取PDF内容的方法(文本、图像、线条等)
Sep 25 Python
python 发送json数据操作实例分析
Oct 15 Python
基于Python解密仿射密码
Oct 21 Python
python3 常见解密加密算法实例分析【base64、MD5等】
Dec 19 Python
window环境pip切换国内源(pip安装异常缓慢的问题)
Dec 31 Python
Python操作注册表详细步骤介绍
Feb 05 Python
Python基础之操作MySQL数据库
May 06 Python
Python 中 Shutil 模块详情
Nov 11 #Python
django 认证类配置实现
Nov 11 #Python
Python Pandas数据分析之iloc和loc的用法详解
据Python爬虫不靠谱预测可知今年双十一销售额将超过6000亿元
Python 详解通过Scrapy框架实现爬取百度新冠疫情数据流程
python中tkinter复选框使用操作
Nov 11 #Python
Python中的变量与常量
Nov 11 #Python
You might like
PHP编程实现阳历转换为阴历的方法实例
2017/08/08 PHP
PHP设计模式之工厂方法设计模式实例分析
2018/04/25 PHP
PHP xpath()函数讲解
2019/02/11 PHP
js传值 判断
2006/10/26 Javascript
PNG背景在不同浏览器下的应用
2009/06/22 Javascript
Json和Jsonp理论实例代码详解
2013/11/15 Javascript
通过js获取上传的图片信息(临时保存路径,名称,大小)然后通过ajax传递给后端的方法
2015/10/01 Javascript
javascript鼠标右键菜单自定义效果
2020/12/08 Javascript
浅谈JS继承_借用构造函数 & 组合式继承
2016/08/16 Javascript
easyui-edatagrid.js实现回车键结束编辑功能的实例
2017/04/12 Javascript
详解ECMAScript6入门--Class对象
2017/04/27 Javascript
vuejs+element-ui+laravel5.4上传文件的示例代码
2017/08/12 Javascript
JSON数据中存在单个转义字符“\”的处理方法
2018/07/11 Javascript
Nuxt.js SSR与权限验证的实现
2018/11/21 Javascript
Vue过渡效果之CSS过渡详解(结合transition,animation,animate.css)
2020/02/05 Javascript
解决vue安装less报错Failed to compile with 1 errors的问题
2020/10/22 Javascript
[58:42]DOTA2上海特级锦标赛C组败者赛 Newbee VS Archon第一局
2016/02/27 DOTA
在Python的Flask框架中实现单元测试的教程
2015/04/20 Python
Python中Class类用法实例分析
2015/11/12 Python
浅谈Python中用datetime包进行对时间的一些操作
2016/06/23 Python
python 网络编程常用代码段
2016/08/28 Python
Diango + uwsgi + nginx项目部署的全过程(可外网访问)
2018/04/22 Python
django 删除数据库表后重新同步的方法
2018/05/27 Python
python GUI库图形界面开发之PyQt5切换按钮控件QPushButton详细使用方法与实例
2020/02/28 Python
django配置app中的静态文件步骤
2020/03/27 Python
python 通过文件夹导入包的操作
2020/06/01 Python
Django form表单与请求的生命周期步骤详解
2020/06/07 Python
python如何导出微信公众号文章方法详解
2020/08/31 Python
分享一个页面平滑滚动小技巧(推荐)
2019/10/23 HTML / CSS
合作协议书范本
2014/04/17 职场文书
租车协议书范本
2014/04/22 职场文书
教师四风问题整改措施
2014/09/25 职场文书
领导班子个人对照检查剖析材料
2014/09/29 职场文书
房屋认购协议书
2015/01/29 职场文书
云服务器部署 Web 项目的实现步骤
2022/06/28 Servers
MySQL的意向共享锁、意向排它锁和死锁
2022/07/15 MySQL