PyTorch快速搭建神经网络及其保存提取方法详解


Posted in Python onApril 28, 2018

有时候我们训练了一个模型, 希望保存它下次直接使用,不需要下次再花时间去训练 ,本节我们来讲解一下PyTorch快速搭建神经网络及其保存提取方法详解

一、PyTorch快速搭建神经网络方法

先看实验代码:

import torch 
import torch.nn.functional as F 
 
# 方法1,通过定义一个Net类来建立神经网络 
class Net(torch.nn.Module): 
  def __init__(self, n_feature, n_hidden, n_output): 
    super(Net, self).__init__() 
    self.hidden = torch.nn.Linear(n_feature, n_hidden) 
    self.predict = torch.nn.Linear(n_hidden, n_output) 
 
  def forward(self, x): 
    x = F.relu(self.hidden(x)) 
    x = self.predict(x) 
    return x 
 
net1 = Net(2, 10, 2) 
print('方法1:\n', net1) 
 
# 方法2 通过torch.nn.Sequential快速建立神经网络结构 
net2 = torch.nn.Sequential( 
  torch.nn.Linear(2, 10), 
  torch.nn.ReLU(), 
  torch.nn.Linear(10, 2), 
  ) 
print('方法2:\n', net2) 
# 经验证,两种方法构建的神经网络功能相同,结构细节稍有不同 
 
''''' 
方法1: 
 Net ( 
 (hidden): Linear (2 -> 10) 
 (predict): Linear (10 -> 2) 
) 
方法2: 
 Sequential ( 
 (0): Linear (2 -> 10) 
 (1): ReLU () 
 (2): Linear (10 -> 2) 
) 
'''

先前学习了通过定义一个Net类来构建神经网络的方法,classNet中首先通过super函数继承torch.nn.Module模块的构造方法,再通过添加属性的方式搭建神经网络各层的结构信息,在forward方法中完善神经网络各层之间的连接信息,然后再通过定义Net类对象的方式完成对神经网络结构的构建。

构建神经网络的另一个方法,也可以说是快速构建方法,就是通过torch.nn.Sequential,直接完成对神经网络的建立。

两种方法构建得到的神经网络结构完全相同,都可以通过print函数来打印输出网络信息,不过打印结果会有些许不同。

二、PyTorch的神经网络保存和提取

在学习和研究深度学习的时候,当我们通过一定时间的训练,得到了一个比较好的模型的时候,我们当然希望将这个模型及模型参数保存下来,以备后用,所以神经网络的保存和模型参数提取重载是很有必要的。

首先,我们需要在需要保存网路结构及其模型参数的神经网络的定义、训练部分之后通过torch.save()实现对网络结构和模型参数的保存。有两种保存方式:一是保存年整个神经网络的的结构信息和模型参数信息,save的对象是网络net;二是只保存神经网络的训练模型参数,save的对象是net.state_dict(),保存结果都以.pkl文件形式存储。

对应上面两种保存方式,重载方式也有两种。对应第一种完整网络结构信息,重载的时候通过torch.load(‘.pkl')直接初始化新的神经网络对象即可。对应第二种只保存模型参数信息,需要首先搭建相同的神经网络结构,通过net.load_state_dict(torch.load('.pkl'))完成模型参数的重载。在网络比较大的时候,第一种方法会花费较多的时间。

代码实现:

import torch 
from torch.autograd import Variable 
import matplotlib.pyplot as plt 
 
torch.manual_seed(1) # 设定随机数种子 
 
# 创建数据 
x = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1) 
y = x.pow(2) + 0.2*torch.rand(x.size()) 
x, y = Variable(x, requires_grad=False), Variable(y, requires_grad=False) 
 
# 将待保存的神经网络定义在一个函数中 
def save(): 
  # 神经网络结构 
  net1 = torch.nn.Sequential( 
    torch.nn.Linear(1, 10), 
    torch.nn.ReLU(), 
    torch.nn.Linear(10, 1), 
    ) 
  optimizer = torch.optim.SGD(net1.parameters(), lr=0.5) 
  loss_function = torch.nn.MSELoss() 
 
  # 训练部分 
  for i in range(300): 
    prediction = net1(x) 
    loss = loss_function(prediction, y) 
    optimizer.zero_grad() 
    loss.backward() 
    optimizer.step() 
 
  # 绘图部分 
  plt.figure(1, figsize=(10, 3)) 
  plt.subplot(131) 
  plt.title('net1') 
  plt.scatter(x.data.numpy(), y.data.numpy()) 
  plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5) 
 
  # 保存神经网络 
  torch.save(net1, '7-net.pkl')           # 保存整个神经网络的结构和模型参数 
  torch.save(net1.state_dict(), '7-net_params.pkl') # 只保存神经网络的模型参数 
 
# 载入整个神经网络的结构及其模型参数 
def reload_net(): 
  net2 = torch.load('7-net.pkl') 
  prediction = net2(x) 
 
  plt.subplot(132) 
  plt.title('net2') 
  plt.scatter(x.data.numpy(), y.data.numpy()) 
  plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5) 
 
# 只载入神经网络的模型参数,神经网络的结构需要与保存的神经网络相同的结构 
def reload_params(): 
  # 首先搭建相同的神经网络结构 
  net3 = torch.nn.Sequential( 
    torch.nn.Linear(1, 10), 
    torch.nn.ReLU(), 
    torch.nn.Linear(10, 1), 
    ) 
 
  # 载入神经网络的模型参数 
  net3.load_state_dict(torch.load('7-net_params.pkl')) 
  prediction = net3(x) 
 
  plt.subplot(133) 
  plt.title('net3') 
  plt.scatter(x.data.numpy(), y.data.numpy()) 
  plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5) 
 
# 运行测试 
save() 
reload_net() 
reload_params()

实验结果:

PyTorch快速搭建神经网络及其保存提取方法详解

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
跟老齐学Python之复习if语句
Oct 02 Python
Python入门篇之面向对象
Oct 20 Python
python读写二进制文件的方法
May 09 Python
python检测某个变量是否有定义的方法
May 20 Python
Python实现正弦信号的时域波形和频谱图示例【基于matplotlib】
May 04 Python
Python中函数参数调用方式分析
Aug 09 Python
对Django url的几种使用方式详解
Aug 06 Python
Django使用uwsgi部署时的配置以及django日志文件的处理方法
Aug 30 Python
python实现网站微信登录的示例代码
Sep 18 Python
ubuntu上安装python的实例方法
Sep 30 Python
Anaconda和ipython环境适配的实现
Apr 22 Python
MoviePy常用剪辑类及Python视频剪辑自动化
Dec 18 Python
对Python中type打开文件的方式介绍
Apr 28 #Python
PyTorch上搭建简单神经网络实现回归和分类的示例
Apr 28 #Python
TensorFlow实现非线性支持向量机的实现方法
Apr 28 #Python
python 通过logging写入日志到文件和控制台的实例
Apr 28 #Python
Python实现合并同一个文件夹下所有PDF文件的方法示例
Apr 28 #Python
用TensorFlow实现多类支持向量机的示例代码
Apr 28 #Python
详谈python在windows中的文件路径问题
Apr 28 #Python
You might like
PHP与MySQL开发中页面乱码的产生与解决
2008/03/27 PHP
重新封装zend_soap实现http连接安全认证的php代码
2011/01/12 PHP
php更新mysql后获取影响的行数发生异常解决方法
2013/03/28 PHP
php实现的日历程序
2015/06/18 PHP
php使用Header函数,PHP_AUTH_PW和PHP_AUTH_USER做用户验证
2016/05/04 PHP
PHP面向对象程序设计重载(overloading)操作详解
2019/06/13 PHP
Javascript技术技巧大全(五)
2007/01/22 Javascript
三种取消选中单选框radio的方法
2014/09/09 Javascript
Nodejs中读取中文文件编码问题、发送邮件和定时任务实例
2015/01/01 NodeJs
Prototype框架详解
2015/11/25 Javascript
JavaScript实现无刷新上传预览图片功能
2017/08/02 Javascript
jQuery实现的页面详情展开收起功能示例
2018/06/11 jQuery
使用vue-router为每个路由配置各自的title
2018/07/30 Javascript
在axios中使用params传参的时候传入数组的方法
2018/09/25 Javascript
深入理解Angularjs 脏值检测
2018/10/12 Javascript
微信小程序全选多选效果实现代码解析
2020/01/21 Javascript
[01:06:42]VP vs NewBee Supermajor 胜者组 BO3 第二场 6.5
2018/06/06 DOTA
python中实现php的var_dump函数功能
2015/01/21 Python
使用graphics.py实现2048小游戏
2015/03/10 Python
操作Windows注册表的简单的Python程序制作教程
2015/04/07 Python
python实现中文转换url编码的方法
2016/06/14 Python
深入理解Python中变量赋值的问题
2017/01/12 Python
Python3.5 Json与pickle实现数据序列化与反序列化操作示例
2019/04/29 Python
浅谈Python编程中3个常用的数据结构和算法
2019/04/30 Python
Python绘制堆叠柱状图的实例
2019/07/09 Python
学python安装的软件总结
2019/10/12 Python
PYQT5 vscode联合操作qtdesigner的方法
2020/03/24 Python
python支持多继承吗
2020/06/19 Python
安装不同版本的tensorflow与models方法实现
2021/02/20 Python
用纯CSS3实现网页中常见的小箭头
2017/10/16 HTML / CSS
html5中的input新属性range使用记录
2014/09/05 HTML / CSS
法国发饰品牌:Alexandre De Paris
2018/12/04 全球购物
编程实现去掉XML的重复结点
2014/05/28 面试题
法学专业毕业生求职信
2014/06/12 职场文书
导游词之贵州百里杜鹃
2019/10/29 职场文书
教你怎么用Python生成九宫格照片
2021/05/20 Python