Python如何加载模型并查看网络


Posted in Python onJuly 15, 2022

加载模型并查看网络

加载模型,以vgg19为例。

打开终端

> python
Python 3.7.2 (tags/v3.7.2:9a3ffc0492, Dec 23 2018, 23:09:28) [MSC v.1916 64 bit
(AMD64)] on win32
Type "help", "copyright", "credits" or "license" for more information.
>>> from torchvision import models
>>> model = models.vgg19(pretrained=True) #此时如果是第一次加载会开始下载模型的pth文件
>>> print(model.model)

结果:

VGG(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU(inplace)
    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU(inplace)
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace)
    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU(inplace)
    (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): ReLU(inplace)
    (16): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (17): ReLU(inplace)
    (18): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (19): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (20): ReLU(inplace)
    (21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (22): ReLU(inplace)
    (23): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (24): ReLU(inplace)
    (25): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (26): ReLU(inplace)
    (27): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (29): ReLU(inplace)
    (30): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (31): ReLU(inplace)
    (32): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (33): ReLU(inplace)
    (34): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (35): ReLU(inplace)
    (36): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(7, 7))
  (classifier): Sequential(
    (0): Linear(in_features=25088, out_features=4096, bias=True)
    (1): ReLU(inplace)
    (2): Dropout(p=0.5)
    (3): Linear(in_features=4096, out_features=4096, bias=True)
    (4): ReLU(inplace)
    (5): Dropout(p=0.5)
    (6): Linear(in_features=4096, out_features=1000, bias=True)
  )
)

注意,直接打印模型是没有办法看到模型结构的,只能看到带模型参数的pth文件内容;需要打印model.model才可以看到模型本身。

神经网络_模型的保存,模型的加载

模型的保存(torch.save)

方式1(模型结构+模型参数)

参数:保存位置

# 创建模型
vgg16 = torchvision.models.vgg16(pretrained=False)
# 保存方式1——模型结构+模型参数
torch.save(vgg16, "vgg16_method1.pth")

方式2(模型参数)

# 保存方式2  模型参数(官方推荐)。保存成字典,只保存网络模型中的一些参数
torch.save(vgg16.state_dict(), "vgg16_method2.pth")

模型的加载(torch.load)

对应保存方式1

参数:模型路径

# 方式1 --》 保存方式1
model1 = torch.load("vgg16_method1.pth")

对应保存方式2

vgg16.load_state_dict("vgg16_method2.pth")

输出为字典形式。若要回复网络,采用以下形式:

model2 = torch.load("vgg16_method2.pth")  #输出是字典形式
# 恢复网络结构
vgg16 = torchvision.models.vgg16(pretrained=False)
vgg16.load_state_dict(model2)

方式1存储,加载时需注意事项

新建自己的网络:

class test(nn.Module):
    def __init__(self):
        super(lh, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3)

    def forward(self, x):
        x = self.conv1(x)
        return x

保存自己的网络:

Test = test()
# 保存自己定义的网络
torch.save(Test, "Test_method1.pth")

加载自己的网络:

model3 = torch.load("Test_method1.pth")

会报错!!!!!!

Python如何加载模型并查看网络

解决办法(需要注意):

将定义的网络复制到加载的python文件中:

class test(nn.Module):
    def __init__(self):
        super(test, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3)

    def forward(self, x):
        x = self.conv1(x)
        return x
model3 = torch.load("Test_method1.pth")

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python协程用法实例分析
Jun 04 Python
Python实现的基数排序算法原理与用法实例分析
Nov 23 Python
Python变量赋值的秘密分享
Apr 03 Python
为什么str(float)在Python 3中比Python 2返回更多的数字
Oct 16 Python
Python常见的pandas用法demo示例
Mar 16 Python
在VS2017中用C#调用python脚本的实现
Jul 31 Python
TFRecord格式存储数据与队列读取实例
Jan 21 Python
tensorflow 实现打印pb模型的所有节点
Jan 23 Python
Python PyQt5运行程序把输出信息展示到GUI图形界面上
Apr 27 Python
Django中使用Celery的方法步骤
Dec 07 Python
为2021年的第一场雪锦上添花:用matplotlib绘制雪花和雪景
Jan 05 Python
python如何读取和存储dict()与.json格式文件
Jun 25 Python
Python绘制散点图之可视化神器pyecharts
Jul 07 #Python
Python可视化神器pyecharts之绘制箱形图
Jul 07 #Python
Python通用验证码识别OCR库ddddocr的安装使用教程
Jul 07 #Python
Django数据库(SQlite)基本入门使用教程
Jul 07 #Python
Python可视化神器pyecharts之绘制地理图表练习
Django中celery的使用项目实例
Python可视化神器pyecharts绘制地理图表
You might like
Yii2 RESTful中api的使用及开发实例详解
2016/07/06 PHP
php将服务端的文件读出来显示在web页面实例
2016/10/31 PHP
PHP获取对象属性的三种方法实例分析
2019/01/03 PHP
PHP面向对象程序设计中的self、static、parent关键字用法分析
2019/08/14 PHP
javascript下查找父节点的简单方法
2007/08/13 Javascript
一些技巧性实用js代码小结
2009/10/14 Javascript
JavaScript DOM 学习第三章 内容表格
2010/02/19 Javascript
jWiard 基于JQuery的强大的向导控件介绍
2011/10/28 Javascript
js实现鼠标悬浮给图片加边框的方法
2015/01/30 Javascript
jquery validate表单验证的基本用法入门
2016/01/18 Javascript
浅析BootStrap栅格系统
2016/06/07 Javascript
全面解析jQuery $(document).ready()和JavaScript onload事件
2016/06/08 Javascript
解决bootstrap导航栏navbar在IE8上存在缺陷的方法
2016/07/01 Javascript
javascript中Date对象应用之简易日历实现
2016/07/12 Javascript
完全深入学习Bootstrap表单
2016/11/28 Javascript
JS实现的走迷宫小游戏完整实例
2017/07/19 Javascript
详解nodejs中express搭建权限管理系统
2017/09/15 NodeJs
详解Vue.js Mixins 混入使用
2017/09/15 Javascript
原生JS实现的多个彩色小球跟随鼠标移动动画效果示例
2018/02/01 Javascript
Vue中 key keep-alive的实现原理
2018/09/18 Javascript
解决Vue中使用keepAlive不缓存问题
2020/08/04 Javascript
vc6编写python扩展的方法分享
2014/01/17 Python
Python3.x对JSON的一些操作示例
2017/09/01 Python
python 实现求解字符串集的最长公共前缀方法
2018/07/20 Python
对python中的 os.mkdir和os.mkdirs详解
2018/10/16 Python
Pycharm中安装Pygal并使用Pygal模拟掷骰子(推荐)
2020/04/08 Python
CSS3实现鼠标悬停显示扩展内容
2016/08/24 HTML / CSS
美国老牌主机服务商:iPage
2016/07/22 全球购物
美国女性运动零售品牌:Lady Foot Locker
2017/05/12 全球购物
The North Face北面荷兰官网:美国著名户外品牌
2019/10/16 全球购物
社区志愿者培训方案
2014/06/10 职场文书
教师年度考核个人总结
2015/02/12 职场文书
关于运动会的广播稿
2015/08/19 职场文书
回门宴新娘答谢词
2015/09/29 职场文书
严以修身专题学习研讨会发言材料
2015/11/09 职场文书
Golang实现可重入锁的示例代码
2022/05/25 Golang