pytorch训练神经网络爆内存的解决方案


Posted in Python onMay 22, 2021

训练的时候内存一直在增加,最后内存爆满,被迫中断。

pytorch训练神经网络爆内存的解决方案

后来换了一个电脑发现还是这样,考虑是代码的问题。

检查才发现我的代码两次存了loss,只有一个地方写的是loss.item()。问题就在loss,因为loss是variable类型。

要写成loss_train = loss_train + loss.item(),不能直接写loss_train = loss_train + loss。否则就会发现随着epoch的增加,占的内存也在一点一点增加。

算是一个小坑吧,希望大家还是要仔细。

补充:pytorch神经网络解决回归问题(非常易懂)

对于pytorch的深度学习框架

在建立人工神经网络时整体的步骤主要有以下四步:

1、载入原始数据

2、构建具体神经网络

3、进行数据的训练

4、数据测试和验证

pytorch神经网络的数据载入,以MINIST书写字体的原始数据为例:

import torch
import matplotlib.pyplot as  plt
def plot_curve(data):
    fig=plt.figure()
    plt.plot(range(len(data)),data,color="blue")
    plt.legend(["value"],loc="upper right")
    plt.xlabel("step")
    plt.ylabel("value")
    plt.show()
 
def plot_image(img,label,name):
    fig=plt.figure()
    for i in range(6):
        plt.subplot(2,3,i+1)
        plt.tight_layout()
        plt.imshow(img[i][0]*0.3081+0.1307,cmap="gray",interpolation="none")
        plt.title("{}:{}".format(name, label[i].item()))
        plt.xticks([])
        plt.yticks([])
    plt.show()
def one_hot(label,depth=10):
    out=torch.zeros(label.size(0),depth)
    idx=torch.LongTensor(label).view(-1,1)
    out.scatter_(dim=1,index=idx,value=1)
    return out
 
batch_size=512
import torch
from torch import nn                         #完成神经网络的构建包
from torch.nn import functional as F         #包含常用的函数包
from torch import optim                      #优化工具包
import torchvision                           #视觉工具包
import  matplotlib.pyplot as plt
from utils import plot_curve,plot_image,one_hot
#step1 load dataset   加载数据包
train_loader=torch.utils.data.DataLoader(
    torchvision.datasets.MNIST("minist_data",train=True,download=True,transform=torchvision.transforms.Compose(
        [torchvision.transforms.ToTensor(),torchvision.transforms.Normalize((0.1307,),(0.3081,))
         ])),
    batch_size=batch_size,shuffle=True)
test_loader=torch.utils.data.DataLoader(
    torchvision.datasets.MNIST("minist_data",train=True,download=False,transform=torchvision.transforms.Compose(
        [torchvision.transforms.ToTensor(),torchvision.transforms.Normalize((0.1307,),(0.3081,))
         ])),
    batch_size=batch_size,shuffle=False)
x,y=next(iter(train_loader))
print(x.shape,y.shape)
plot_image(x,y,"image")
print(x)
print(y)

以构建一个简单的回归问题的神经网络为例,

其具体的实现代码如下所示:

import torch
import torch.nn.functional as F  # 激励函数都在这
 
x = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1)  # x data (tensor), shape=(100, 1)
y = x.pow(2) + 0.2 * torch.rand(x.size())  # noisy y data (tensor), shape=(100, 1)
 
class Net(torch.nn.Module):  # 继承 torch 的 Module(固定)
    def __init__(self, n_feature, n_hidden, n_output):  # 定义层的信息,n_feature多少个输入, n_hidden每层神经元, n_output多少个输出
        super(Net, self).__init__()  # 继承 __init__ 功能(固定)
        # 定义每层用什么样的形式
        self.hidden = torch.nn.Linear(n_feature, n_hidden)  # 定义隐藏层,线性输出
        self.predict = torch.nn.Linear(n_hidden, n_output)  # 定义输出层线性输出
 
    def forward(self, x):  # x是输入信息就是data,同时也是 Module 中的 forward 功能,定义神经网络前向传递的过程,把__init__中的层信息一个一个的组合起来
        # 正向传播输入值, 神经网络分析出输出值
        x = F.relu(self.hidden(x))  # 定义激励函数(隐藏层的线性值)
        x = self.predict(x)  # 输出层,输出值
        return x 
 
net = Net(n_feature=1, n_hidden=10, n_output=1) 
print(net)  # net 的结构
"""
Net (
  (hidden): Linear (1 -> 10)
  (predict): Linear (10 -> 1)
)
"""
# optimizer 是训练的工具
optimizer = torch.optim.SGD(net.parameters(), lr=0.2)  # 传入 net 的所有参数, 学习率
loss_func = torch.nn.MSELoss()  # 预测值和真实值的误差计算公式 (均方差)
 
for t in range(100):  # 训练的步数100步
    prediction = net(x)  # 喂给 net 训练数据 x, 每迭代一步,输出预测值
 
    loss = loss_func(prediction, y)  # 计算两者的误差
 
    # 优化步骤:
    optimizer.zero_grad()  # 清空上一步的残余更新参数值
    loss.backward()  # 误差反向传播, 计算参数更新值
    optimizer.step()  # 将参数更新值施加到 net 的 parameters 上
 
import matplotlib.pyplot as plt 
plt.ion()  # 实时画图something about plotting 
for t in range(200):
    prediction = net(x)  # input x and predict based on x 
    loss = loss_func(prediction, y)  # must be (1. nn output, 2. target) 
    optimizer.zero_grad()  # clear gradients for next train
    loss.backward()  # backpropagation, compute gradients
    optimizer.step()  # apply gradients
 
    if t % 5 == 0:  # 每五步绘一次图
        # plot and show learning process
        plt.cla()
        plt.scatter(x.data.numpy(), y.data.numpy())
        plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)
        plt.text(0.5, 0, 'Loss=%.4f' % loss.data.numpy(), fontdict={'size': 20, 'color': 'red'})
        plt.pause(0.1)
 
plt.ioff()
plt.show()

pytorch训练神经网络爆内存的解决方案

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python实现通过代理服务器访问远程url的方法
Apr 29 Python
简单掌握Python中glob模块查找文件路径的用法
Jul 05 Python
python opencv之SURF算法示例
Feb 24 Python
python中利用h5py模块读取h5文件中的主键方法
Jun 05 Python
django Serializer序列化使用方法详解
Oct 16 Python
Python Matplotlib库安装与基本作图示例
Jan 09 Python
python自动发送测试报告邮件功能的实现
Jan 22 Python
Python语言进阶知识点总结
May 28 Python
Tensorflow 多线程设置方式
Feb 06 Python
Python3搭建http服务器的实现代码
Feb 11 Python
python递归调用中的坑:打印有值, 返回却None
Mar 16 Python
Pytorch数据拼接与拆分操作实现图解
Apr 30 Python
粗暴解决CUDA out of memory的问题
May 22 #Python
pytorch中的model.eval()和BN层的使用
May 22 #Python
解决Pytorch中关于model.eval的问题
Pytorch 中net.train 和 net.eval的使用说明
May 22 #Python
对PyTorch中inplace字段的全面理解
May 22 #Python
pytorch中F.avg_pool1d()和F.avg_pool2d()的使用操作
May 22 #Python
用python实现监控视频人数统计
You might like
php面向对象全攻略 (十一)__toString()用法 克隆对象 __call处理调用错误
2009/09/30 PHP
PHP实现的mysql主从数据库状态检测功能示例
2017/07/20 PHP
javascript得到XML某节点的子节点个数的脚本
2008/10/11 Javascript
获得Javascript对象属性个数的示例代码
2013/11/21 Javascript
基于jQuery的图片不完全按比例自动缩小
2014/07/11 Javascript
JavaScript实现的GBK、UTF8字符串实际长度计算函数
2014/08/27 Javascript
JS+CSS实现的经典圆角下拉菜单效果代码
2015/10/21 Javascript
php利用curl获取远程图片实现方法
2015/10/26 Javascript
jquery中cookie用法实例详解(获取,存储,删除等)
2016/01/04 Javascript
详解angularJs指令的3种绑定策略
2017/04/13 Javascript
详谈jQuery中使用attr(), prop(), val()获取value的异同
2017/04/25 jQuery
详解Vue中使用v-for语句抛出错误的解决方案
2017/05/04 Javascript
nodejs mysql 实现分页的方法
2017/06/06 NodeJs
jQuery制作input提示内容(兼容IE8以上)
2017/07/05 jQuery
vue axios用法教程详解
2017/07/23 Javascript
微信小程序页面间传值与页面取值操作实例分析
2019/04/30 Javascript
JS使用new操作符创建对象的方法分析
2019/05/30 Javascript
详解vue-cli项目开发/生产环境代理实现跨域请求
2019/07/23 Javascript
vue router 跳转时打开新页面的示例方法
2019/07/28 Javascript
解决layui数据表格排序图标被超出的表头挤出去的问题
2019/09/19 Javascript
vue 解决computed修改data数据的问题
2019/11/06 Javascript
JS实现省市县三级下拉联动
2020/04/10 Javascript
基于vue实现微博三方登录流程解析
2020/11/04 Javascript
Windows下搭建python开发环境详细步骤
2020/07/20 Python
Python的collections模块中namedtuple结构使用示例
2016/07/07 Python
Python拼接字符串的7种方法总结
2018/11/01 Python
简单了解django文件下载方式
2020/02/10 Python
Python 解析pymysql模块操作数据库的方法
2020/02/18 Python
保时捷设计:Porsche Design
2019/03/30 全球购物
英国门销售网站:Green Tree Doors
2020/01/07 全球购物
运动会800米加油稿
2014/02/22 职场文书
学生检讨书范文
2014/10/30 职场文书
介绍信样本
2015/01/31 职场文书
毕业晚宴祝酒词
2015/08/11 职场文书
一文搞懂如何实现Go 超时控制
2021/03/30 Python
有趣的二维码:使用MyQR和qrcode来制作二维码
2021/05/10 Python