Python如何加载模型并查看网络


Posted in Python onJuly 15, 2022

加载模型并查看网络

加载模型,以vgg19为例。

打开终端

> python
Python 3.7.2 (tags/v3.7.2:9a3ffc0492, Dec 23 2018, 23:09:28) [MSC v.1916 64 bit
(AMD64)] on win32
Type "help", "copyright", "credits" or "license" for more information.
>>> from torchvision import models
>>> model = models.vgg19(pretrained=True) #此时如果是第一次加载会开始下载模型的pth文件
>>> print(model.model)

结果:

VGG(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU(inplace)
    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU(inplace)
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace)
    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU(inplace)
    (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): ReLU(inplace)
    (16): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (17): ReLU(inplace)
    (18): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (19): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (20): ReLU(inplace)
    (21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (22): ReLU(inplace)
    (23): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (24): ReLU(inplace)
    (25): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (26): ReLU(inplace)
    (27): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (29): ReLU(inplace)
    (30): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (31): ReLU(inplace)
    (32): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (33): ReLU(inplace)
    (34): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (35): ReLU(inplace)
    (36): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(7, 7))
  (classifier): Sequential(
    (0): Linear(in_features=25088, out_features=4096, bias=True)
    (1): ReLU(inplace)
    (2): Dropout(p=0.5)
    (3): Linear(in_features=4096, out_features=4096, bias=True)
    (4): ReLU(inplace)
    (5): Dropout(p=0.5)
    (6): Linear(in_features=4096, out_features=1000, bias=True)
  )
)

注意,直接打印模型是没有办法看到模型结构的,只能看到带模型参数的pth文件内容;需要打印model.model才可以看到模型本身。

神经网络_模型的保存,模型的加载

模型的保存(torch.save)

方式1(模型结构+模型参数)

参数:保存位置

# 创建模型
vgg16 = torchvision.models.vgg16(pretrained=False)
# 保存方式1——模型结构+模型参数
torch.save(vgg16, "vgg16_method1.pth")

方式2(模型参数)

# 保存方式2  模型参数(官方推荐)。保存成字典,只保存网络模型中的一些参数
torch.save(vgg16.state_dict(), "vgg16_method2.pth")

模型的加载(torch.load)

对应保存方式1

参数:模型路径

# 方式1 --》 保存方式1
model1 = torch.load("vgg16_method1.pth")

对应保存方式2

vgg16.load_state_dict("vgg16_method2.pth")

输出为字典形式。若要回复网络,采用以下形式:

model2 = torch.load("vgg16_method2.pth")  #输出是字典形式
# 恢复网络结构
vgg16 = torchvision.models.vgg16(pretrained=False)
vgg16.load_state_dict(model2)

方式1存储,加载时需注意事项

新建自己的网络:

class test(nn.Module):
    def __init__(self):
        super(lh, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3)

    def forward(self, x):
        x = self.conv1(x)
        return x

保存自己的网络:

Test = test()
# 保存自己定义的网络
torch.save(Test, "Test_method1.pth")

加载自己的网络:

model3 = torch.load("Test_method1.pth")

会报错!!!!!!

Python如何加载模型并查看网络

解决办法(需要注意):

将定义的网络复制到加载的python文件中:

class test(nn.Module):
    def __init__(self):
        super(test, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3)

    def forward(self, x):
        x = self.conv1(x)
        return x
model3 = torch.load("Test_method1.pth")

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python executemany的使用及注意事项
Mar 13 Python
python3爬取淘宝信息代码分析
Feb 10 Python
python实现数据导出到excel的示例--普通格式
May 03 Python
解决Python获取字典dict中不存在的值时出错问题
Oct 17 Python
浅谈python中拼接路径os.path.join斜杠的问题
Oct 23 Python
Python完成毫秒级抢淘宝大单功能
Jun 06 Python
python3实现斐波那契数列(4种方法)
Jul 15 Python
pytorch加载自定义网络权重的实现
Jan 07 Python
python3 字符串知识点学习笔记
Feb 08 Python
简单了解如何封装自己的Python包
Jul 08 Python
Python实现FTP文件定时自动下载的步骤
Dec 19 Python
Python机器学习算法之决策树算法的实现与优缺点
May 13 Python
Python绘制散点图之可视化神器pyecharts
Jul 07 #Python
Python可视化神器pyecharts之绘制箱形图
Jul 07 #Python
Python通用验证码识别OCR库ddddocr的安装使用教程
Jul 07 #Python
Django数据库(SQlite)基本入门使用教程
Jul 07 #Python
Python可视化神器pyecharts之绘制地理图表练习
Django中celery的使用项目实例
Python可视化神器pyecharts绘制地理图表
You might like
什么是短波收听SWL
2021/03/01 无线电
分享PHP入门的学习方法
2007/01/02 PHP
PHP统计目录下的文件总数及代码行数(去除注释及空行)
2011/01/17 PHP
迁移PHP版本到PHP7
2015/02/06 PHP
Ajax和PHP正则表达式验证表单及验证码
2016/09/24 PHP
PHP laravel中的多对多关系实例详解
2017/06/07 PHP
Laravel5.1 框架数据库操作DB运行原生SQL的方法分析
2020/01/07 PHP
一个刚完成的layout(拖动流畅,不受iframe影响)
2007/08/17 Javascript
javascript操作cookie_获取与修改代码
2009/05/21 Javascript
javascript中的五种基本数据类型
2015/08/26 Javascript
JavaScript 七大技巧(一)
2015/12/13 Javascript
Function.prototype.apply()与Function.prototype.call()小结
2016/04/27 Javascript
JS组件系列之Bootstrap table表格组件神器【终结篇】
2016/05/10 Javascript
bootstrap监听滚动实现头部跟随滚动
2016/11/08 Javascript
vue2.x 父组件监听子组件事件并传回信息的方法
2017/07/17 Javascript
使用vue构建移动应用实战代码
2017/08/02 Javascript
three.js实现3D模型展示的示例代码
2017/12/31 Javascript
浅析vue插槽和作用域插槽的理解
2019/04/22 Javascript
在layui tab控件中载入外部html页面的方法
2019/09/04 Javascript
Windows系统配置python脚本开机启动的3种方法分享
2015/03/10 Python
python实现的用于搜索文件并进行内容替换的类实例
2015/06/28 Python
Python学习之用pygal画世界地图实例
2017/12/07 Python
基于python OpenCV实现动态人脸检测
2018/05/25 Python
python实现推箱子游戏
2020/03/25 Python
Django项目后台不挂断运行的方法
2019/08/31 Python
安装python及pycharm的教程图解
2019/10/10 Python
使用pyshp包进行shapefile文件修改的例子
2019/12/06 Python
Django如何使用asyncio协程和ThreadPoolExecutor多线程
2020/10/12 Python
python中把元组转换为namedtuple方法
2020/12/09 Python
加拿大领先的优质厨具产品在线购物网站:Golda’s Kitchen
2017/11/17 全球购物
Under Armour安德玛中国官网:美国高端运动科技品牌
2018/03/09 全球购物
书法培训心得体会
2014/01/05 职场文书
2015中学政教处工作总结
2015/07/22 职场文书
SQL Server中常用截取字符串函数介绍
2022/03/16 SQL Server
基于Python实现射击小游戏的制作
2022/04/06 Python
纯CSS实现一个简单步骤条的示例代码
2022/07/15 HTML / CSS