Python如何加载模型并查看网络


Posted in Python onJuly 15, 2022

加载模型并查看网络

加载模型,以vgg19为例。

打开终端

> python
Python 3.7.2 (tags/v3.7.2:9a3ffc0492, Dec 23 2018, 23:09:28) [MSC v.1916 64 bit
(AMD64)] on win32
Type "help", "copyright", "credits" or "license" for more information.
>>> from torchvision import models
>>> model = models.vgg19(pretrained=True) #此时如果是第一次加载会开始下载模型的pth文件
>>> print(model.model)

结果:

VGG(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU(inplace)
    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU(inplace)
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace)
    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU(inplace)
    (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): ReLU(inplace)
    (16): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (17): ReLU(inplace)
    (18): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (19): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (20): ReLU(inplace)
    (21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (22): ReLU(inplace)
    (23): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (24): ReLU(inplace)
    (25): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (26): ReLU(inplace)
    (27): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (29): ReLU(inplace)
    (30): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (31): ReLU(inplace)
    (32): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (33): ReLU(inplace)
    (34): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (35): ReLU(inplace)
    (36): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(7, 7))
  (classifier): Sequential(
    (0): Linear(in_features=25088, out_features=4096, bias=True)
    (1): ReLU(inplace)
    (2): Dropout(p=0.5)
    (3): Linear(in_features=4096, out_features=4096, bias=True)
    (4): ReLU(inplace)
    (5): Dropout(p=0.5)
    (6): Linear(in_features=4096, out_features=1000, bias=True)
  )
)

注意,直接打印模型是没有办法看到模型结构的,只能看到带模型参数的pth文件内容;需要打印model.model才可以看到模型本身。

神经网络_模型的保存,模型的加载

模型的保存(torch.save)

方式1(模型结构+模型参数)

参数:保存位置

# 创建模型
vgg16 = torchvision.models.vgg16(pretrained=False)
# 保存方式1——模型结构+模型参数
torch.save(vgg16, "vgg16_method1.pth")

方式2(模型参数)

# 保存方式2  模型参数(官方推荐)。保存成字典,只保存网络模型中的一些参数
torch.save(vgg16.state_dict(), "vgg16_method2.pth")

模型的加载(torch.load)

对应保存方式1

参数:模型路径

# 方式1 --》 保存方式1
model1 = torch.load("vgg16_method1.pth")

对应保存方式2

vgg16.load_state_dict("vgg16_method2.pth")

输出为字典形式。若要回复网络,采用以下形式:

model2 = torch.load("vgg16_method2.pth")  #输出是字典形式
# 恢复网络结构
vgg16 = torchvision.models.vgg16(pretrained=False)
vgg16.load_state_dict(model2)

方式1存储,加载时需注意事项

新建自己的网络:

class test(nn.Module):
    def __init__(self):
        super(lh, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3)

    def forward(self, x):
        x = self.conv1(x)
        return x

保存自己的网络:

Test = test()
# 保存自己定义的网络
torch.save(Test, "Test_method1.pth")

加载自己的网络:

model3 = torch.load("Test_method1.pth")

会报错!!!!!!

Python如何加载模型并查看网络

解决办法(需要注意):

将定义的网络复制到加载的python文件中:

class test(nn.Module):
    def __init__(self):
        super(test, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3)

    def forward(self, x):
        x = self.conv1(x)
        return x
model3 = torch.load("Test_method1.pth")

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python中SOAP项目的介绍及其在web开发中的应用
Apr 14 Python
遗传算法之Python实现代码
Oct 10 Python
Django ORM框架的定时任务如何使用详解
Oct 19 Python
Python3实现发送QQ邮件功能(html)
Dec 15 Python
matplotlib绘图实例演示标记路径
Jan 23 Python
selenium设置proxy、headers的方法(phantomjs、Chrome、Firefox)
Nov 29 Python
对Python强大的可变参数传递机制详解
Jun 13 Python
Python之pymysql的使用小结
Jul 01 Python
基于 Django 的手机管理系统实现过程详解
Aug 16 Python
matplotlib绘制鼠标的十字光标的实现(自定义方式,官方实例)
Jan 10 Python
Python爬虫入门教程01之爬取豆瓣Top电影
Jan 24 Python
Python 中的 copy()和deepcopy()
Nov 07 Python
Python绘制散点图之可视化神器pyecharts
Jul 07 #Python
Python可视化神器pyecharts之绘制箱形图
Jul 07 #Python
Python通用验证码识别OCR库ddddocr的安装使用教程
Jul 07 #Python
Django数据库(SQlite)基本入门使用教程
Jul 07 #Python
Python可视化神器pyecharts之绘制地理图表练习
Django中celery的使用项目实例
Python可视化神器pyecharts绘制地理图表
You might like
数据库的日期格式转换
2006/10/09 PHP
PHP之生成GIF动画的实现方法
2013/06/07 PHP
php输出反斜杠的实例方法
2019/09/19 PHP
php获取是星期几的的一些常用姿势
2019/12/15 PHP
JQuery设置获取下拉菜单某个选项的值(比较全)
2014/08/05 Javascript
JavaScript中的分号插入机制详细介绍
2015/02/11 Javascript
跟我学习javascript的执行上下文
2015/11/18 Javascript
基于jquery实现表格无刷新分页
2016/01/07 Javascript
js判断图片加载完成后获取图片实际宽高的方法
2016/02/25 Javascript
两种JavaScript的AES加密方式(可与Java相互加解密)
2016/08/02 Javascript
node.js学习之交互式解释器REPL详解
2016/12/08 Javascript
Jquery树插件zTree实现菜单树
2017/01/24 Javascript
深入理解nodejs中Express的中间件
2017/05/19 NodeJs
微信小程序之页面拦截器的示例代码
2017/09/07 Javascript
详解webpack3编译兼容IE8的正确姿势
2017/12/21 Javascript
原生JS实现的多个彩色小球跟随鼠标移动动画效果示例
2018/02/01 Javascript
vue 国际化 vue-i18n 双语言 语言包
2018/06/07 Javascript
Echarts之悬浮框中的数据排序问题
2018/11/08 Javascript
vue中eslintrc.js配置最详细介绍
2018/12/21 Javascript
微信小程序云开发之新手环境配置
2019/05/16 Javascript
[01:14:10]2014 DOTA2国际邀请赛中国区预选赛 SPD-GAMING VS Orenda
2014/05/22 DOTA
[06:01]刀塔次级联赛top10第一期
2014/11/07 DOTA
Python之os操作方法(详解)
2017/06/15 Python
利用python实现在微信群刷屏的方法
2019/02/21 Python
Pytorch 实现冻结指定卷积层的参数
2020/01/06 Python
python+selenium定时爬取丁香园的新型冠状病毒数据并制作出类似的地图(部署到云服务器)
2020/02/09 Python
HTML5开发动态音频图的实现
2020/07/02 HTML / CSS
static全局变量与普通的全局变量有什么区别?static局部变量和普通局部变量有什么区别?static函数与普通函数有什么区别?
2015/02/22 面试题
小学生打架检讨书
2014/01/26 职场文书
个人廉洁自律承诺书
2014/03/27 职场文书
社区领导班子四风问题原因分析及整改措施
2014/09/28 职场文书
2014年村计划生育工作总结
2014/11/14 职场文书
2016年习主席讲话学习心得体会
2016/01/20 职场文书
《红领巾真好》教学反思
2016/02/16 职场文书
vue实现列表拖拽排序的示例代码
2022/04/08 Vue.js
Java中的继承、多态以及封装
2022/04/11 Java/Android