PyTorch快速搭建神经网络及其保存提取方法详解


Posted in Python onApril 28, 2018

有时候我们训练了一个模型, 希望保存它下次直接使用,不需要下次再花时间去训练 ,本节我们来讲解一下PyTorch快速搭建神经网络及其保存提取方法详解

一、PyTorch快速搭建神经网络方法

先看实验代码:

import torch 
import torch.nn.functional as F 
 
# 方法1,通过定义一个Net类来建立神经网络 
class Net(torch.nn.Module): 
  def __init__(self, n_feature, n_hidden, n_output): 
    super(Net, self).__init__() 
    self.hidden = torch.nn.Linear(n_feature, n_hidden) 
    self.predict = torch.nn.Linear(n_hidden, n_output) 
 
  def forward(self, x): 
    x = F.relu(self.hidden(x)) 
    x = self.predict(x) 
    return x 
 
net1 = Net(2, 10, 2) 
print('方法1:\n', net1) 
 
# 方法2 通过torch.nn.Sequential快速建立神经网络结构 
net2 = torch.nn.Sequential( 
  torch.nn.Linear(2, 10), 
  torch.nn.ReLU(), 
  torch.nn.Linear(10, 2), 
  ) 
print('方法2:\n', net2) 
# 经验证,两种方法构建的神经网络功能相同,结构细节稍有不同 
 
''''' 
方法1: 
 Net ( 
 (hidden): Linear (2 -> 10) 
 (predict): Linear (10 -> 2) 
) 
方法2: 
 Sequential ( 
 (0): Linear (2 -> 10) 
 (1): ReLU () 
 (2): Linear (10 -> 2) 
) 
'''

先前学习了通过定义一个Net类来构建神经网络的方法,classNet中首先通过super函数继承torch.nn.Module模块的构造方法,再通过添加属性的方式搭建神经网络各层的结构信息,在forward方法中完善神经网络各层之间的连接信息,然后再通过定义Net类对象的方式完成对神经网络结构的构建。

构建神经网络的另一个方法,也可以说是快速构建方法,就是通过torch.nn.Sequential,直接完成对神经网络的建立。

两种方法构建得到的神经网络结构完全相同,都可以通过print函数来打印输出网络信息,不过打印结果会有些许不同。

二、PyTorch的神经网络保存和提取

在学习和研究深度学习的时候,当我们通过一定时间的训练,得到了一个比较好的模型的时候,我们当然希望将这个模型及模型参数保存下来,以备后用,所以神经网络的保存和模型参数提取重载是很有必要的。

首先,我们需要在需要保存网路结构及其模型参数的神经网络的定义、训练部分之后通过torch.save()实现对网络结构和模型参数的保存。有两种保存方式:一是保存年整个神经网络的的结构信息和模型参数信息,save的对象是网络net;二是只保存神经网络的训练模型参数,save的对象是net.state_dict(),保存结果都以.pkl文件形式存储。

对应上面两种保存方式,重载方式也有两种。对应第一种完整网络结构信息,重载的时候通过torch.load(‘.pkl')直接初始化新的神经网络对象即可。对应第二种只保存模型参数信息,需要首先搭建相同的神经网络结构,通过net.load_state_dict(torch.load('.pkl'))完成模型参数的重载。在网络比较大的时候,第一种方法会花费较多的时间。

代码实现:

import torch 
from torch.autograd import Variable 
import matplotlib.pyplot as plt 
 
torch.manual_seed(1) # 设定随机数种子 
 
# 创建数据 
x = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1) 
y = x.pow(2) + 0.2*torch.rand(x.size()) 
x, y = Variable(x, requires_grad=False), Variable(y, requires_grad=False) 
 
# 将待保存的神经网络定义在一个函数中 
def save(): 
  # 神经网络结构 
  net1 = torch.nn.Sequential( 
    torch.nn.Linear(1, 10), 
    torch.nn.ReLU(), 
    torch.nn.Linear(10, 1), 
    ) 
  optimizer = torch.optim.SGD(net1.parameters(), lr=0.5) 
  loss_function = torch.nn.MSELoss() 
 
  # 训练部分 
  for i in range(300): 
    prediction = net1(x) 
    loss = loss_function(prediction, y) 
    optimizer.zero_grad() 
    loss.backward() 
    optimizer.step() 
 
  # 绘图部分 
  plt.figure(1, figsize=(10, 3)) 
  plt.subplot(131) 
  plt.title('net1') 
  plt.scatter(x.data.numpy(), y.data.numpy()) 
  plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5) 
 
  # 保存神经网络 
  torch.save(net1, '7-net.pkl')           # 保存整个神经网络的结构和模型参数 
  torch.save(net1.state_dict(), '7-net_params.pkl') # 只保存神经网络的模型参数 
 
# 载入整个神经网络的结构及其模型参数 
def reload_net(): 
  net2 = torch.load('7-net.pkl') 
  prediction = net2(x) 
 
  plt.subplot(132) 
  plt.title('net2') 
  plt.scatter(x.data.numpy(), y.data.numpy()) 
  plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5) 
 
# 只载入神经网络的模型参数,神经网络的结构需要与保存的神经网络相同的结构 
def reload_params(): 
  # 首先搭建相同的神经网络结构 
  net3 = torch.nn.Sequential( 
    torch.nn.Linear(1, 10), 
    torch.nn.ReLU(), 
    torch.nn.Linear(10, 1), 
    ) 
 
  # 载入神经网络的模型参数 
  net3.load_state_dict(torch.load('7-net_params.pkl')) 
  prediction = net3(x) 
 
  plt.subplot(133) 
  plt.title('net3') 
  plt.scatter(x.data.numpy(), y.data.numpy()) 
  plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5) 
 
# 运行测试 
save() 
reload_net() 
reload_params()

实验结果:

PyTorch快速搭建神经网络及其保存提取方法详解

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
浅谈Python实现贪心算法与活动安排问题
Dec 19 Python
浅谈Python中的作用域规则和闭包
Mar 20 Python
python实现简单登陆流程的方法
Apr 22 Python
基于MTCNN/TensorFlow实现人脸检测
May 24 Python
pyQt4实现俄罗斯方块游戏
Jun 26 Python
python 处理string到hex脚本的方法
Oct 26 Python
Python二叉树的遍历操作示例【前序遍历,中序遍历,后序遍历,层序遍历】
Dec 24 Python
对python中xlsx,csv以及json文件的相互转化方法详解
Dec 25 Python
对python tkinter窗口弹出置顶的方法详解
Jun 14 Python
keras处理欠拟合和过拟合的实例讲解
May 25 Python
python在linux环境下安装skimage的示例代码
Oct 14 Python
pytorch中index_select()的用法详解
Jan 06 Python
对Python中type打开文件的方式介绍
Apr 28 #Python
PyTorch上搭建简单神经网络实现回归和分类的示例
Apr 28 #Python
TensorFlow实现非线性支持向量机的实现方法
Apr 28 #Python
python 通过logging写入日志到文件和控制台的实例
Apr 28 #Python
Python实现合并同一个文件夹下所有PDF文件的方法示例
Apr 28 #Python
用TensorFlow实现多类支持向量机的示例代码
Apr 28 #Python
详谈python在windows中的文件路径问题
Apr 28 #Python
You might like
php 清除网页病毒的方法
2008/12/05 PHP
PHP获取163、gmail、126等邮箱联系人地址【已测试2009.10.10】
2009/10/11 PHP
使用Huagepage和PGO来提升PHP7的执行性能
2015/11/30 PHP
php+jQuery+Ajax简单实现页面异步刷新
2016/08/08 PHP
php file_get_contents取文件中数组元素的方法
2017/04/01 PHP
PHP用PDO如何封装简单易用的DB类详解
2017/07/30 PHP
php进程daemon化的正确实现方法
2018/09/06 PHP
一个js写的日历(代码部分网摘)
2009/09/20 Javascript
基于jquery的direction图片渐变动画效果
2010/05/24 Javascript
基于jquery中children()与find()的区别介绍
2013/04/26 Javascript
javascript面向对象快速入门实例
2015/01/13 Javascript
js实现文件上传表单域美化特效
2015/11/02 Javascript
js实现数组冒泡排序、快速排序原理
2016/03/08 Javascript
jQuery中常用动画效果函数(日常整理)
2016/09/17 Javascript
微信小程序 教程之条件渲染
2016/10/18 Javascript
jQuery层级选择器实例代码
2017/02/06 Javascript
初探JavaScript 面向对象(推荐)
2017/09/03 Javascript
JS库之Three.js 简易入门教程(详解之一)
2017/09/13 Javascript
详解webpack loader和plugin编写
2018/10/12 Javascript
JS实现简单的抽奖转盘效果示例
2019/02/16 Javascript
vue中img src 动态加载本地json的图片路径写法
2019/04/25 Javascript
python self,cls,decorator的理解
2009/07/13 Python
Python基于matplotlib实现绘制三维图形功能示例
2018/01/18 Python
Python求解任意闭区间的所有素数
2018/06/10 Python
在python中实现对list求和及求积
2018/11/14 Python
浅谈python函数调用返回两个或多个变量的方法
2019/01/23 Python
python向图片里添加文字
2019/11/26 Python
Python+OpenCV实现将图像转换为二进制格式
2020/01/09 Python
Pytorch中的VGG实现修改最后一层FC
2020/01/15 Python
python爬虫使用requests发送post请求示例详解
2020/08/05 Python
PyCharm上安装Package的实现(以pandas为例)
2020/09/18 Python
HTML5制作酷炫音频播放器插件图文教程
2014/12/30 HTML / CSS
会计工作能力自我评价
2015/03/05 职场文书
初中班主任培训心得体会
2016/01/07 职场文书
Java实现二维数组和稀疏数组之间的转换
2021/06/27 Java/Android
纯 CSS 自定义多行省略的问题(从原理到实现)
2021/11/11 HTML / CSS