PyTorch快速搭建神经网络及其保存提取方法详解


Posted in Python onApril 28, 2018

有时候我们训练了一个模型, 希望保存它下次直接使用,不需要下次再花时间去训练 ,本节我们来讲解一下PyTorch快速搭建神经网络及其保存提取方法详解

一、PyTorch快速搭建神经网络方法

先看实验代码:

import torch 
import torch.nn.functional as F 
 
# 方法1,通过定义一个Net类来建立神经网络 
class Net(torch.nn.Module): 
  def __init__(self, n_feature, n_hidden, n_output): 
    super(Net, self).__init__() 
    self.hidden = torch.nn.Linear(n_feature, n_hidden) 
    self.predict = torch.nn.Linear(n_hidden, n_output) 
 
  def forward(self, x): 
    x = F.relu(self.hidden(x)) 
    x = self.predict(x) 
    return x 
 
net1 = Net(2, 10, 2) 
print('方法1:\n', net1) 
 
# 方法2 通过torch.nn.Sequential快速建立神经网络结构 
net2 = torch.nn.Sequential( 
  torch.nn.Linear(2, 10), 
  torch.nn.ReLU(), 
  torch.nn.Linear(10, 2), 
  ) 
print('方法2:\n', net2) 
# 经验证,两种方法构建的神经网络功能相同,结构细节稍有不同 
 
''''' 
方法1: 
 Net ( 
 (hidden): Linear (2 -> 10) 
 (predict): Linear (10 -> 2) 
) 
方法2: 
 Sequential ( 
 (0): Linear (2 -> 10) 
 (1): ReLU () 
 (2): Linear (10 -> 2) 
) 
'''

先前学习了通过定义一个Net类来构建神经网络的方法,classNet中首先通过super函数继承torch.nn.Module模块的构造方法,再通过添加属性的方式搭建神经网络各层的结构信息,在forward方法中完善神经网络各层之间的连接信息,然后再通过定义Net类对象的方式完成对神经网络结构的构建。

构建神经网络的另一个方法,也可以说是快速构建方法,就是通过torch.nn.Sequential,直接完成对神经网络的建立。

两种方法构建得到的神经网络结构完全相同,都可以通过print函数来打印输出网络信息,不过打印结果会有些许不同。

二、PyTorch的神经网络保存和提取

在学习和研究深度学习的时候,当我们通过一定时间的训练,得到了一个比较好的模型的时候,我们当然希望将这个模型及模型参数保存下来,以备后用,所以神经网络的保存和模型参数提取重载是很有必要的。

首先,我们需要在需要保存网路结构及其模型参数的神经网络的定义、训练部分之后通过torch.save()实现对网络结构和模型参数的保存。有两种保存方式:一是保存年整个神经网络的的结构信息和模型参数信息,save的对象是网络net;二是只保存神经网络的训练模型参数,save的对象是net.state_dict(),保存结果都以.pkl文件形式存储。

对应上面两种保存方式,重载方式也有两种。对应第一种完整网络结构信息,重载的时候通过torch.load(‘.pkl')直接初始化新的神经网络对象即可。对应第二种只保存模型参数信息,需要首先搭建相同的神经网络结构,通过net.load_state_dict(torch.load('.pkl'))完成模型参数的重载。在网络比较大的时候,第一种方法会花费较多的时间。

代码实现:

import torch 
from torch.autograd import Variable 
import matplotlib.pyplot as plt 
 
torch.manual_seed(1) # 设定随机数种子 
 
# 创建数据 
x = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1) 
y = x.pow(2) + 0.2*torch.rand(x.size()) 
x, y = Variable(x, requires_grad=False), Variable(y, requires_grad=False) 
 
# 将待保存的神经网络定义在一个函数中 
def save(): 
  # 神经网络结构 
  net1 = torch.nn.Sequential( 
    torch.nn.Linear(1, 10), 
    torch.nn.ReLU(), 
    torch.nn.Linear(10, 1), 
    ) 
  optimizer = torch.optim.SGD(net1.parameters(), lr=0.5) 
  loss_function = torch.nn.MSELoss() 
 
  # 训练部分 
  for i in range(300): 
    prediction = net1(x) 
    loss = loss_function(prediction, y) 
    optimizer.zero_grad() 
    loss.backward() 
    optimizer.step() 
 
  # 绘图部分 
  plt.figure(1, figsize=(10, 3)) 
  plt.subplot(131) 
  plt.title('net1') 
  plt.scatter(x.data.numpy(), y.data.numpy()) 
  plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5) 
 
  # 保存神经网络 
  torch.save(net1, '7-net.pkl')           # 保存整个神经网络的结构和模型参数 
  torch.save(net1.state_dict(), '7-net_params.pkl') # 只保存神经网络的模型参数 
 
# 载入整个神经网络的结构及其模型参数 
def reload_net(): 
  net2 = torch.load('7-net.pkl') 
  prediction = net2(x) 
 
  plt.subplot(132) 
  plt.title('net2') 
  plt.scatter(x.data.numpy(), y.data.numpy()) 
  plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5) 
 
# 只载入神经网络的模型参数,神经网络的结构需要与保存的神经网络相同的结构 
def reload_params(): 
  # 首先搭建相同的神经网络结构 
  net3 = torch.nn.Sequential( 
    torch.nn.Linear(1, 10), 
    torch.nn.ReLU(), 
    torch.nn.Linear(10, 1), 
    ) 
 
  # 载入神经网络的模型参数 
  net3.load_state_dict(torch.load('7-net_params.pkl')) 
  prediction = net3(x) 
 
  plt.subplot(133) 
  plt.title('net3') 
  plt.scatter(x.data.numpy(), y.data.numpy()) 
  plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5) 
 
# 运行测试 
save() 
reload_net() 
reload_params()

实验结果:

PyTorch快速搭建神经网络及其保存提取方法详解

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python中dir函数用法分析
Apr 17 Python
Python实现统计英文单词个数及字符串分割代码
May 28 Python
TensorFlow打印tensor值的实现方法
Jul 27 Python
对python3新增的byte类型详解
Dec 04 Python
利用Python半自动化生成Nessus报告的方法
Mar 19 Python
python实现登录密码重置简易操作代码
Aug 14 Python
Pandas操作CSV文件的读写实现方法
Nov 13 Python
python使用信号量动态更新配置文件的操作
Apr 01 Python
pip安装tensorflow的坑的解决
Apr 19 Python
python 一维二维插值实例
Apr 22 Python
Python timeit模块原理及使用方法
Oct 10 Python
python中Pexpect的工作流程实例讲解
Mar 02 Python
对Python中type打开文件的方式介绍
Apr 28 #Python
PyTorch上搭建简单神经网络实现回归和分类的示例
Apr 28 #Python
TensorFlow实现非线性支持向量机的实现方法
Apr 28 #Python
python 通过logging写入日志到文件和控制台的实例
Apr 28 #Python
Python实现合并同一个文件夹下所有PDF文件的方法示例
Apr 28 #Python
用TensorFlow实现多类支持向量机的示例代码
Apr 28 #Python
详谈python在windows中的文件路径问题
Apr 28 #Python
You might like
火影忍者:这才是千手柱间和扉间的真正死因,角都就比较搞笑了!
2020/03/10 日漫
PHP远程连接MYSQL数据库非常慢的解决方法
2008/07/05 PHP
PHP编程风格规范分享
2014/01/15 PHP
Smarty中的注释和截断功能介绍
2015/04/09 PHP
PHP的自定义模板引擎
2017/03/24 PHP
jquery跨域请求示例分享(jquery发送ajax请求)
2014/03/25 Javascript
jQuery中outerWidth()方法用法实例
2015/01/19 Javascript
node.js微信公众平台开发教程
2016/03/04 Javascript
JS中的二叉树遍历详解
2016/03/18 Javascript
JSP基于Bootstrap分页显示实例解析
2016/06/12 Javascript
省市选择的简单实现(基于zepto.js)
2016/06/21 Javascript
简单实现轮播图效果的实例
2016/07/15 Javascript
基于vue2.0+vuex+localStorage开发的本地记事本示例
2017/02/28 Javascript
Cropper.js 实现裁剪图片并上传(PC端)
2017/08/20 Javascript
浅谈Angular路由守卫
2017/08/26 Javascript
微信小程序中button组件的边框设置的实例详解
2017/09/27 Javascript
原生JS实现随机点名项目的实例代码
2019/04/30 Javascript
带你使用webpack快速构建web项目的方法
2020/11/12 Javascript
JavaScript实现浏览器网页自动滚动并点击的示例代码
2020/12/05 Javascript
Python对列表排序的方法实例分析
2015/05/16 Python
搭建Python的Django框架环境并建立和运行第一个App的教程
2016/07/02 Python
不知道这5种下划线的含义,你就不算真的会Python!
2018/10/09 Python
python异常处理、自定义异常、断言原理与用法分析
2020/03/23 Python
VSCode基础使用与VSCode调试python程序入门的图文教程
2020/03/30 Python
scrapy框架携带cookie访问淘宝购物车功能的实现代码
2020/07/07 Python
使用CSS3编写灰阶滤镜来制作黑白照片效果的方法
2016/05/09 HTML / CSS
移动通信行业实习自我鉴定
2013/09/28 职场文书
师范毕业生自我鉴定
2014/01/15 职场文书
高中生期末评语大全
2014/01/28 职场文书
工作会议主持词
2014/03/17 职场文书
租赁意向书范本
2014/04/01 职场文书
2014年小班保育员工作总结
2014/12/23 职场文书
精神文明建设先进个人事迹材料
2014/12/24 职场文书
企业开业庆典答谢词
2015/01/20 职场文书
创卫工作总结2015
2015/04/22 职场文书
八年级作文之友谊
2019/12/02 职场文书