pytorch 6 batch_train 批训练操作


Posted in Python onMay 28, 2021

看代码吧~

import torch
import torch.utils.data as Data
torch.manual_seed(1)    # reproducible
# BATCH_SIZE = 5  
BATCH_SIZE = 8      # 每次使用8个数据同时传入网路
x = torch.linspace(1, 10, 10)       # this is x data (torch tensor)
y = torch.linspace(10, 1, 10)       # this is y data (torch tensor)
torch_dataset = Data.TensorDataset(x, y)
loader = Data.DataLoader(
    dataset=torch_dataset,      # torch TensorDataset format
    batch_size=BATCH_SIZE,      # mini batch size
    shuffle=False,              # 设置不随机打乱数据 random shuffle for training
    num_workers=2,              # 使用两个进程提取数据,subprocesses for loading data
)
def show_batch():
    for epoch in range(3):   # 全部的数据使用3遍,train entire dataset 3 times
        for step, (batch_x, batch_y) in enumerate(loader):  # for each training step
            # train your data...
            print('Epoch: ', epoch, '| Step: ', step, '| batch x: ',
                  batch_x.numpy(), '| batch y: ', batch_y.numpy())
if __name__ == '__main__':
    show_batch()

BATCH_SIZE = 8 , 所有数据利用三次

Epoch:  0 | Step:  0 | batch x:  [1. 2. 3. 4. 5. 6. 7. 8.] | batch y:  [10.  9.  8.  7.  6.  5.  4.  3.]
Epoch:  0 | Step:  1 | batch x:  [ 9. 10.] | batch y:  [2. 1.]
Epoch:  1 | Step:  0 | batch x:  [1. 2. 3. 4. 5. 6. 7. 8.] | batch y:  [10.  9.  8.  7.  6.  5.  4.  3.]
Epoch:  1 | Step:  1 | batch x:  [ 9. 10.] | batch y:  [2. 1.]
Epoch:  2 | Step:  0 | batch x:  [1. 2. 3. 4. 5. 6. 7. 8.] | batch y:  [10.  9.  8.  7.  6.  5.  4.  3.]
Epoch:  2 | Step:  1 | batch x:  [ 9. 10.] | batch y:  [2. 1.]

补充:pytorch批训练bug

问题描述:

在进行pytorch神经网络批训练的时候,有时会出现报错 

TypeError: batch must contain tensors, numbers, dicts or lists; found <class 'torch.autograd.variable.Variable'>

解决办法:

第一步:

检查(重点!!!!!):

train_dataset = Data.TensorDataset(train_x, train_y)

train_x,和train_y格式,要求是tensor类,我第一次出错就是因为传入的是variable

可以这样将数据变为tensor类:

train_x = torch.FloatTensor(train_x)

第二步:

train_loader = Data.DataLoader(
        dataset=train_dataset,
        batch_size=batch_size,
        shuffle=True
    )

实例化一个DataLoader对象

第三步:

for epoch in range(epochs):
        for step, (batch_x, batch_y) in enumerate(train_loader):
            batch_x, batch_y = Variable(batch_x), Variable(batch_y)

这样就可以批训练了

需要注意的是:train_loader输出的是tensor,在训练网络时,需要变成Variable

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python获取电脑硬件信息及状态的实现方法
Aug 29 Python
Python for Informatics 第11章之正则表达式(二)
Apr 21 Python
对Python中list的倒序索引和切片实例讲解
Nov 15 Python
Django 外键的使用方法详解
Jul 19 Python
python 导入数据及作图的实现
Dec 03 Python
pytorch实现mnist分类的示例讲解
Jan 10 Python
python encrypt 实现AES加密的实例详解
Feb 20 Python
Django 返回json数据的实现示例
Mar 05 Python
django中嵌套的try-except实例
May 21 Python
python 实现汉诺塔游戏
Nov 28 Python
Python自动化测试PO模型封装过程详解
Jun 22 Python
Python极值整数的边界探讨分析
Sep 15 Python
pytorch 如何使用batch训练lstm网络
May 28 #Python
使用Pytorch训练two-head网络的操作
May 28 #Python
使用Python的开发框架Brownie部署以太坊智能合约
使用Pytorch实现two-head(多输出)模型的操作
8g内存用python读取10文件_面试题-python 如何读取一个大于 10G 的txt文件?
用python画城市轮播地图
用Python实现一个打字速度测试工具来测试你的手速
You might like
php array_merge下进行数组合并的代码
2008/07/22 PHP
PHP与MYSQL中UTF8 中文排序示例代码
2014/10/23 PHP
php实现图片添加描边字和马赛克的方法
2014/12/10 PHP
用js实现键盘方向键翻页功能的代码
2007/06/03 Javascript
zepto.js中tap事件阻止冒泡的实现方法
2015/02/12 Javascript
基于JS实现无缝滚动思路及代码分享
2016/06/07 Javascript
js实现股票实时刷新数据案例
2017/05/14 Javascript
iscroll-probe实现下拉刷新和下拉加载效果
2017/06/28 Javascript
Angular 组件之间的交互的示例代码
2018/03/24 Javascript
详解mpvue scroll-view自动回弹bug解决方案
2018/10/01 Javascript
使用 Vue cli 3.0 构建自定义组件库的方法
2019/04/30 Javascript
javascript简单实现深浅拷贝过程详解
2019/10/08 Javascript
原生js实现碰撞检测
2020/03/12 Javascript
浅谈javascript如何获取文件后缀名
2020/08/07 Javascript
nginx配置域名后的二级目录访问不同项目的配置操作
2020/11/06 Javascript
[13:21]DOTA2国际邀请赛采访专栏:RSnake战队国士无双,Fnatic.Fly
2013/08/06 DOTA
[03:12]完美世界DOTA2联赛PWL DAY6集锦
2020/11/05 DOTA
Python实现将文本生成二维码的方法示例
2017/07/18 Python
Django框架实现逆向解析url的方法
2018/07/04 Python
基于python实现高速视频传输程序
2019/05/05 Python
Python从list类型、range()序列简单认识类(class)【可迭代】
2019/05/31 Python
python调用摄像头拍摄数据集
2019/06/01 Python
python tornado修改log输出方式
2019/11/18 Python
python实现局域网内实时通信代码
2019/12/22 Python
Python中os模块功能与用法详解
2020/02/26 Python
HTML5 和小程序实现拍照图片旋转、压缩和上传功能
2018/10/08 HTML / CSS
美国购买和销售礼品卡平台:Raise
2017/01/13 全球购物
波兰品牌内衣及泳装网上商店:Astratex.pl
2017/02/03 全球购物
哈利波特商店:Harry Potter Shop
2018/11/30 全球购物
厨房工作人员岗位职责
2013/11/15 职场文书
历史专业个人求职信范文
2013/12/07 职场文书
入党积极分子学习两会心得体会范文
2014/03/17 职场文书
党的群众路线教育实践活动对照检查剖析材料
2014/10/09 职场文书
领导班子整改方案和个人整改措施
2014/10/25 职场文书
2016年教师政治思想表现评语
2015/12/02 职场文书
Python实现socket库网络通信套接字
2021/06/04 Python