Python如何加载模型并查看网络


Posted in Python onJuly 15, 2022

加载模型并查看网络

加载模型,以vgg19为例。

打开终端

> python
Python 3.7.2 (tags/v3.7.2:9a3ffc0492, Dec 23 2018, 23:09:28) [MSC v.1916 64 bit
(AMD64)] on win32
Type "help", "copyright", "credits" or "license" for more information.
>>> from torchvision import models
>>> model = models.vgg19(pretrained=True) #此时如果是第一次加载会开始下载模型的pth文件
>>> print(model.model)

结果:

VGG(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU(inplace)
    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU(inplace)
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace)
    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU(inplace)
    (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): ReLU(inplace)
    (16): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (17): ReLU(inplace)
    (18): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (19): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (20): ReLU(inplace)
    (21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (22): ReLU(inplace)
    (23): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (24): ReLU(inplace)
    (25): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (26): ReLU(inplace)
    (27): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (29): ReLU(inplace)
    (30): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (31): ReLU(inplace)
    (32): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (33): ReLU(inplace)
    (34): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (35): ReLU(inplace)
    (36): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(7, 7))
  (classifier): Sequential(
    (0): Linear(in_features=25088, out_features=4096, bias=True)
    (1): ReLU(inplace)
    (2): Dropout(p=0.5)
    (3): Linear(in_features=4096, out_features=4096, bias=True)
    (4): ReLU(inplace)
    (5): Dropout(p=0.5)
    (6): Linear(in_features=4096, out_features=1000, bias=True)
  )
)

注意,直接打印模型是没有办法看到模型结构的,只能看到带模型参数的pth文件内容;需要打印model.model才可以看到模型本身。

神经网络_模型的保存,模型的加载

模型的保存(torch.save)

方式1(模型结构+模型参数)

参数:保存位置

# 创建模型
vgg16 = torchvision.models.vgg16(pretrained=False)
# 保存方式1——模型结构+模型参数
torch.save(vgg16, "vgg16_method1.pth")

方式2(模型参数)

# 保存方式2  模型参数(官方推荐)。保存成字典,只保存网络模型中的一些参数
torch.save(vgg16.state_dict(), "vgg16_method2.pth")

模型的加载(torch.load)

对应保存方式1

参数:模型路径

# 方式1 --》 保存方式1
model1 = torch.load("vgg16_method1.pth")

对应保存方式2

vgg16.load_state_dict("vgg16_method2.pth")

输出为字典形式。若要回复网络,采用以下形式:

model2 = torch.load("vgg16_method2.pth")  #输出是字典形式
# 恢复网络结构
vgg16 = torchvision.models.vgg16(pretrained=False)
vgg16.load_state_dict(model2)

方式1存储,加载时需注意事项

新建自己的网络:

class test(nn.Module):
    def __init__(self):
        super(lh, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3)

    def forward(self, x):
        x = self.conv1(x)
        return x

保存自己的网络:

Test = test()
# 保存自己定义的网络
torch.save(Test, "Test_method1.pth")

加载自己的网络:

model3 = torch.load("Test_method1.pth")

会报错!!!!!!

Python如何加载模型并查看网络

解决办法(需要注意):

将定义的网络复制到加载的python文件中:

class test(nn.Module):
    def __init__(self):
        super(test, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3)

    def forward(self, x):
        x = self.conv1(x)
        return x
model3 = torch.load("Test_method1.pth")

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python入门之modf()方法的使用
May 15 Python
Python中import机制详解
Nov 14 Python
numpy的文件存储.npy .npz 文件详解
Jul 09 Python
python 从文件夹抽取图片另存的方法
Dec 04 Python
python求最大值最小值方法总结
Jun 25 Python
python3中类的继承以及self和super的区别详解
Jun 26 Python
python 杀死自身进程的实现方法
Jul 01 Python
Python scipy的二维图像卷积运算与图像模糊处理操作示例
Sep 06 Python
django项目中新增app的2种实现方法
Apr 01 Python
Python实现SMTP邮件发送
Jun 16 Python
python制作微博图片爬取工具
Jan 16 Python
Django中celery的使用项目实例
Jul 07 Python
Python绘制散点图之可视化神器pyecharts
Jul 07 #Python
Python可视化神器pyecharts之绘制箱形图
Jul 07 #Python
Python通用验证码识别OCR库ddddocr的安装使用教程
Jul 07 #Python
Django数据库(SQlite)基本入门使用教程
Jul 07 #Python
Python可视化神器pyecharts之绘制地理图表练习
Django中celery的使用项目实例
Python可视化神器pyecharts绘制地理图表
You might like
php file_exists 检查文件或目录是否存在的函数
2010/05/10 PHP
PHP安全性漫谈
2012/06/28 PHP
php实现图形显示Ip地址的代码及注释
2014/01/20 PHP
ThinkPHP控制器里javascript代码不能执行的解决方法
2014/11/22 PHP
PHPStrom 新建FTP项目以及在线操作教程
2016/10/16 PHP
jQuery getJSON 处理json数据的代码
2010/07/26 Javascript
jquery判断复选框是否选中进行答题提示特效
2015/12/10 Javascript
深入解析AngularJS框架中$scope的作用与生命周期
2016/03/05 Javascript
jQuery animate和CSS3相结合实现缓动追逐效果附源码下载
2016/04/18 Javascript
微信小程序 首页制作简单实例
2017/04/07 Javascript
详解tween.js的使用教程
2017/09/14 Javascript
AngularJS基于MVC的复杂操作实例讲解
2017/12/31 Javascript
使用watch监听路由变化和watch监听对象的实例
2018/02/24 Javascript
基于rollup的组件库打包体积优化小结
2018/06/18 Javascript
webpack+vue+express(hot)热启动调试简单配置方法
2018/09/19 Javascript
优雅地使用loading(推荐)
2019/04/20 Javascript
微信小程序封装的HTTP请求示例【附升级版】
2019/05/11 Javascript
JS实现简单的文字无缝上下滚动功能示例
2019/06/22 Javascript
layer关闭当前窗口页面以及确认取消按钮的方法
2019/09/09 Javascript
vue中js判断长时间不操作界面自动退出登录(推荐)
2020/01/22 Javascript
python3图片转换二进制存入mysql
2013/12/06 Python
python编程开发之日期操作实例分析
2015/11/13 Python
Python中的字符串类型基本知识学习教程
2016/02/04 Python
Python的Tornado框架实现异步非阻塞访问数据库的示例
2016/06/30 Python
python实现的多线程端口扫描功能示例
2017/01/21 Python
python如何定义带参数的装饰器
2018/03/20 Python
python3使用smtplib实现发送邮件功能
2018/05/22 Python
Matplotlib 生成不同大小的subplots实例
2018/05/25 Python
Python MongoDB 插入数据时已存在则不执行,不存在则插入的解决方法
2019/09/24 Python
python编写softmax函数、交叉熵函数实例
2020/06/11 Python
如何使用Python进行PDF图片识别OCR
2021/01/22 Python
基于canvas使用贝塞尔曲线平滑拟合折线段的方法
2018/01/10 HTML / CSS
史蒂夫·马登加拿大官网:Steve Madden加拿大
2017/11/18 全球购物
软件测试面试题
2015/10/21 面试题
网络教育自我鉴定
2013/11/01 职场文书
2019年自助餐厅创业计划书模板
2019/08/22 职场文书