PyTorch快速搭建神经网络及其保存提取方法详解


Posted in Python onApril 28, 2018

有时候我们训练了一个模型, 希望保存它下次直接使用,不需要下次再花时间去训练 ,本节我们来讲解一下PyTorch快速搭建神经网络及其保存提取方法详解

一、PyTorch快速搭建神经网络方法

先看实验代码:

import torch 
import torch.nn.functional as F 
 
# 方法1,通过定义一个Net类来建立神经网络 
class Net(torch.nn.Module): 
  def __init__(self, n_feature, n_hidden, n_output): 
    super(Net, self).__init__() 
    self.hidden = torch.nn.Linear(n_feature, n_hidden) 
    self.predict = torch.nn.Linear(n_hidden, n_output) 
 
  def forward(self, x): 
    x = F.relu(self.hidden(x)) 
    x = self.predict(x) 
    return x 
 
net1 = Net(2, 10, 2) 
print('方法1:\n', net1) 
 
# 方法2 通过torch.nn.Sequential快速建立神经网络结构 
net2 = torch.nn.Sequential( 
  torch.nn.Linear(2, 10), 
  torch.nn.ReLU(), 
  torch.nn.Linear(10, 2), 
  ) 
print('方法2:\n', net2) 
# 经验证,两种方法构建的神经网络功能相同,结构细节稍有不同 
 
''''' 
方法1: 
 Net ( 
 (hidden): Linear (2 -> 10) 
 (predict): Linear (10 -> 2) 
) 
方法2: 
 Sequential ( 
 (0): Linear (2 -> 10) 
 (1): ReLU () 
 (2): Linear (10 -> 2) 
) 
'''

先前学习了通过定义一个Net类来构建神经网络的方法,classNet中首先通过super函数继承torch.nn.Module模块的构造方法,再通过添加属性的方式搭建神经网络各层的结构信息,在forward方法中完善神经网络各层之间的连接信息,然后再通过定义Net类对象的方式完成对神经网络结构的构建。

构建神经网络的另一个方法,也可以说是快速构建方法,就是通过torch.nn.Sequential,直接完成对神经网络的建立。

两种方法构建得到的神经网络结构完全相同,都可以通过print函数来打印输出网络信息,不过打印结果会有些许不同。

二、PyTorch的神经网络保存和提取

在学习和研究深度学习的时候,当我们通过一定时间的训练,得到了一个比较好的模型的时候,我们当然希望将这个模型及模型参数保存下来,以备后用,所以神经网络的保存和模型参数提取重载是很有必要的。

首先,我们需要在需要保存网路结构及其模型参数的神经网络的定义、训练部分之后通过torch.save()实现对网络结构和模型参数的保存。有两种保存方式:一是保存年整个神经网络的的结构信息和模型参数信息,save的对象是网络net;二是只保存神经网络的训练模型参数,save的对象是net.state_dict(),保存结果都以.pkl文件形式存储。

对应上面两种保存方式,重载方式也有两种。对应第一种完整网络结构信息,重载的时候通过torch.load(‘.pkl')直接初始化新的神经网络对象即可。对应第二种只保存模型参数信息,需要首先搭建相同的神经网络结构,通过net.load_state_dict(torch.load('.pkl'))完成模型参数的重载。在网络比较大的时候,第一种方法会花费较多的时间。

代码实现:

import torch 
from torch.autograd import Variable 
import matplotlib.pyplot as plt 
 
torch.manual_seed(1) # 设定随机数种子 
 
# 创建数据 
x = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1) 
y = x.pow(2) + 0.2*torch.rand(x.size()) 
x, y = Variable(x, requires_grad=False), Variable(y, requires_grad=False) 
 
# 将待保存的神经网络定义在一个函数中 
def save(): 
  # 神经网络结构 
  net1 = torch.nn.Sequential( 
    torch.nn.Linear(1, 10), 
    torch.nn.ReLU(), 
    torch.nn.Linear(10, 1), 
    ) 
  optimizer = torch.optim.SGD(net1.parameters(), lr=0.5) 
  loss_function = torch.nn.MSELoss() 
 
  # 训练部分 
  for i in range(300): 
    prediction = net1(x) 
    loss = loss_function(prediction, y) 
    optimizer.zero_grad() 
    loss.backward() 
    optimizer.step() 
 
  # 绘图部分 
  plt.figure(1, figsize=(10, 3)) 
  plt.subplot(131) 
  plt.title('net1') 
  plt.scatter(x.data.numpy(), y.data.numpy()) 
  plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5) 
 
  # 保存神经网络 
  torch.save(net1, '7-net.pkl')           # 保存整个神经网络的结构和模型参数 
  torch.save(net1.state_dict(), '7-net_params.pkl') # 只保存神经网络的模型参数 
 
# 载入整个神经网络的结构及其模型参数 
def reload_net(): 
  net2 = torch.load('7-net.pkl') 
  prediction = net2(x) 
 
  plt.subplot(132) 
  plt.title('net2') 
  plt.scatter(x.data.numpy(), y.data.numpy()) 
  plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5) 
 
# 只载入神经网络的模型参数,神经网络的结构需要与保存的神经网络相同的结构 
def reload_params(): 
  # 首先搭建相同的神经网络结构 
  net3 = torch.nn.Sequential( 
    torch.nn.Linear(1, 10), 
    torch.nn.ReLU(), 
    torch.nn.Linear(10, 1), 
    ) 
 
  # 载入神经网络的模型参数 
  net3.load_state_dict(torch.load('7-net_params.pkl')) 
  prediction = net3(x) 
 
  plt.subplot(133) 
  plt.title('net3') 
  plt.scatter(x.data.numpy(), y.data.numpy()) 
  plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5) 
 
# 运行测试 
save() 
reload_net() 
reload_params()

实验结果:

PyTorch快速搭建神经网络及其保存提取方法详解

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python和Ruby中each循环引用变量问题(一个隐秘BUG?)
Jun 04 Python
Python扫描IP段查看指定端口是否开放的方法
Jun 09 Python
Python selenium文件上传方法汇总
Nov 19 Python
Win7 64位下python3.6.5安装配置图文教程
Oct 27 Python
python写入已存在的excel数据实例
May 03 Python
python实现遍历文件夹修改文件后缀
Aug 28 Python
python 执行文件时额外参数获取的实例
Dec 18 Python
详解pandas使用drop_duplicates去除DataFrame重复项参数
Aug 01 Python
Python matplotlib画曲线例题解析
Feb 07 Python
解决Keras 自定义层时遇到版本的问题
Jun 16 Python
Python过滤序列元素的方法
Jul 31 Python
Python echarts实现数据可视化实例详解
Mar 03 Python
对Python中type打开文件的方式介绍
Apr 28 #Python
PyTorch上搭建简单神经网络实现回归和分类的示例
Apr 28 #Python
TensorFlow实现非线性支持向量机的实现方法
Apr 28 #Python
python 通过logging写入日志到文件和控制台的实例
Apr 28 #Python
Python实现合并同一个文件夹下所有PDF文件的方法示例
Apr 28 #Python
用TensorFlow实现多类支持向量机的示例代码
Apr 28 #Python
详谈python在windows中的文件路径问题
Apr 28 #Python
You might like
一个取得文件扩展名的函数
2006/10/09 PHP
简单的PHP图片上传程序
2008/03/27 PHP
PHP中改变图片的尺寸大小的代码
2011/07/17 PHP
Python中使用django form表单验证的方法
2017/01/16 PHP
Laravel框架下载,安装及路由操作图文详解
2019/12/04 PHP
PHP执行系统命令函数实例讲解
2021/03/03 PHP
JS OOP包机制,类创建的方法定义
2009/11/02 Javascript
用jQuery打造TabPanel效果代码
2010/05/22 Javascript
JavaScript中几种常见排序算法小结
2011/02/22 Javascript
使用隐藏的new来创建对象
2011/03/29 Javascript
js淡入淡出的图片轮播效果代码分享
2015/08/24 Javascript
Angular.js跨controller实现参数传递的两种方法
2017/02/20 Javascript
80%应聘者都不及格的JS面试题
2017/03/21 Javascript
three.js中文文档学习之如何本地运行详解
2017/11/20 Javascript
[47:08]OG vs INfamous 2019国际邀请赛小组赛 BO2 第一场 8.15
2019/08/17 DOTA
[01:00:13]完美世界DOTA2联赛 LBZS vs Forest 第一场 11.07
2020/11/09 DOTA
一张图带我们入门Python基础教程
2017/02/05 Python
python安装oracle扩展及数据库连接方法
2017/02/21 Python
Android 兼容性问题:java.lang.UnsupportedOperationException解决办法
2017/03/19 Python
Python根据指定日期计算后n天,前n天是哪一天的方法
2018/05/29 Python
Python利用递归实现文件的复制方法
2018/10/27 Python
python实现浪漫的烟花秀
2019/01/30 Python
Django自定义模板过滤器和标签的实现方法
2019/08/21 Python
Python decorator拦截器代码实例解析
2020/04/04 Python
Python HTMLTestRunner可视化报告实现过程解析
2020/04/10 Python
真正了解CSS3背景下的@font face规则
2017/05/04 HTML / CSS
里程积分管理买卖交换平台:Points.com
2017/01/13 全球购物
python re模块和正则表达式
2021/03/24 Python
JavaScript实现前端网页版倒计时
2021/03/24 Javascript
采购部部门职责
2013/12/15 职场文书
应届毕业生求职自荐书
2014/01/03 职场文书
委托书范本
2014/04/02 职场文书
自查自纠工作总结
2014/10/15 职场文书
出差报告格式模板
2014/11/06 职场文书
2015大学生求职信范文
2015/03/20 职场文书
MySQL的Query Cache图文详解
2021/07/01 MySQL