pytorch训练神经网络爆内存的解决方案


Posted in Python onMay 22, 2021

训练的时候内存一直在增加,最后内存爆满,被迫中断。

pytorch训练神经网络爆内存的解决方案

后来换了一个电脑发现还是这样,考虑是代码的问题。

检查才发现我的代码两次存了loss,只有一个地方写的是loss.item()。问题就在loss,因为loss是variable类型。

要写成loss_train = loss_train + loss.item(),不能直接写loss_train = loss_train + loss。否则就会发现随着epoch的增加,占的内存也在一点一点增加。

算是一个小坑吧,希望大家还是要仔细。

补充:pytorch神经网络解决回归问题(非常易懂)

对于pytorch的深度学习框架

在建立人工神经网络时整体的步骤主要有以下四步:

1、载入原始数据

2、构建具体神经网络

3、进行数据的训练

4、数据测试和验证

pytorch神经网络的数据载入,以MINIST书写字体的原始数据为例:

import torch
import matplotlib.pyplot as  plt
def plot_curve(data):
    fig=plt.figure()
    plt.plot(range(len(data)),data,color="blue")
    plt.legend(["value"],loc="upper right")
    plt.xlabel("step")
    plt.ylabel("value")
    plt.show()
 
def plot_image(img,label,name):
    fig=plt.figure()
    for i in range(6):
        plt.subplot(2,3,i+1)
        plt.tight_layout()
        plt.imshow(img[i][0]*0.3081+0.1307,cmap="gray",interpolation="none")
        plt.title("{}:{}".format(name, label[i].item()))
        plt.xticks([])
        plt.yticks([])
    plt.show()
def one_hot(label,depth=10):
    out=torch.zeros(label.size(0),depth)
    idx=torch.LongTensor(label).view(-1,1)
    out.scatter_(dim=1,index=idx,value=1)
    return out
 
batch_size=512
import torch
from torch import nn                         #完成神经网络的构建包
from torch.nn import functional as F         #包含常用的函数包
from torch import optim                      #优化工具包
import torchvision                           #视觉工具包
import  matplotlib.pyplot as plt
from utils import plot_curve,plot_image,one_hot
#step1 load dataset   加载数据包
train_loader=torch.utils.data.DataLoader(
    torchvision.datasets.MNIST("minist_data",train=True,download=True,transform=torchvision.transforms.Compose(
        [torchvision.transforms.ToTensor(),torchvision.transforms.Normalize((0.1307,),(0.3081,))
         ])),
    batch_size=batch_size,shuffle=True)
test_loader=torch.utils.data.DataLoader(
    torchvision.datasets.MNIST("minist_data",train=True,download=False,transform=torchvision.transforms.Compose(
        [torchvision.transforms.ToTensor(),torchvision.transforms.Normalize((0.1307,),(0.3081,))
         ])),
    batch_size=batch_size,shuffle=False)
x,y=next(iter(train_loader))
print(x.shape,y.shape)
plot_image(x,y,"image")
print(x)
print(y)

以构建一个简单的回归问题的神经网络为例,

其具体的实现代码如下所示:

import torch
import torch.nn.functional as F  # 激励函数都在这
 
x = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1)  # x data (tensor), shape=(100, 1)
y = x.pow(2) + 0.2 * torch.rand(x.size())  # noisy y data (tensor), shape=(100, 1)
 
class Net(torch.nn.Module):  # 继承 torch 的 Module(固定)
    def __init__(self, n_feature, n_hidden, n_output):  # 定义层的信息,n_feature多少个输入, n_hidden每层神经元, n_output多少个输出
        super(Net, self).__init__()  # 继承 __init__ 功能(固定)
        # 定义每层用什么样的形式
        self.hidden = torch.nn.Linear(n_feature, n_hidden)  # 定义隐藏层,线性输出
        self.predict = torch.nn.Linear(n_hidden, n_output)  # 定义输出层线性输出
 
    def forward(self, x):  # x是输入信息就是data,同时也是 Module 中的 forward 功能,定义神经网络前向传递的过程,把__init__中的层信息一个一个的组合起来
        # 正向传播输入值, 神经网络分析出输出值
        x = F.relu(self.hidden(x))  # 定义激励函数(隐藏层的线性值)
        x = self.predict(x)  # 输出层,输出值
        return x 
 
net = Net(n_feature=1, n_hidden=10, n_output=1) 
print(net)  # net 的结构
"""
Net (
  (hidden): Linear (1 -> 10)
  (predict): Linear (10 -> 1)
)
"""
# optimizer 是训练的工具
optimizer = torch.optim.SGD(net.parameters(), lr=0.2)  # 传入 net 的所有参数, 学习率
loss_func = torch.nn.MSELoss()  # 预测值和真实值的误差计算公式 (均方差)
 
for t in range(100):  # 训练的步数100步
    prediction = net(x)  # 喂给 net 训练数据 x, 每迭代一步,输出预测值
 
    loss = loss_func(prediction, y)  # 计算两者的误差
 
    # 优化步骤:
    optimizer.zero_grad()  # 清空上一步的残余更新参数值
    loss.backward()  # 误差反向传播, 计算参数更新值
    optimizer.step()  # 将参数更新值施加到 net 的 parameters 上
 
import matplotlib.pyplot as plt 
plt.ion()  # 实时画图something about plotting 
for t in range(200):
    prediction = net(x)  # input x and predict based on x 
    loss = loss_func(prediction, y)  # must be (1. nn output, 2. target) 
    optimizer.zero_grad()  # clear gradients for next train
    loss.backward()  # backpropagation, compute gradients
    optimizer.step()  # apply gradients
 
    if t % 5 == 0:  # 每五步绘一次图
        # plot and show learning process
        plt.cla()
        plt.scatter(x.data.numpy(), y.data.numpy())
        plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)
        plt.text(0.5, 0, 'Loss=%.4f' % loss.data.numpy(), fontdict={'size': 20, 'color': 'red'})
        plt.pause(0.1)
 
plt.ioff()
plt.show()

pytorch训练神经网络爆内存的解决方案

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python3 正在毁灭 Python的原因分析
Nov 28 Python
Python下的subprocess模块的入门指引
Apr 16 Python
使用python绘制常用的图表
Aug 27 Python
python数据类型_元组、字典常用操作方法(介绍)
May 30 Python
Python实现网站注册验证码生成类
Jun 08 Python
对python判断是否回文数的实例详解
Feb 08 Python
Python匿名函数及应用示例
Apr 09 Python
python中的TCP(传输控制协议)用法实例分析
Nov 15 Python
Python实现将元组中的元素作为参数传入函数的操作
Jun 05 Python
python实现学生信息管理系统(精简版)
Nov 27 Python
详解如何使用Pytest进行自动化测试
Jan 14 Python
Python中Matplotlib的点、线形状、颜色以及绘制散点图
Apr 07 Python
粗暴解决CUDA out of memory的问题
May 22 #Python
pytorch中的model.eval()和BN层的使用
May 22 #Python
解决Pytorch中关于model.eval的问题
Pytorch 中net.train 和 net.eval的使用说明
May 22 #Python
对PyTorch中inplace字段的全面理解
May 22 #Python
pytorch中F.avg_pool1d()和F.avg_pool2d()的使用操作
May 22 #Python
用python实现监控视频人数统计
You might like
php正则
2006/07/07 PHP
安装APACHE
2007/01/15 PHP
CentOS 6.2使用yum安装LAMP以及phpMyadmin详解
2013/06/17 PHP
memcache命令启动参数中文解释
2014/01/13 PHP
THINKPHP2.0到3.0有哪些改进之处
2015/01/04 PHP
php打包压缩文件之ZipArchive方法用法分析
2016/04/30 PHP
PHP安全之register_globals的on和off的区别
2020/07/23 PHP
jQuery筛选器children()案例详解(图文)
2013/02/17 Javascript
轻松创建nodejs服务器(3):代码模块化
2014/12/18 NodeJs
jquery 中ajax执行的优先级
2015/06/22 Javascript
Javascript生成全局唯一标识符(GUID,UUID)的方法
2016/02/27 Javascript
javascript如何实现360度全景照片问题汇总
2016/04/04 Javascript
使用jQuery实现Web页面换肤功能的要点解析
2016/05/12 Javascript
JS判断Android、iOS或浏览器的多种方法(四种方法)
2017/06/29 Javascript
ReactNative实现Toast的示例
2017/12/31 Javascript
详解基于vue-cli3.0如何构建功能完善的前端架子
2018/10/09 Javascript
vue 2.5.1 源码学习 之Vue.extend 和 data的合并策略
2019/06/04 Javascript
DatePickerDialog 自定义样式及使用全解
2019/07/09 Javascript
Python入门篇之编程习惯与特点
2014/10/17 Python
Python selenium 三种等待方式解读
2016/09/15 Python
Python实现通过文件路径获取文件hash值的方法
2017/04/29 Python
使用PyQtGraph绘制精美的股票行情K线图的示例代码
2019/03/14 Python
python实现在函数中修改变量值的方法
2019/07/16 Python
将python安装信息加入注册表的示例
2019/11/20 Python
关于Python Tkinter Button控件command传参问题的解决方式
2020/03/04 Python
Python decimal模块使用方法详解
2020/06/08 Python
浅谈keras 的抽象后端(from keras import backend as K)
2020/06/16 Python
零基础小白多久能学会python
2020/06/22 Python
Python下划线5种含义代码实例解析
2020/07/10 Python
Python+kivy BoxLayout布局示例代码详解
2020/12/28 Python
python 制作网站筛选工具(附源码)
2021/01/21 Python
音乐学院硕士生的自我评价分享
2013/11/01 职场文书
给校长的一封建议书
2014/03/12 职场文书
运动会加油口号
2014/06/07 职场文书
python中sys模块的介绍与实例
2021/04/17 Python
win10识别不了U盘怎么办 win10系统读取U盘失败的解决办法
2022/08/05 数码科技