pytorch 6 batch_train 批训练操作


Posted in Python onMay 28, 2021

看代码吧~

import torch
import torch.utils.data as Data
torch.manual_seed(1)    # reproducible
# BATCH_SIZE = 5  
BATCH_SIZE = 8      # 每次使用8个数据同时传入网路
x = torch.linspace(1, 10, 10)       # this is x data (torch tensor)
y = torch.linspace(10, 1, 10)       # this is y data (torch tensor)
torch_dataset = Data.TensorDataset(x, y)
loader = Data.DataLoader(
    dataset=torch_dataset,      # torch TensorDataset format
    batch_size=BATCH_SIZE,      # mini batch size
    shuffle=False,              # 设置不随机打乱数据 random shuffle for training
    num_workers=2,              # 使用两个进程提取数据,subprocesses for loading data
)
def show_batch():
    for epoch in range(3):   # 全部的数据使用3遍,train entire dataset 3 times
        for step, (batch_x, batch_y) in enumerate(loader):  # for each training step
            # train your data...
            print('Epoch: ', epoch, '| Step: ', step, '| batch x: ',
                  batch_x.numpy(), '| batch y: ', batch_y.numpy())
if __name__ == '__main__':
    show_batch()

BATCH_SIZE = 8 , 所有数据利用三次

Epoch:  0 | Step:  0 | batch x:  [1. 2. 3. 4. 5. 6. 7. 8.] | batch y:  [10.  9.  8.  7.  6.  5.  4.  3.]
Epoch:  0 | Step:  1 | batch x:  [ 9. 10.] | batch y:  [2. 1.]
Epoch:  1 | Step:  0 | batch x:  [1. 2. 3. 4. 5. 6. 7. 8.] | batch y:  [10.  9.  8.  7.  6.  5.  4.  3.]
Epoch:  1 | Step:  1 | batch x:  [ 9. 10.] | batch y:  [2. 1.]
Epoch:  2 | Step:  0 | batch x:  [1. 2. 3. 4. 5. 6. 7. 8.] | batch y:  [10.  9.  8.  7.  6.  5.  4.  3.]
Epoch:  2 | Step:  1 | batch x:  [ 9. 10.] | batch y:  [2. 1.]

补充:pytorch批训练bug

问题描述:

在进行pytorch神经网络批训练的时候,有时会出现报错 

TypeError: batch must contain tensors, numbers, dicts or lists; found <class 'torch.autograd.variable.Variable'>

解决办法:

第一步:

检查(重点!!!!!):

train_dataset = Data.TensorDataset(train_x, train_y)

train_x,和train_y格式,要求是tensor类,我第一次出错就是因为传入的是variable

可以这样将数据变为tensor类:

train_x = torch.FloatTensor(train_x)

第二步:

train_loader = Data.DataLoader(
        dataset=train_dataset,
        batch_size=batch_size,
        shuffle=True
    )

实例化一个DataLoader对象

第三步:

for epoch in range(epochs):
        for step, (batch_x, batch_y) in enumerate(train_loader):
            batch_x, batch_y = Variable(batch_x), Variable(batch_y)

这样就可以批训练了

需要注意的是:train_loader输出的是tensor,在训练网络时,需要变成Variable

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
使用python实现递归版汉诺塔示例(汉诺塔递归算法)
Apr 08 Python
python执行get提交的方法
Apr 29 Python
Java中重定向输出流实现用文件记录程序日志
Jun 12 Python
使用Python进行二进制文件读写的简单方法(推荐)
Sep 12 Python
Python基础练习之用户登录实现代码分享
Nov 08 Python
python+selenium识别验证码并登录的示例代码
Dec 21 Python
python表格存取的方法
Mar 07 Python
Flask框架配置与调试操作示例
Jul 23 Python
python生成九宫格图片
Nov 19 Python
python代码 输入数字使其反向输出的方法
Dec 22 Python
详解python itertools功能
Feb 07 Python
Python实现简繁体转换
Jun 07 Python
pytorch 如何使用batch训练lstm网络
May 28 #Python
使用Pytorch训练two-head网络的操作
May 28 #Python
使用Python的开发框架Brownie部署以太坊智能合约
使用Pytorch实现two-head(多输出)模型的操作
8g内存用python读取10文件_面试题-python 如何读取一个大于 10G 的txt文件?
用python画城市轮播地图
用Python实现一个打字速度测试工具来测试你的手速
You might like
php zlib压缩和解压缩swf文件的代码
2008/12/30 PHP
PHP 多维数组排序实现代码
2009/08/05 PHP
php获取bing每日壁纸示例分享
2014/02/25 PHP
排序算法之PHP版快速排序、冒泡排序
2014/04/09 PHP
php文件服务实现虚拟挂载其他目录示例
2014/04/17 PHP
PHP生成和获取XML格式数据的方法
2016/03/04 PHP
js实现漂浮回顶部按钮实例
2015/05/06 Javascript
解决WordPress使用CDN后博文无法评论的错误
2015/12/15 Javascript
JavaScript实战(原生range和自定义特效)简单实例
2016/08/21 Javascript
利用Jquery队列实现根据输入数量显示的动画
2016/09/01 Javascript
利用BootStrap的Carousel.js实现轮播图动画效果
2016/12/21 Javascript
微信小程序中实现一对多发消息详解及实例代码
2017/02/14 Javascript
js实现下拉菜单效果
2017/03/01 Javascript
微信小程序 中wx.chooseAddress(OBJECT)实例详解
2017/03/31 Javascript
JS验证全角与半角及相互转化的介绍
2017/05/18 Javascript
详解jquery选择器的原理
2017/08/01 jQuery
BootStrap入门学习第一篇
2017/08/28 Javascript
Vue.js实现表格渲染的方法
2018/09/07 Javascript
vue 引用自定义ttf、otf、在线字体的方法
2019/05/09 Javascript
详解使用WebPack搭建React开发环境
2019/08/06 Javascript
Vue el-autocomplete远程搜索下拉框并实现自动填充功能(推荐)
2019/10/25 Javascript
JS中FormData类实现文件上传
2020/03/27 Javascript
pytyon 带有重复的全排列
2013/08/13 Python
python中使用mysql数据库详细介绍
2015/03/27 Python
修改 CentOS 6.x 上默认Python的方法
2019/09/06 Python
python编程进阶之异常处理用法实例分析
2020/02/21 Python
Russell Stover巧克力官方网站:美国领先的精美巧克力制造商
2016/11/27 全球购物
Forever 21美国官网:美国标志性快时尚品牌
2017/02/20 全球购物
Timberland俄罗斯官方网上商店:全球领先的户外品牌
2020/03/15 全球购物
介绍一下内联、左联、右联
2013/12/31 面试题
什么时候需要进行强制类型转换
2016/09/03 面试题
食品安全标语
2014/06/07 职场文书
服装仓管员岗位职责
2014/06/17 职场文书
小学见习报告
2014/10/31 职场文书
2014年财务个人工作总结
2014/12/08 职场文书
OpenCV3.3+Python3.6实现图片高斯模糊
2021/05/18 Python