Python如何加载模型并查看网络


Posted in Python onJuly 15, 2022

加载模型并查看网络

加载模型,以vgg19为例。

打开终端

> python
Python 3.7.2 (tags/v3.7.2:9a3ffc0492, Dec 23 2018, 23:09:28) [MSC v.1916 64 bit
(AMD64)] on win32
Type "help", "copyright", "credits" or "license" for more information.
>>> from torchvision import models
>>> model = models.vgg19(pretrained=True) #此时如果是第一次加载会开始下载模型的pth文件
>>> print(model.model)

结果:

VGG(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU(inplace)
    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU(inplace)
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace)
    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU(inplace)
    (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): ReLU(inplace)
    (16): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (17): ReLU(inplace)
    (18): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (19): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (20): ReLU(inplace)
    (21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (22): ReLU(inplace)
    (23): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (24): ReLU(inplace)
    (25): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (26): ReLU(inplace)
    (27): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (29): ReLU(inplace)
    (30): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (31): ReLU(inplace)
    (32): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (33): ReLU(inplace)
    (34): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (35): ReLU(inplace)
    (36): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(7, 7))
  (classifier): Sequential(
    (0): Linear(in_features=25088, out_features=4096, bias=True)
    (1): ReLU(inplace)
    (2): Dropout(p=0.5)
    (3): Linear(in_features=4096, out_features=4096, bias=True)
    (4): ReLU(inplace)
    (5): Dropout(p=0.5)
    (6): Linear(in_features=4096, out_features=1000, bias=True)
  )
)

注意,直接打印模型是没有办法看到模型结构的,只能看到带模型参数的pth文件内容;需要打印model.model才可以看到模型本身。

神经网络_模型的保存,模型的加载

模型的保存(torch.save)

方式1(模型结构+模型参数)

参数:保存位置

# 创建模型
vgg16 = torchvision.models.vgg16(pretrained=False)
# 保存方式1——模型结构+模型参数
torch.save(vgg16, "vgg16_method1.pth")

方式2(模型参数)

# 保存方式2  模型参数(官方推荐)。保存成字典,只保存网络模型中的一些参数
torch.save(vgg16.state_dict(), "vgg16_method2.pth")

模型的加载(torch.load)

对应保存方式1

参数:模型路径

# 方式1 --》 保存方式1
model1 = torch.load("vgg16_method1.pth")

对应保存方式2

vgg16.load_state_dict("vgg16_method2.pth")

输出为字典形式。若要回复网络,采用以下形式:

model2 = torch.load("vgg16_method2.pth")  #输出是字典形式
# 恢复网络结构
vgg16 = torchvision.models.vgg16(pretrained=False)
vgg16.load_state_dict(model2)

方式1存储,加载时需注意事项

新建自己的网络:

class test(nn.Module):
    def __init__(self):
        super(lh, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3)

    def forward(self, x):
        x = self.conv1(x)
        return x

保存自己的网络:

Test = test()
# 保存自己定义的网络
torch.save(Test, "Test_method1.pth")

加载自己的网络:

model3 = torch.load("Test_method1.pth")

会报错!!!!!!

Python如何加载模型并查看网络

解决办法(需要注意):

将定义的网络复制到加载的python文件中:

class test(nn.Module):
    def __init__(self):
        super(test, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3)

    def forward(self, x):
        x = self.conv1(x)
        return x
model3 = torch.load("Test_method1.pth")

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
自己编程中遇到的Python错误和解决方法汇总整理
Jun 03 Python
Python使用Srapy框架爬虫模拟登陆并抓取知乎内容
Jul 02 Python
Django中cookie的基本使用方法示例
Feb 03 Python
实例讲解Python脚本成为Windows中运行的exe文件
Jan 24 Python
Python实现最大子序和的方法示例
Jul 05 Python
Django中使用MySQL5.5的教程
Dec 18 Python
Python如何通过Flask-Mail发送电子邮件
Jan 29 Python
使用python的pyplot绘制函数实例
Feb 13 Python
PyQt5 文本输入框自动补全QLineEdit的实现示例
May 13 Python
Python进行特征提取的示例代码
Oct 15 Python
Python字符串查找基本操作代码案例
Oct 27 Python
Python实现京东抢秒杀功能
Jan 25 Python
Python绘制散点图之可视化神器pyecharts
Jul 07 #Python
Python可视化神器pyecharts之绘制箱形图
Jul 07 #Python
Python通用验证码识别OCR库ddddocr的安装使用教程
Jul 07 #Python
Django数据库(SQlite)基本入门使用教程
Jul 07 #Python
Python可视化神器pyecharts之绘制地理图表练习
Django中celery的使用项目实例
Python可视化神器pyecharts绘制地理图表
You might like
数据库查询记录php 多行多列显示
2009/08/15 PHP
PHP依赖注入(DI)和控制反转(IoC)详解
2017/06/12 PHP
如何修改yii2.0自带的user表为其它的表
2017/08/01 PHP
用JQuery 实现的自定义对话框
2007/03/24 Javascript
jQuery 可以拖动的div实现代码 脚本之家修正版
2009/06/26 Javascript
通过javascript的匿名函数来分析几段简单有趣的代码
2010/06/29 Javascript
火狐textarea输入法的bug的触发及解决
2013/07/24 Javascript
JS表格组件神器bootstrap table详解(基础版)
2015/12/08 Javascript
一种新的javascript对象创建方式Object.create()
2015/12/28 Javascript
javascript this详细介绍
2016/09/19 Javascript
js文件中直接alert()中文出来的是乱码的解决方法
2016/11/01 Javascript
JS 实现可停顿的垂直滚动实例代码
2016/11/23 Javascript
Vue常用指令V-model用法
2017/03/08 Javascript
vue项目开发中setTimeout等定时器的管理问题
2018/09/13 Javascript
详解javascript 变量提升(Hoisting)
2019/03/12 Javascript
JavaScript面向对象程序设计中对象的定义和继承详解
2019/07/29 Javascript
JavaScript 装逼指南(js另类写法)
2020/05/10 Javascript
微信小程序向Java后台传输参数的方法实现
2020/12/10 Javascript
python遍历数组的方法小结
2015/04/30 Python
python在指定目录下查找gif文件的方法
2015/05/04 Python
Python使用xlrd模块操作Excel数据导入的方法
2015/05/26 Python
Python中的id()函数指的什么
2017/10/17 Python
TensorFlow数据输入的方法示例
2018/06/19 Python
kaggle+mnist实现手写字体识别
2018/07/26 Python
对python中的float除法和整除法的实例详解
2019/07/20 Python
Django使用Channels实现WebSocket的方法
2019/07/28 Python
浅谈pytorch 模型 .pt, .pth, .pkl的区别及模型保存方式
2020/05/25 Python
解析Tensorflow之MNIST的使用
2020/06/30 Python
乡下人家教学反思
2014/02/01 职场文书
领导干部群众路线教育实践活动剖析材料
2014/10/10 职场文书
党的群众路线教育实践活动个人对照检查材料(乡镇)
2014/11/05 职场文书
高中生毕业评语
2014/12/30 职场文书
总账会计岗位职责
2015/04/02 职场文书
SQL写法--行行比较
2021/08/23 SQL Server
redis 解决库存并发问题实现数量控制
2022/04/08 Redis
Python OpenGL基本配置方式
2022/05/20 Python