Python如何加载模型并查看网络


Posted in Python onJuly 15, 2022

加载模型并查看网络

加载模型,以vgg19为例。

打开终端

> python
Python 3.7.2 (tags/v3.7.2:9a3ffc0492, Dec 23 2018, 23:09:28) [MSC v.1916 64 bit
(AMD64)] on win32
Type "help", "copyright", "credits" or "license" for more information.
>>> from torchvision import models
>>> model = models.vgg19(pretrained=True) #此时如果是第一次加载会开始下载模型的pth文件
>>> print(model.model)

结果:

VGG(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU(inplace)
    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU(inplace)
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace)
    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU(inplace)
    (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): ReLU(inplace)
    (16): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (17): ReLU(inplace)
    (18): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (19): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (20): ReLU(inplace)
    (21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (22): ReLU(inplace)
    (23): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (24): ReLU(inplace)
    (25): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (26): ReLU(inplace)
    (27): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (29): ReLU(inplace)
    (30): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (31): ReLU(inplace)
    (32): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (33): ReLU(inplace)
    (34): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (35): ReLU(inplace)
    (36): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(7, 7))
  (classifier): Sequential(
    (0): Linear(in_features=25088, out_features=4096, bias=True)
    (1): ReLU(inplace)
    (2): Dropout(p=0.5)
    (3): Linear(in_features=4096, out_features=4096, bias=True)
    (4): ReLU(inplace)
    (5): Dropout(p=0.5)
    (6): Linear(in_features=4096, out_features=1000, bias=True)
  )
)

注意,直接打印模型是没有办法看到模型结构的,只能看到带模型参数的pth文件内容;需要打印model.model才可以看到模型本身。

神经网络_模型的保存,模型的加载

模型的保存(torch.save)

方式1(模型结构+模型参数)

参数:保存位置

# 创建模型
vgg16 = torchvision.models.vgg16(pretrained=False)
# 保存方式1——模型结构+模型参数
torch.save(vgg16, "vgg16_method1.pth")

方式2(模型参数)

# 保存方式2  模型参数(官方推荐)。保存成字典,只保存网络模型中的一些参数
torch.save(vgg16.state_dict(), "vgg16_method2.pth")

模型的加载(torch.load)

对应保存方式1

参数:模型路径

# 方式1 --》 保存方式1
model1 = torch.load("vgg16_method1.pth")

对应保存方式2

vgg16.load_state_dict("vgg16_method2.pth")

输出为字典形式。若要回复网络,采用以下形式:

model2 = torch.load("vgg16_method2.pth")  #输出是字典形式
# 恢复网络结构
vgg16 = torchvision.models.vgg16(pretrained=False)
vgg16.load_state_dict(model2)

方式1存储,加载时需注意事项

新建自己的网络:

class test(nn.Module):
    def __init__(self):
        super(lh, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3)

    def forward(self, x):
        x = self.conv1(x)
        return x

保存自己的网络:

Test = test()
# 保存自己定义的网络
torch.save(Test, "Test_method1.pth")

加载自己的网络:

model3 = torch.load("Test_method1.pth")

会报错!!!!!!

Python如何加载模型并查看网络

解决办法(需要注意):

将定义的网络复制到加载的python文件中:

class test(nn.Module):
    def __init__(self):
        super(test, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3)

    def forward(self, x):
        x = self.conv1(x)
        return x
model3 = torch.load("Test_method1.pth")

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python中使用 Selenium 实现网页截图实例
Jul 18 Python
使用Python的PEAK来适配协议的教程
Apr 14 Python
读写json中文ASCII乱码问题的解决方法
Nov 05 Python
Python爬虫_城市公交、地铁站点和线路数据采集实例
Jan 10 Python
python基础教程项目二之画幅好画
Apr 02 Python
python删除不需要的python文件方法
Apr 24 Python
一步步教你用python的scrapy编写一个爬虫
Apr 17 Python
python字符串替换re.sub()方法解析
Sep 18 Python
使用TensorFlow对图像进行随机旋转的实现示例
Jan 20 Python
NumPy排序的实现
Jan 21 Python
详解python第三方库的安装、PyInstaller库、random库
Mar 03 Python
python 爬取天气网卫星图片
Jun 07 Python
Python绘制散点图之可视化神器pyecharts
Jul 07 #Python
Python可视化神器pyecharts之绘制箱形图
Jul 07 #Python
Python通用验证码识别OCR库ddddocr的安装使用教程
Jul 07 #Python
Django数据库(SQlite)基本入门使用教程
Jul 07 #Python
Python可视化神器pyecharts之绘制地理图表练习
Django中celery的使用项目实例
Python可视化神器pyecharts绘制地理图表
You might like
WML,Apache,和 PHP 的介绍
2006/10/09 PHP
php中引用符号(&)的使用详解
2013/11/13 PHP
PHP实现动态添加XML中数据的方法
2018/03/30 PHP
Mac系统下搭建Nginx+php-fpm实例讲解
2020/12/15 PHP
破除网页鼠标右键被禁用的绝招大全
2006/12/27 Javascript
url 特殊字符 传递参数解决方法
2010/01/01 Javascript
js操作二级联动实现代码
2010/07/27 Javascript
基于jquery的拖动布局插件
2011/11/25 Javascript
jQuery LigerUI 使用教程入门篇
2012/01/18 Javascript
SwfUpload在IE10上不出现上传按钮的解决方法
2013/06/25 Javascript
详解Angular.js的$q.defer()服务异步处理
2016/11/06 Javascript
JS简单判断字符在另一个字符串中出现次数的2种常用方法
2017/04/20 Javascript
AngularJs定时器$interval 和 $timeout详解
2017/05/25 Javascript
angularjs性能优化的方法
2018/09/05 Javascript
React Native开发封装Toast与加载Loading组件示例
2018/09/08 Javascript
解决layui中onchange失效以及form动态渲染失效的问题
2019/09/27 Javascript
vue页面切换项目实现转场动画的方法
2019/11/12 Javascript
Vue+Element自定义纵向表格表头教程
2020/10/26 Javascript
vue 图片裁剪上传组件的实现
2020/11/12 Javascript
[01:11]steam端dota2实名认证操作流程视频
2021/03/11 DOTA
python类的继承实例详解
2017/03/30 Python
selenium获取当前页面的url、源码、title的方法
2019/06/12 Python
Pycharm运行加载文本出现错误的解决方法
2019/06/27 Python
python读取raw binary图片并提取统计信息的实例
2020/01/09 Python
使用css3背景渐变中的透明度来设置不同颜色的背景渐变
2014/03/31 HTML / CSS
HTML5在线预览PDF的示例代码
2017/09/14 HTML / CSS
用缩写的指针比较"if(p)" 检查空指针是否可靠?如果空指针的内部表达不是0会怎么样?
2014/01/05 面试题
如何估计一张表的大小(假设该表中有1万条数据)
2016/03/27 面试题
幼儿园校车司机的岗位职责
2014/01/30 职场文书
2014年银行客户经理工作总结
2014/11/12 职场文书
实习单位意见
2015/06/04 职场文书
房贷工资证明范本
2015/06/12 职场文书
Canvas三种动态画圆实现方法说明(小结)
2021/04/16 Javascript
Python机器学习之PCA降维算法详解
2021/05/19 Python
Python四款GUI图形界面库介绍
2022/06/05 Python
微软发布Windows 11今年最大更新22H2(附 ISO 镜像官方下载)
2022/09/23 数码科技