pytorch 6 batch_train 批训练操作


Posted in Python onMay 28, 2021

看代码吧~

import torch
import torch.utils.data as Data
torch.manual_seed(1)    # reproducible
# BATCH_SIZE = 5  
BATCH_SIZE = 8      # 每次使用8个数据同时传入网路
x = torch.linspace(1, 10, 10)       # this is x data (torch tensor)
y = torch.linspace(10, 1, 10)       # this is y data (torch tensor)
torch_dataset = Data.TensorDataset(x, y)
loader = Data.DataLoader(
    dataset=torch_dataset,      # torch TensorDataset format
    batch_size=BATCH_SIZE,      # mini batch size
    shuffle=False,              # 设置不随机打乱数据 random shuffle for training
    num_workers=2,              # 使用两个进程提取数据,subprocesses for loading data
)
def show_batch():
    for epoch in range(3):   # 全部的数据使用3遍,train entire dataset 3 times
        for step, (batch_x, batch_y) in enumerate(loader):  # for each training step
            # train your data...
            print('Epoch: ', epoch, '| Step: ', step, '| batch x: ',
                  batch_x.numpy(), '| batch y: ', batch_y.numpy())
if __name__ == '__main__':
    show_batch()

BATCH_SIZE = 8 , 所有数据利用三次

Epoch:  0 | Step:  0 | batch x:  [1. 2. 3. 4. 5. 6. 7. 8.] | batch y:  [10.  9.  8.  7.  6.  5.  4.  3.]
Epoch:  0 | Step:  1 | batch x:  [ 9. 10.] | batch y:  [2. 1.]
Epoch:  1 | Step:  0 | batch x:  [1. 2. 3. 4. 5. 6. 7. 8.] | batch y:  [10.  9.  8.  7.  6.  5.  4.  3.]
Epoch:  1 | Step:  1 | batch x:  [ 9. 10.] | batch y:  [2. 1.]
Epoch:  2 | Step:  0 | batch x:  [1. 2. 3. 4. 5. 6. 7. 8.] | batch y:  [10.  9.  8.  7.  6.  5.  4.  3.]
Epoch:  2 | Step:  1 | batch x:  [ 9. 10.] | batch y:  [2. 1.]

补充:pytorch批训练bug

问题描述:

在进行pytorch神经网络批训练的时候,有时会出现报错 

TypeError: batch must contain tensors, numbers, dicts or lists; found <class 'torch.autograd.variable.Variable'>

解决办法:

第一步:

检查(重点!!!!!):

train_dataset = Data.TensorDataset(train_x, train_y)

train_x,和train_y格式,要求是tensor类,我第一次出错就是因为传入的是variable

可以这样将数据变为tensor类:

train_x = torch.FloatTensor(train_x)

第二步:

train_loader = Data.DataLoader(
        dataset=train_dataset,
        batch_size=batch_size,
        shuffle=True
    )

实例化一个DataLoader对象

第三步:

for epoch in range(epochs):
        for step, (batch_x, batch_y) in enumerate(train_loader):
            batch_x, batch_y = Variable(batch_x), Variable(batch_y)

这样就可以批训练了

需要注意的是:train_loader输出的是tensor,在训练网络时,需要变成Variable

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python3.5仿微软计算器程序
Mar 30 Python
python实现上传下载文件功能
Nov 19 Python
Windows 7下Python Web环境搭建图文教程
Mar 20 Python
基于pandas将类别属性转化为数值属性的方法
Jul 25 Python
12个步骤教你理解Python装饰器
Jul 01 Python
python实现文本进度条 程序进度条 加载进度条 单行刷新功能
Jul 03 Python
导入tensorflow时报错:cannot import name 'abs'的解决
Oct 10 Python
appium+python自动化配置(adk、jdk、node.js)
Nov 17 Python
matplotlib运行时配置(Runtime Configuration,rc)参数rcParams解析
Jan 05 Python
Python爬取科目四考试题库的方法实现
Mar 30 Python
详解Python类和对象内容
Jun 22 Python
Python采集壁纸并实现炫轮播
Apr 30 Python
pytorch 如何使用batch训练lstm网络
May 28 #Python
使用Pytorch训练two-head网络的操作
May 28 #Python
使用Python的开发框架Brownie部署以太坊智能合约
使用Pytorch实现two-head(多输出)模型的操作
8g内存用python读取10文件_面试题-python 如何读取一个大于 10G 的txt文件?
用python画城市轮播地图
用Python实现一个打字速度测试工具来测试你的手速
You might like
星际争霸任务指南——虫族
2020/03/04 星际争霸
合作指挥官:孟斯克
2020/03/16 星际争霸
PHP写入WRITE编码为UTF8的文件的实现代码
2008/07/07 PHP
php实现文章评论系统
2019/02/18 PHP
JQuery Tips(2) 关于$()包装集你不知道的
2009/12/14 Javascript
jQuery EasyUI API 中文文档 DateTimeBox日期时间框
2011/10/16 Javascript
js同比例缩放图片的小例子
2013/10/30 Javascript
用JavaScript实现类似于ListBox功能示例代码
2014/03/09 Javascript
jQuery中常用的遍历函数用法实例总结
2015/09/01 Javascript
js中new一个对象的过程
2017/02/20 Javascript
jQuery+ajax实现局部刷新的两种方法
2017/06/08 jQuery
Vue CLI3搭建的项目中路径相关问题的解决
2018/09/17 Javascript
JavaScript使用表单元素验证表单的示例代码
2019/08/20 Javascript
javascript实现动态时钟的启动和停止
2020/07/29 Javascript
layUI的验证码功能及校验实例
2019/10/25 Javascript
实用的 vue tags 创建缓存导航的过程实现
2020/12/03 Vue.js
[04:48]DOTA2上海特锦赛小组赛第三日 TOP10精彩集锦
2016/02/28 DOTA
[01:56]无止竞 再出发——中国军团出征2017年DOTA2国际邀请赛
2017/07/05 DOTA
Python实现字典依据value排序
2016/02/24 Python
PyCharm安装第三方库如Requests的图文教程
2018/05/18 Python
Python实现的读写json文件功能示例
2018/06/05 Python
用python打印1~20的整数实例讲解
2019/07/01 Python
Django crontab定时任务模块操作方法解析
2020/09/10 Python
为什么要使用servlet
2016/01/17 面试题
大一期末自我鉴定
2013/12/13 职场文书
校园活动策划书范文
2014/01/10 职场文书
初一地理教学反思
2014/01/16 职场文书
个人先进材料范文
2014/12/30 职场文书
小班上学期个人总结
2015/02/12 职场文书
2015年大班保育员工作总结
2015/05/18 职场文书
幼儿园六一儿童节主持词
2015/06/30 职场文书
2015年十月一日放假通知
2015/08/18 职场文书
2016大学生就业指导课心得体会
2016/01/15 职场文书
六年级作文之预言作文
2019/10/25 职场文书
MySQL如何使用使用Xtrabackup进行备份和恢复
2021/06/21 MySQL
如何打开Win11系统注册表编辑器?Win11注册表编辑器打开修复方法
2022/04/05 数码科技