PyTorch快速搭建神经网络及其保存提取方法详解


Posted in Python onApril 28, 2018

有时候我们训练了一个模型, 希望保存它下次直接使用,不需要下次再花时间去训练 ,本节我们来讲解一下PyTorch快速搭建神经网络及其保存提取方法详解

一、PyTorch快速搭建神经网络方法

先看实验代码:

import torch 
import torch.nn.functional as F 
 
# 方法1,通过定义一个Net类来建立神经网络 
class Net(torch.nn.Module): 
  def __init__(self, n_feature, n_hidden, n_output): 
    super(Net, self).__init__() 
    self.hidden = torch.nn.Linear(n_feature, n_hidden) 
    self.predict = torch.nn.Linear(n_hidden, n_output) 
 
  def forward(self, x): 
    x = F.relu(self.hidden(x)) 
    x = self.predict(x) 
    return x 
 
net1 = Net(2, 10, 2) 
print('方法1:\n', net1) 
 
# 方法2 通过torch.nn.Sequential快速建立神经网络结构 
net2 = torch.nn.Sequential( 
  torch.nn.Linear(2, 10), 
  torch.nn.ReLU(), 
  torch.nn.Linear(10, 2), 
  ) 
print('方法2:\n', net2) 
# 经验证,两种方法构建的神经网络功能相同,结构细节稍有不同 
 
''''' 
方法1: 
 Net ( 
 (hidden): Linear (2 -> 10) 
 (predict): Linear (10 -> 2) 
) 
方法2: 
 Sequential ( 
 (0): Linear (2 -> 10) 
 (1): ReLU () 
 (2): Linear (10 -> 2) 
) 
'''

先前学习了通过定义一个Net类来构建神经网络的方法,classNet中首先通过super函数继承torch.nn.Module模块的构造方法,再通过添加属性的方式搭建神经网络各层的结构信息,在forward方法中完善神经网络各层之间的连接信息,然后再通过定义Net类对象的方式完成对神经网络结构的构建。

构建神经网络的另一个方法,也可以说是快速构建方法,就是通过torch.nn.Sequential,直接完成对神经网络的建立。

两种方法构建得到的神经网络结构完全相同,都可以通过print函数来打印输出网络信息,不过打印结果会有些许不同。

二、PyTorch的神经网络保存和提取

在学习和研究深度学习的时候,当我们通过一定时间的训练,得到了一个比较好的模型的时候,我们当然希望将这个模型及模型参数保存下来,以备后用,所以神经网络的保存和模型参数提取重载是很有必要的。

首先,我们需要在需要保存网路结构及其模型参数的神经网络的定义、训练部分之后通过torch.save()实现对网络结构和模型参数的保存。有两种保存方式:一是保存年整个神经网络的的结构信息和模型参数信息,save的对象是网络net;二是只保存神经网络的训练模型参数,save的对象是net.state_dict(),保存结果都以.pkl文件形式存储。

对应上面两种保存方式,重载方式也有两种。对应第一种完整网络结构信息,重载的时候通过torch.load(‘.pkl')直接初始化新的神经网络对象即可。对应第二种只保存模型参数信息,需要首先搭建相同的神经网络结构,通过net.load_state_dict(torch.load('.pkl'))完成模型参数的重载。在网络比较大的时候,第一种方法会花费较多的时间。

代码实现:

import torch 
from torch.autograd import Variable 
import matplotlib.pyplot as plt 
 
torch.manual_seed(1) # 设定随机数种子 
 
# 创建数据 
x = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1) 
y = x.pow(2) + 0.2*torch.rand(x.size()) 
x, y = Variable(x, requires_grad=False), Variable(y, requires_grad=False) 
 
# 将待保存的神经网络定义在一个函数中 
def save(): 
  # 神经网络结构 
  net1 = torch.nn.Sequential( 
    torch.nn.Linear(1, 10), 
    torch.nn.ReLU(), 
    torch.nn.Linear(10, 1), 
    ) 
  optimizer = torch.optim.SGD(net1.parameters(), lr=0.5) 
  loss_function = torch.nn.MSELoss() 
 
  # 训练部分 
  for i in range(300): 
    prediction = net1(x) 
    loss = loss_function(prediction, y) 
    optimizer.zero_grad() 
    loss.backward() 
    optimizer.step() 
 
  # 绘图部分 
  plt.figure(1, figsize=(10, 3)) 
  plt.subplot(131) 
  plt.title('net1') 
  plt.scatter(x.data.numpy(), y.data.numpy()) 
  plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5) 
 
  # 保存神经网络 
  torch.save(net1, '7-net.pkl')           # 保存整个神经网络的结构和模型参数 
  torch.save(net1.state_dict(), '7-net_params.pkl') # 只保存神经网络的模型参数 
 
# 载入整个神经网络的结构及其模型参数 
def reload_net(): 
  net2 = torch.load('7-net.pkl') 
  prediction = net2(x) 
 
  plt.subplot(132) 
  plt.title('net2') 
  plt.scatter(x.data.numpy(), y.data.numpy()) 
  plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5) 
 
# 只载入神经网络的模型参数,神经网络的结构需要与保存的神经网络相同的结构 
def reload_params(): 
  # 首先搭建相同的神经网络结构 
  net3 = torch.nn.Sequential( 
    torch.nn.Linear(1, 10), 
    torch.nn.ReLU(), 
    torch.nn.Linear(10, 1), 
    ) 
 
  # 载入神经网络的模型参数 
  net3.load_state_dict(torch.load('7-net_params.pkl')) 
  prediction = net3(x) 
 
  plt.subplot(133) 
  plt.title('net3') 
  plt.scatter(x.data.numpy(), y.data.numpy()) 
  plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5) 
 
# 运行测试 
save() 
reload_net() 
reload_params()

实验结果:

PyTorch快速搭建神经网络及其保存提取方法详解

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
详解Python中的文本处理
Apr 11 Python
利用python模拟sql语句对员工表格进行增删改查
Jul 05 Python
python实时监控cpu小工具
Jun 21 Python
Django中数据库的数据关系:一对一,一对多,多对多
Oct 21 Python
只需7行Python代码玩转微信自动聊天
Jan 27 Python
python如何实现视频转代码视频
Jun 17 Python
Django中使用CORS实现跨域请求过程解析
Aug 05 Python
python3 下载网络图片代码实例
Aug 27 Python
Django Serializer HiddenField隐藏字段实例
Mar 31 Python
如何用python处理excel表格
Jun 09 Python
Python+OpenCV图像处理——图像二值化的实现
Oct 24 Python
python实战之一步一步教你绘制小猪佩奇
Apr 22 Python
对Python中type打开文件的方式介绍
Apr 28 #Python
PyTorch上搭建简单神经网络实现回归和分类的示例
Apr 28 #Python
TensorFlow实现非线性支持向量机的实现方法
Apr 28 #Python
python 通过logging写入日志到文件和控制台的实例
Apr 28 #Python
Python实现合并同一个文件夹下所有PDF文件的方法示例
Apr 28 #Python
用TensorFlow实现多类支持向量机的示例代码
Apr 28 #Python
详谈python在windows中的文件路径问题
Apr 28 #Python
You might like
php求正负数数组中连续元素最大值示例
2014/04/11 PHP
php实现文件下载实例分享
2014/06/02 PHP
ThinkPHP入口文件设置及相关注意事项分析
2014/12/05 PHP
PHP Imagick完美实现图片裁切、生成缩略图、添加水印
2016/02/22 PHP
php进行ip地址掩码运算处理的方法
2016/07/11 PHP
JScript中的"this"关键字使用方式补充材料
2007/03/08 Javascript
DIV菜单层实现代码
2010/11/19 Javascript
jQuery ui插件的使用方法代码实例
2013/05/08 Javascript
js中的this关键字详解
2013/09/25 Javascript
js实现的折叠导航示例
2013/11/29 Javascript
javaScript年份下拉列表框内容为当前年份及前后50年
2014/05/28 Javascript
学习Angular中作用域需要注意的坑
2016/08/17 Javascript
json数据处理及数据绑定
2017/01/25 Javascript
layui--select使用以及下拉框实现键盘选择的例子
2019/09/24 Javascript
ES5和ES6中类的区别总结
2020/12/21 Javascript
vue绑定class的三种方法
2020/12/24 Vue.js
[01:34]2016国际邀请赛中国区预选赛IG战队教练采访
2016/06/27 DOTA
Python实现按中文排序的方法示例
2018/04/25 Python
python字符串查找函数的用法详解
2019/07/08 Python
处理Selenium3+python3定位鼠标悬停才显示的元素
2019/07/31 Python
python运用pygame库实现双人弹球小游戏
2019/11/25 Python
获取python运行输出的数据并解析存为dataFrame实例
2020/07/07 Python
浅谈python锁与死锁问题
2020/08/14 Python
Django Auth用户认证组件实现代码
2020/10/13 Python
10 套华丽的CSS3 按钮小结
2012/10/03 HTML / CSS
美国鲜花递送:UrbanStems
2021/01/04 全球购物
产品促销活动策划书
2014/01/15 职场文书
中国文明网签名寄语
2014/01/18 职场文书
个人公开承诺书
2014/03/28 职场文书
市场营销战略计划书
2014/05/06 职场文书
人事主管岗位职责说明书
2014/07/30 职场文书
看上去很美观后感
2015/06/10 职场文书
有关三国演义的读书笔记
2015/06/25 职场文书
《山中访友》教学反思
2016/02/24 职场文书
如何利用Python实现一个论文降重工具
2021/07/09 Python
使用Nginx的访问日志统计PV与UV
2022/05/06 Servers