PyTorch快速搭建神经网络及其保存提取方法详解


Posted in Python onApril 28, 2018

有时候我们训练了一个模型, 希望保存它下次直接使用,不需要下次再花时间去训练 ,本节我们来讲解一下PyTorch快速搭建神经网络及其保存提取方法详解

一、PyTorch快速搭建神经网络方法

先看实验代码:

import torch 
import torch.nn.functional as F 
 
# 方法1,通过定义一个Net类来建立神经网络 
class Net(torch.nn.Module): 
  def __init__(self, n_feature, n_hidden, n_output): 
    super(Net, self).__init__() 
    self.hidden = torch.nn.Linear(n_feature, n_hidden) 
    self.predict = torch.nn.Linear(n_hidden, n_output) 
 
  def forward(self, x): 
    x = F.relu(self.hidden(x)) 
    x = self.predict(x) 
    return x 
 
net1 = Net(2, 10, 2) 
print('方法1:\n', net1) 
 
# 方法2 通过torch.nn.Sequential快速建立神经网络结构 
net2 = torch.nn.Sequential( 
  torch.nn.Linear(2, 10), 
  torch.nn.ReLU(), 
  torch.nn.Linear(10, 2), 
  ) 
print('方法2:\n', net2) 
# 经验证,两种方法构建的神经网络功能相同,结构细节稍有不同 
 
''''' 
方法1: 
 Net ( 
 (hidden): Linear (2 -> 10) 
 (predict): Linear (10 -> 2) 
) 
方法2: 
 Sequential ( 
 (0): Linear (2 -> 10) 
 (1): ReLU () 
 (2): Linear (10 -> 2) 
) 
'''

先前学习了通过定义一个Net类来构建神经网络的方法,classNet中首先通过super函数继承torch.nn.Module模块的构造方法,再通过添加属性的方式搭建神经网络各层的结构信息,在forward方法中完善神经网络各层之间的连接信息,然后再通过定义Net类对象的方式完成对神经网络结构的构建。

构建神经网络的另一个方法,也可以说是快速构建方法,就是通过torch.nn.Sequential,直接完成对神经网络的建立。

两种方法构建得到的神经网络结构完全相同,都可以通过print函数来打印输出网络信息,不过打印结果会有些许不同。

二、PyTorch的神经网络保存和提取

在学习和研究深度学习的时候,当我们通过一定时间的训练,得到了一个比较好的模型的时候,我们当然希望将这个模型及模型参数保存下来,以备后用,所以神经网络的保存和模型参数提取重载是很有必要的。

首先,我们需要在需要保存网路结构及其模型参数的神经网络的定义、训练部分之后通过torch.save()实现对网络结构和模型参数的保存。有两种保存方式:一是保存年整个神经网络的的结构信息和模型参数信息,save的对象是网络net;二是只保存神经网络的训练模型参数,save的对象是net.state_dict(),保存结果都以.pkl文件形式存储。

对应上面两种保存方式,重载方式也有两种。对应第一种完整网络结构信息,重载的时候通过torch.load(‘.pkl')直接初始化新的神经网络对象即可。对应第二种只保存模型参数信息,需要首先搭建相同的神经网络结构,通过net.load_state_dict(torch.load('.pkl'))完成模型参数的重载。在网络比较大的时候,第一种方法会花费较多的时间。

代码实现:

import torch 
from torch.autograd import Variable 
import matplotlib.pyplot as plt 
 
torch.manual_seed(1) # 设定随机数种子 
 
# 创建数据 
x = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1) 
y = x.pow(2) + 0.2*torch.rand(x.size()) 
x, y = Variable(x, requires_grad=False), Variable(y, requires_grad=False) 
 
# 将待保存的神经网络定义在一个函数中 
def save(): 
  # 神经网络结构 
  net1 = torch.nn.Sequential( 
    torch.nn.Linear(1, 10), 
    torch.nn.ReLU(), 
    torch.nn.Linear(10, 1), 
    ) 
  optimizer = torch.optim.SGD(net1.parameters(), lr=0.5) 
  loss_function = torch.nn.MSELoss() 
 
  # 训练部分 
  for i in range(300): 
    prediction = net1(x) 
    loss = loss_function(prediction, y) 
    optimizer.zero_grad() 
    loss.backward() 
    optimizer.step() 
 
  # 绘图部分 
  plt.figure(1, figsize=(10, 3)) 
  plt.subplot(131) 
  plt.title('net1') 
  plt.scatter(x.data.numpy(), y.data.numpy()) 
  plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5) 
 
  # 保存神经网络 
  torch.save(net1, '7-net.pkl')           # 保存整个神经网络的结构和模型参数 
  torch.save(net1.state_dict(), '7-net_params.pkl') # 只保存神经网络的模型参数 
 
# 载入整个神经网络的结构及其模型参数 
def reload_net(): 
  net2 = torch.load('7-net.pkl') 
  prediction = net2(x) 
 
  plt.subplot(132) 
  plt.title('net2') 
  plt.scatter(x.data.numpy(), y.data.numpy()) 
  plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5) 
 
# 只载入神经网络的模型参数,神经网络的结构需要与保存的神经网络相同的结构 
def reload_params(): 
  # 首先搭建相同的神经网络结构 
  net3 = torch.nn.Sequential( 
    torch.nn.Linear(1, 10), 
    torch.nn.ReLU(), 
    torch.nn.Linear(10, 1), 
    ) 
 
  # 载入神经网络的模型参数 
  net3.load_state_dict(torch.load('7-net_params.pkl')) 
  prediction = net3(x) 
 
  plt.subplot(133) 
  plt.title('net3') 
  plt.scatter(x.data.numpy(), y.data.numpy()) 
  plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5) 
 
# 运行测试 
save() 
reload_net() 
reload_params()

实验结果:

PyTorch快速搭建神经网络及其保存提取方法详解

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
详细讲解用Python发送SMTP邮件的教程
Apr 29 Python
全面解析Python的While循环语句的使用方法
Oct 13 Python
python3 对list中每个元素进行处理的方法
Jun 29 Python
Django中日期处理注意事项与自定义时间格式转换详解
Aug 06 Python
Python实现的逻辑回归算法示例【附测试csv文件下载】
Dec 28 Python
十行代码使用Python写一个USB病毒
Jun 21 Python
python基于property()函数定义属性
Jan 22 Python
Python实现自动签到脚本的示例代码
Aug 19 Python
Pycharm如何自动生成头文件注释
Nov 14 Python
关于Python 解决Python3.9 pandas.read_excel(‘xxx.xlsx‘)报错的问题
Nov 28 Python
Python3使用tesserocr识别字母数字验证码的实现
Jan 29 Python
Python 如何实现文件自动去重
Jun 02 Python
对Python中type打开文件的方式介绍
Apr 28 #Python
PyTorch上搭建简单神经网络实现回归和分类的示例
Apr 28 #Python
TensorFlow实现非线性支持向量机的实现方法
Apr 28 #Python
python 通过logging写入日志到文件和控制台的实例
Apr 28 #Python
Python实现合并同一个文件夹下所有PDF文件的方法示例
Apr 28 #Python
用TensorFlow实现多类支持向量机的示例代码
Apr 28 #Python
详谈python在windows中的文件路径问题
Apr 28 #Python
You might like
PHP面向对象继承用法详解(优化与减少代码重复)
2016/12/02 PHP
PHP实现接收二进制流转换成图片的方法
2017/01/10 PHP
laravel框架中路由设置,路由参数和路由命名实例分析
2019/11/23 PHP
js focus不起作用的解决方法(主要是因为dom元素是否加载完成)
2010/11/05 Javascript
jQuery Ajax方法调用 Asp.Net WebService 的详细实例代码
2011/04/27 Javascript
根据表格中的某一列进行排序的javascript代码
2013/11/29 Javascript
自己封装的javascript事件队列函数版
2014/06/12 Javascript
JsRender实用入门教程
2014/10/31 Javascript
nodejs中简单实现Javascript Promise机制的实例
2014/12/06 NodeJs
node.js中的fs.lchmodSync方法使用说明
2014/12/16 Javascript
jQuery动态星级评分效果实现方法
2015/08/06 Javascript
非常实用的js验证框架实现源码 附原理方法
2016/06/08 Javascript
VUEJS实战之利用laypage插件实现分页(3)
2016/06/13 Javascript
D3.js实现雷达图的方法详解
2016/09/22 Javascript
浅析js的模块化编写 require.js
2016/12/07 Javascript
JavaScript与JQUERY获取元素的宽、高和位置
2017/02/26 Javascript
JavaScript和JQuery获取DIV值的方法示例
2017/03/07 Javascript
javascript实现二叉树遍历的代码
2017/06/08 Javascript
Node.js学习之查询字符串解析querystring详解
2017/09/28 Javascript
js实现京东秒杀倒计时功能
2019/01/21 Javascript
通过实例了解Render Props回调地狱解决方案
2020/11/04 Javascript
Python开发常用的一些开源Package分享
2015/02/14 Python
Python字符串和文件操作常用函数分析
2015/04/08 Python
numpy中的delete删除数组整行和整列的实例
2018/05/09 Python
Python爬虫实现简单的爬取有道翻译功能示例
2018/07/13 Python
浅谈Python3实现两个矩形的交并比(IoU)
2020/01/18 Python
Python爬虫Scrapy框架CrawlSpider原理及使用案例
2020/11/20 Python
python基于openpyxl生成excel文件
2020/12/23 Python
HTML5 Canvas概述
2009/08/26 HTML / CSS
YSL Beauty加拿大官方商城:圣罗兰美妆加拿大
2017/05/15 全球购物
电子信息毕业生自荐信
2013/11/16 职场文书
精通CAD能手自荐书
2014/01/31 职场文书
2014年单位法制宣传日活动总结
2014/11/01 职场文书
《春酒》教学反思
2016/02/22 职场文书
Laravel中获取IP的真实地理位置
2021/04/01 PHP
php将xml转化对象的实例详解
2021/11/17 PHP