pytorch训练神经网络爆内存的解决方案


Posted in Python onMay 22, 2021

训练的时候内存一直在增加,最后内存爆满,被迫中断。

pytorch训练神经网络爆内存的解决方案

后来换了一个电脑发现还是这样,考虑是代码的问题。

检查才发现我的代码两次存了loss,只有一个地方写的是loss.item()。问题就在loss,因为loss是variable类型。

要写成loss_train = loss_train + loss.item(),不能直接写loss_train = loss_train + loss。否则就会发现随着epoch的增加,占的内存也在一点一点增加。

算是一个小坑吧,希望大家还是要仔细。

补充:pytorch神经网络解决回归问题(非常易懂)

对于pytorch的深度学习框架

在建立人工神经网络时整体的步骤主要有以下四步:

1、载入原始数据

2、构建具体神经网络

3、进行数据的训练

4、数据测试和验证

pytorch神经网络的数据载入,以MINIST书写字体的原始数据为例:

import torch
import matplotlib.pyplot as  plt
def plot_curve(data):
    fig=plt.figure()
    plt.plot(range(len(data)),data,color="blue")
    plt.legend(["value"],loc="upper right")
    plt.xlabel("step")
    plt.ylabel("value")
    plt.show()
 
def plot_image(img,label,name):
    fig=plt.figure()
    for i in range(6):
        plt.subplot(2,3,i+1)
        plt.tight_layout()
        plt.imshow(img[i][0]*0.3081+0.1307,cmap="gray",interpolation="none")
        plt.title("{}:{}".format(name, label[i].item()))
        plt.xticks([])
        plt.yticks([])
    plt.show()
def one_hot(label,depth=10):
    out=torch.zeros(label.size(0),depth)
    idx=torch.LongTensor(label).view(-1,1)
    out.scatter_(dim=1,index=idx,value=1)
    return out
 
batch_size=512
import torch
from torch import nn                         #完成神经网络的构建包
from torch.nn import functional as F         #包含常用的函数包
from torch import optim                      #优化工具包
import torchvision                           #视觉工具包
import  matplotlib.pyplot as plt
from utils import plot_curve,plot_image,one_hot
#step1 load dataset   加载数据包
train_loader=torch.utils.data.DataLoader(
    torchvision.datasets.MNIST("minist_data",train=True,download=True,transform=torchvision.transforms.Compose(
        [torchvision.transforms.ToTensor(),torchvision.transforms.Normalize((0.1307,),(0.3081,))
         ])),
    batch_size=batch_size,shuffle=True)
test_loader=torch.utils.data.DataLoader(
    torchvision.datasets.MNIST("minist_data",train=True,download=False,transform=torchvision.transforms.Compose(
        [torchvision.transforms.ToTensor(),torchvision.transforms.Normalize((0.1307,),(0.3081,))
         ])),
    batch_size=batch_size,shuffle=False)
x,y=next(iter(train_loader))
print(x.shape,y.shape)
plot_image(x,y,"image")
print(x)
print(y)

以构建一个简单的回归问题的神经网络为例,

其具体的实现代码如下所示:

import torch
import torch.nn.functional as F  # 激励函数都在这
 
x = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1)  # x data (tensor), shape=(100, 1)
y = x.pow(2) + 0.2 * torch.rand(x.size())  # noisy y data (tensor), shape=(100, 1)
 
class Net(torch.nn.Module):  # 继承 torch 的 Module(固定)
    def __init__(self, n_feature, n_hidden, n_output):  # 定义层的信息,n_feature多少个输入, n_hidden每层神经元, n_output多少个输出
        super(Net, self).__init__()  # 继承 __init__ 功能(固定)
        # 定义每层用什么样的形式
        self.hidden = torch.nn.Linear(n_feature, n_hidden)  # 定义隐藏层,线性输出
        self.predict = torch.nn.Linear(n_hidden, n_output)  # 定义输出层线性输出
 
    def forward(self, x):  # x是输入信息就是data,同时也是 Module 中的 forward 功能,定义神经网络前向传递的过程,把__init__中的层信息一个一个的组合起来
        # 正向传播输入值, 神经网络分析出输出值
        x = F.relu(self.hidden(x))  # 定义激励函数(隐藏层的线性值)
        x = self.predict(x)  # 输出层,输出值
        return x 
 
net = Net(n_feature=1, n_hidden=10, n_output=1) 
print(net)  # net 的结构
"""
Net (
  (hidden): Linear (1 -> 10)
  (predict): Linear (10 -> 1)
)
"""
# optimizer 是训练的工具
optimizer = torch.optim.SGD(net.parameters(), lr=0.2)  # 传入 net 的所有参数, 学习率
loss_func = torch.nn.MSELoss()  # 预测值和真实值的误差计算公式 (均方差)
 
for t in range(100):  # 训练的步数100步
    prediction = net(x)  # 喂给 net 训练数据 x, 每迭代一步,输出预测值
 
    loss = loss_func(prediction, y)  # 计算两者的误差
 
    # 优化步骤:
    optimizer.zero_grad()  # 清空上一步的残余更新参数值
    loss.backward()  # 误差反向传播, 计算参数更新值
    optimizer.step()  # 将参数更新值施加到 net 的 parameters 上
 
import matplotlib.pyplot as plt 
plt.ion()  # 实时画图something about plotting 
for t in range(200):
    prediction = net(x)  # input x and predict based on x 
    loss = loss_func(prediction, y)  # must be (1. nn output, 2. target) 
    optimizer.zero_grad()  # clear gradients for next train
    loss.backward()  # backpropagation, compute gradients
    optimizer.step()  # apply gradients
 
    if t % 5 == 0:  # 每五步绘一次图
        # plot and show learning process
        plt.cla()
        plt.scatter(x.data.numpy(), y.data.numpy())
        plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)
        plt.text(0.5, 0, 'Loss=%.4f' % loss.data.numpy(), fontdict={'size': 20, 'color': 'red'})
        plt.pause(0.1)
 
plt.ioff()
plt.show()

pytorch训练神经网络爆内存的解决方案

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python获取一组汉字拼音首字母的方法
Jul 01 Python
python通过cookie模拟已登录状态的初步研究
Nov 09 Python
Python实现遍历目录的方法【测试可用】
Mar 22 Python
关于Python正则表达式 findall函数问题详解
Mar 22 Python
virtualenv 指定 python 解释器的版本方法
Oct 25 Python
Python数据可视化:箱线图多种库画法
Nov 06 Python
python中的数组赋值与拷贝的区别详解
Nov 26 Python
python函数不定长参数使用方法解析
Dec 14 Python
基于python读取.mat文件并取出信息
Dec 16 Python
pytorch模型预测结果与ndarray互转方式
Jan 15 Python
执行Python程序时模块报错问题
Mar 26 Python
浅谈matplotlib默认字体设置探索
Feb 03 Python
粗暴解决CUDA out of memory的问题
May 22 #Python
pytorch中的model.eval()和BN层的使用
May 22 #Python
解决Pytorch中关于model.eval的问题
Pytorch 中net.train 和 net.eval的使用说明
May 22 #Python
对PyTorch中inplace字段的全面理解
May 22 #Python
pytorch中F.avg_pool1d()和F.avg_pool2d()的使用操作
May 22 #Python
用python实现监控视频人数统计
You might like
GD输出汉字的函数的分析
2006/10/09 PHP
php的zip解压缩类pclzip使用示例
2014/03/14 PHP
PHP5中实现多态的两种方法实例分享
2014/04/21 PHP
PHP正则验证Email的方法
2015/06/15 PHP
PHP二进制与字符串之间的相互转换教程
2016/10/14 PHP
Yii2配置Nginx伪静态的方法
2017/05/05 PHP
Javascript调用XML制作连动下拉列表框
2006/06/25 Javascript
用js判断浏览器是否是IE的比较好的办法
2007/05/08 Javascript
jQuery学习4 浏览器的事件模型
2010/02/07 Javascript
jquery Moblie入门—hello world的示例代码学习
2013/01/08 Javascript
js原生appendChild的bug解决心得分享
2013/07/01 Javascript
js页面跳转的常用方法整理
2013/10/18 Javascript
利用JS解决ie6不支持max-width,max-height问题的方法
2014/01/02 Javascript
JavaScript数值转换的三种方式总结
2014/07/31 Javascript
使用jquery清空、复位整个输入域
2015/04/02 Javascript
jQuery实现向下滑出的平滑下拉菜单效果
2015/08/21 Javascript
Bootstrap每天必学之日期控制
2016/03/07 Javascript
浅谈jQuery中Ajax事件beforesend及各参数含义
2016/12/03 Javascript
js判断一个字符串是以某个字符串开头的简单实例
2016/12/27 Javascript
vue首次赋值不触发watch的解决方法
2018/09/11 Javascript
微信小程序页面间传值与页面取值操作实例分析
2019/04/30 Javascript
详解json串反转义(消除反斜杠)
2019/08/12 Javascript
js实现星星打分效果
2020/07/05 Javascript
python创建只读属性对象的方法(ReadOnlyObject)
2013/02/10 Python
Python序列之list和tuple常用方法以及注意事项
2015/01/09 Python
Python基本语法经典教程
2016/03/11 Python
python 自动化将markdown文件转成html文件的方法
2016/09/23 Python
Python方法的延迟加载的示例代码
2017/12/18 Python
利用Python实现原创工具的Logo与Help
2018/12/03 Python
pyqt5实现按钮添加背景图片以及背景图片的切换方法
2019/06/13 Python
Python HTMLTestRunner测试报告view按钮失效解决方案
2020/05/25 Python
scrapy结合selenium解析动态页面的实现
2020/09/28 Python
Steiff台湾官网:德国金耳釦泰迪熊
2019/12/26 全球购物
财务工作个人总结
2015/02/27 职场文书
圆明园纪录片观后感
2015/06/03 职场文书
导游词之泰山玉皇顶
2019/12/23 职场文书