pytorch 6 batch_train 批训练操作


Posted in Python onMay 28, 2021

看代码吧~

import torch
import torch.utils.data as Data
torch.manual_seed(1)    # reproducible
# BATCH_SIZE = 5  
BATCH_SIZE = 8      # 每次使用8个数据同时传入网路
x = torch.linspace(1, 10, 10)       # this is x data (torch tensor)
y = torch.linspace(10, 1, 10)       # this is y data (torch tensor)
torch_dataset = Data.TensorDataset(x, y)
loader = Data.DataLoader(
    dataset=torch_dataset,      # torch TensorDataset format
    batch_size=BATCH_SIZE,      # mini batch size
    shuffle=False,              # 设置不随机打乱数据 random shuffle for training
    num_workers=2,              # 使用两个进程提取数据,subprocesses for loading data
)
def show_batch():
    for epoch in range(3):   # 全部的数据使用3遍,train entire dataset 3 times
        for step, (batch_x, batch_y) in enumerate(loader):  # for each training step
            # train your data...
            print('Epoch: ', epoch, '| Step: ', step, '| batch x: ',
                  batch_x.numpy(), '| batch y: ', batch_y.numpy())
if __name__ == '__main__':
    show_batch()

BATCH_SIZE = 8 , 所有数据利用三次

Epoch:  0 | Step:  0 | batch x:  [1. 2. 3. 4. 5. 6. 7. 8.] | batch y:  [10.  9.  8.  7.  6.  5.  4.  3.]
Epoch:  0 | Step:  1 | batch x:  [ 9. 10.] | batch y:  [2. 1.]
Epoch:  1 | Step:  0 | batch x:  [1. 2. 3. 4. 5. 6. 7. 8.] | batch y:  [10.  9.  8.  7.  6.  5.  4.  3.]
Epoch:  1 | Step:  1 | batch x:  [ 9. 10.] | batch y:  [2. 1.]
Epoch:  2 | Step:  0 | batch x:  [1. 2. 3. 4. 5. 6. 7. 8.] | batch y:  [10.  9.  8.  7.  6.  5.  4.  3.]
Epoch:  2 | Step:  1 | batch x:  [ 9. 10.] | batch y:  [2. 1.]

补充:pytorch批训练bug

问题描述:

在进行pytorch神经网络批训练的时候,有时会出现报错 

TypeError: batch must contain tensors, numbers, dicts or lists; found <class 'torch.autograd.variable.Variable'>

解决办法:

第一步:

检查(重点!!!!!):

train_dataset = Data.TensorDataset(train_x, train_y)

train_x,和train_y格式,要求是tensor类,我第一次出错就是因为传入的是variable

可以这样将数据变为tensor类:

train_x = torch.FloatTensor(train_x)

第二步:

train_loader = Data.DataLoader(
        dataset=train_dataset,
        batch_size=batch_size,
        shuffle=True
    )

实例化一个DataLoader对象

第三步:

for epoch in range(epochs):
        for step, (batch_x, batch_y) in enumerate(train_loader):
            batch_x, batch_y = Variable(batch_x), Variable(batch_y)

这样就可以批训练了

需要注意的是:train_loader输出的是tensor,在训练网络时,需要变成Variable

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python的lambda匿名函数的简单介绍
Apr 25 Python
Python代理抓取并验证使用多线程实现
May 03 Python
Python tkinter模块中类继承的三种方式分析
Aug 08 Python
Python cookbook(数据结构与算法)让字典保持有序的方法
Feb 18 Python
python中partial()基础用法说明
Dec 30 Python
对python判断ip是否可达的实例详解
Jan 31 Python
对Python获取屏幕截图的4种方法详解
Aug 27 Python
Python加密模块的hashlib,hmac模块使用解析
Jan 02 Python
python多维数组分位数的求取方式
Mar 03 Python
python模块如何查看
Jun 16 Python
python更新数据库中某个字段的数据(方法详解)
Nov 18 Python
Python各协议下socket黏包问题原理
Apr 12 Python
pytorch 如何使用batch训练lstm网络
May 28 #Python
使用Pytorch训练two-head网络的操作
May 28 #Python
使用Python的开发框架Brownie部署以太坊智能合约
使用Pytorch实现two-head(多输出)模型的操作
8g内存用python读取10文件_面试题-python 如何读取一个大于 10G 的txt文件?
用python画城市轮播地图
用Python实现一个打字速度测试工具来测试你的手速
You might like
基于Zend的Captcha机制的应用
2013/05/02 PHP
php 中文字符串首字母的获取函数分享
2013/11/04 PHP
PHP中substr函数字符串截取用法分析
2016/01/07 PHP
php实现微信原生支付(扫码支付)功能
2018/05/30 PHP
js自定义事件及事件交互原理概述(一)
2013/02/01 Javascript
深入理解JSON数据源格式
2014/01/10 Javascript
常用的JavaScript WEB操作方法分享
2015/02/28 Javascript
javascript性能优化之事件委托实例详解
2015/12/12 Javascript
JavaScript实现页面无操作倒计时退出
2016/10/22 Javascript
微信小程序 网络API发起请求详解
2016/11/09 Javascript
前端 Vue.js 和 MVVM 详细介绍
2016/12/29 Javascript
AngularJS框架中的双向数据绑定机制详解【减少需要重复的开发代码量】
2017/01/19 Javascript
如何用JS/HTML将时间戳转换为“xx天前”的形式
2017/02/06 Javascript
JS触摸事件、手势事件详解
2017/05/04 Javascript
AngularJS封装$http.post()实例详解
2017/05/06 Javascript
JavaScript循环_动力节点Java学院整理
2017/06/28 Javascript
详解JavaScript中的六种错误类型
2017/09/21 Javascript
Angular6封装http请求的步骤详解
2018/08/13 Javascript
Vue 第三方字体图标引入 Font Awesome的方法
2018/09/28 Javascript
vue+Element实现搜索关键字高亮功能
2019/05/28 Javascript
Angular之jwt令牌身份验证的实现
2020/02/14 Javascript
[01:14]TI珍贵瞬间系列(六):冠军
2020/08/30 DOTA
推荐11个实用Python库
2015/01/23 Python
Python实现按特定格式对文件进行读写的方法示例
2017/11/30 Python
Pytorch中膨胀卷积的用法详解
2020/01/07 Python
基于SpringBoot构造器注入循环依赖及解决方式
2020/04/26 Python
Django 用户登陆访问限制实例 @login_required
2020/05/13 Python
python和go语言的区别是什么
2020/07/20 Python
医学院毕业生自荐信
2013/11/08 职场文书
领导班子党的群众路线教育实践活动对照检查材料
2014/09/25 职场文书
2014年小学生迎国庆65周年演讲稿
2014/09/27 职场文书
2014年帮扶工作总结
2014/11/26 职场文书
少年雷锋观后感
2015/06/10 职场文书
如何利用golang运用mysql数据库
2022/03/13 Golang
TV动画《间谍过家家》公开PV
2022/03/20 日漫
Mysql使用全文索引(FullText index)的实例代码
2022/04/03 MySQL