pytorch训练神经网络爆内存的解决方案


Posted in Python onMay 22, 2021

训练的时候内存一直在增加,最后内存爆满,被迫中断。

pytorch训练神经网络爆内存的解决方案

后来换了一个电脑发现还是这样,考虑是代码的问题。

检查才发现我的代码两次存了loss,只有一个地方写的是loss.item()。问题就在loss,因为loss是variable类型。

要写成loss_train = loss_train + loss.item(),不能直接写loss_train = loss_train + loss。否则就会发现随着epoch的增加,占的内存也在一点一点增加。

算是一个小坑吧,希望大家还是要仔细。

补充:pytorch神经网络解决回归问题(非常易懂)

对于pytorch的深度学习框架

在建立人工神经网络时整体的步骤主要有以下四步:

1、载入原始数据

2、构建具体神经网络

3、进行数据的训练

4、数据测试和验证

pytorch神经网络的数据载入,以MINIST书写字体的原始数据为例:

import torch
import matplotlib.pyplot as  plt
def plot_curve(data):
    fig=plt.figure()
    plt.plot(range(len(data)),data,color="blue")
    plt.legend(["value"],loc="upper right")
    plt.xlabel("step")
    plt.ylabel("value")
    plt.show()
 
def plot_image(img,label,name):
    fig=plt.figure()
    for i in range(6):
        plt.subplot(2,3,i+1)
        plt.tight_layout()
        plt.imshow(img[i][0]*0.3081+0.1307,cmap="gray",interpolation="none")
        plt.title("{}:{}".format(name, label[i].item()))
        plt.xticks([])
        plt.yticks([])
    plt.show()
def one_hot(label,depth=10):
    out=torch.zeros(label.size(0),depth)
    idx=torch.LongTensor(label).view(-1,1)
    out.scatter_(dim=1,index=idx,value=1)
    return out
 
batch_size=512
import torch
from torch import nn                         #完成神经网络的构建包
from torch.nn import functional as F         #包含常用的函数包
from torch import optim                      #优化工具包
import torchvision                           #视觉工具包
import  matplotlib.pyplot as plt
from utils import plot_curve,plot_image,one_hot
#step1 load dataset   加载数据包
train_loader=torch.utils.data.DataLoader(
    torchvision.datasets.MNIST("minist_data",train=True,download=True,transform=torchvision.transforms.Compose(
        [torchvision.transforms.ToTensor(),torchvision.transforms.Normalize((0.1307,),(0.3081,))
         ])),
    batch_size=batch_size,shuffle=True)
test_loader=torch.utils.data.DataLoader(
    torchvision.datasets.MNIST("minist_data",train=True,download=False,transform=torchvision.transforms.Compose(
        [torchvision.transforms.ToTensor(),torchvision.transforms.Normalize((0.1307,),(0.3081,))
         ])),
    batch_size=batch_size,shuffle=False)
x,y=next(iter(train_loader))
print(x.shape,y.shape)
plot_image(x,y,"image")
print(x)
print(y)

以构建一个简单的回归问题的神经网络为例,

其具体的实现代码如下所示:

import torch
import torch.nn.functional as F  # 激励函数都在这
 
x = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1)  # x data (tensor), shape=(100, 1)
y = x.pow(2) + 0.2 * torch.rand(x.size())  # noisy y data (tensor), shape=(100, 1)
 
class Net(torch.nn.Module):  # 继承 torch 的 Module(固定)
    def __init__(self, n_feature, n_hidden, n_output):  # 定义层的信息,n_feature多少个输入, n_hidden每层神经元, n_output多少个输出
        super(Net, self).__init__()  # 继承 __init__ 功能(固定)
        # 定义每层用什么样的形式
        self.hidden = torch.nn.Linear(n_feature, n_hidden)  # 定义隐藏层,线性输出
        self.predict = torch.nn.Linear(n_hidden, n_output)  # 定义输出层线性输出
 
    def forward(self, x):  # x是输入信息就是data,同时也是 Module 中的 forward 功能,定义神经网络前向传递的过程,把__init__中的层信息一个一个的组合起来
        # 正向传播输入值, 神经网络分析出输出值
        x = F.relu(self.hidden(x))  # 定义激励函数(隐藏层的线性值)
        x = self.predict(x)  # 输出层,输出值
        return x 
 
net = Net(n_feature=1, n_hidden=10, n_output=1) 
print(net)  # net 的结构
"""
Net (
  (hidden): Linear (1 -> 10)
  (predict): Linear (10 -> 1)
)
"""
# optimizer 是训练的工具
optimizer = torch.optim.SGD(net.parameters(), lr=0.2)  # 传入 net 的所有参数, 学习率
loss_func = torch.nn.MSELoss()  # 预测值和真实值的误差计算公式 (均方差)
 
for t in range(100):  # 训练的步数100步
    prediction = net(x)  # 喂给 net 训练数据 x, 每迭代一步,输出预测值
 
    loss = loss_func(prediction, y)  # 计算两者的误差
 
    # 优化步骤:
    optimizer.zero_grad()  # 清空上一步的残余更新参数值
    loss.backward()  # 误差反向传播, 计算参数更新值
    optimizer.step()  # 将参数更新值施加到 net 的 parameters 上
 
import matplotlib.pyplot as plt 
plt.ion()  # 实时画图something about plotting 
for t in range(200):
    prediction = net(x)  # input x and predict based on x 
    loss = loss_func(prediction, y)  # must be (1. nn output, 2. target) 
    optimizer.zero_grad()  # clear gradients for next train
    loss.backward()  # backpropagation, compute gradients
    optimizer.step()  # apply gradients
 
    if t % 5 == 0:  # 每五步绘一次图
        # plot and show learning process
        plt.cla()
        plt.scatter(x.data.numpy(), y.data.numpy())
        plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)
        plt.text(0.5, 0, 'Loss=%.4f' % loss.data.numpy(), fontdict={'size': 20, 'color': 'red'})
        plt.pause(0.1)
 
plt.ioff()
plt.show()

pytorch训练神经网络爆内存的解决方案

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
基于Python3 逗号代码 和 字符图网格(详谈)
Jun 22 Python
Python用csv写入文件_消除空余行的方法
Jul 06 Python
NLTK 3.2.4 环境搭建教程
Sep 19 Python
详解Django项目中模板标签及模板的继承与引用(网站中快速布置广告)
Mar 27 Python
python实现动态数组的示例代码
Jul 15 Python
python写程序统计词频的方法
Jul 29 Python
python requests证书问题解决
Sep 05 Python
python tkinter canvas使用实例
Nov 04 Python
Python实现FLV视频拼接功能
Jan 21 Python
django实现将修改好的新模型写入数据库
Mar 31 Python
Anaconda配置pytorch-gpu虚拟环境的图文教程
Apr 16 Python
keras得到每层的系数方式
Jun 15 Python
粗暴解决CUDA out of memory的问题
May 22 #Python
pytorch中的model.eval()和BN层的使用
May 22 #Python
解决Pytorch中关于model.eval的问题
Pytorch 中net.train 和 net.eval的使用说明
May 22 #Python
对PyTorch中inplace字段的全面理解
May 22 #Python
pytorch中F.avg_pool1d()和F.avg_pool2d()的使用操作
May 22 #Python
用python实现监控视频人数统计
You might like
php中自定义函数dump查看数组信息类似var_dump
2014/01/27 PHP
php模仿asp Application对象在线人数统计实现方法
2015/01/04 PHP
PHP实现多图上传和单图上传功能
2018/05/17 PHP
超越Jquery_01_isPlainObject分析与重构
2010/10/20 Javascript
jquery miniui 教程 表格控件 合并单元格应用
2012/11/25 Javascript
JavaScript设计模式之策略模式实例
2014/10/10 Javascript
JQuery标签页效果的两个实例讲解(4)
2015/09/17 Javascript
JS实现DIV容器赋值的方法
2015/12/14 Javascript
Angular.js实现注册系统的实例详解
2016/12/18 Javascript
基于JavaScript实现验证码功能
2017/04/01 Javascript
改变vue请求过来的数据中的某一项值的方法(详解)
2018/03/08 Javascript
AngularJS与BootStrap模仿百度分页的示例代码
2018/05/23 Javascript
Vue.js 中的 v-show 指令及用法详解
2018/11/19 Javascript
新手快速入门JavaScript装饰者模式与AOP
2019/06/24 Javascript
JS如何实现动态添加的元素绑定事件
2019/11/12 Javascript
vue 使用vant插件做tabs切换和无限加载功能的实现
2020/11/04 Javascript
python实现划词翻译
2020/04/23 Python
Python中的异常处理相关语句基础学习笔记
2016/07/11 Python
python 打印直角三角形,等边三角形,菱形,正方形的代码
2017/11/21 Python
解决Python安装后pip不能用的问题
2018/06/12 Python
pyQt5实时刷新界面的示例
2019/06/25 Python
在Django model中设置多个字段联合唯一约束的实例
2019/07/17 Python
python中的django是做什么的
2020/07/31 Python
提高python代码运行效率的一些建议
2020/09/29 Python
详解移动端HTML5页面端去掉input输入框的白色背景和边框(兼容Android和ios)
2016/12/15 HTML / CSS
初三学生个人自我评定
2014/04/06 职场文书
上课不认真检讨书
2014/09/17 职场文书
教师自我剖析材料(四风问题)
2014/09/30 职场文书
意向协议书
2015/01/27 职场文书
检讨书格式范文
2015/05/07 职场文书
医院病假条范文
2015/08/17 职场文书
2015年乡镇食品安全工作总结
2015/10/22 职场文书
2016年主题党日活动总结
2016/04/05 职场文书
2016年企业安全生产月活动总结
2016/04/06 职场文书
golang中的空接口使用详解
2021/03/30 Python
学会Python数据可视化必须尝试这7个库
2021/06/16 Python