pytorch 6 batch_train 批训练操作


Posted in Python onMay 28, 2021

看代码吧~

import torch
import torch.utils.data as Data
torch.manual_seed(1)    # reproducible
# BATCH_SIZE = 5  
BATCH_SIZE = 8      # 每次使用8个数据同时传入网路
x = torch.linspace(1, 10, 10)       # this is x data (torch tensor)
y = torch.linspace(10, 1, 10)       # this is y data (torch tensor)
torch_dataset = Data.TensorDataset(x, y)
loader = Data.DataLoader(
    dataset=torch_dataset,      # torch TensorDataset format
    batch_size=BATCH_SIZE,      # mini batch size
    shuffle=False,              # 设置不随机打乱数据 random shuffle for training
    num_workers=2,              # 使用两个进程提取数据,subprocesses for loading data
)
def show_batch():
    for epoch in range(3):   # 全部的数据使用3遍,train entire dataset 3 times
        for step, (batch_x, batch_y) in enumerate(loader):  # for each training step
            # train your data...
            print('Epoch: ', epoch, '| Step: ', step, '| batch x: ',
                  batch_x.numpy(), '| batch y: ', batch_y.numpy())
if __name__ == '__main__':
    show_batch()

BATCH_SIZE = 8 , 所有数据利用三次

Epoch:  0 | Step:  0 | batch x:  [1. 2. 3. 4. 5. 6. 7. 8.] | batch y:  [10.  9.  8.  7.  6.  5.  4.  3.]
Epoch:  0 | Step:  1 | batch x:  [ 9. 10.] | batch y:  [2. 1.]
Epoch:  1 | Step:  0 | batch x:  [1. 2. 3. 4. 5. 6. 7. 8.] | batch y:  [10.  9.  8.  7.  6.  5.  4.  3.]
Epoch:  1 | Step:  1 | batch x:  [ 9. 10.] | batch y:  [2. 1.]
Epoch:  2 | Step:  0 | batch x:  [1. 2. 3. 4. 5. 6. 7. 8.] | batch y:  [10.  9.  8.  7.  6.  5.  4.  3.]
Epoch:  2 | Step:  1 | batch x:  [ 9. 10.] | batch y:  [2. 1.]

补充:pytorch批训练bug

问题描述:

在进行pytorch神经网络批训练的时候,有时会出现报错 

TypeError: batch must contain tensors, numbers, dicts or lists; found <class 'torch.autograd.variable.Variable'>

解决办法:

第一步:

检查(重点!!!!!):

train_dataset = Data.TensorDataset(train_x, train_y)

train_x,和train_y格式,要求是tensor类,我第一次出错就是因为传入的是variable

可以这样将数据变为tensor类:

train_x = torch.FloatTensor(train_x)

第二步:

train_loader = Data.DataLoader(
        dataset=train_dataset,
        batch_size=batch_size,
        shuffle=True
    )

实例化一个DataLoader对象

第三步:

for epoch in range(epochs):
        for step, (batch_x, batch_y) in enumerate(train_loader):
            batch_x, batch_y = Variable(batch_x), Variable(batch_y)

这样就可以批训练了

需要注意的是:train_loader输出的是tensor,在训练网络时,需要变成Variable

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python contextlib模块使用示例
Feb 18 Python
Python实现SVN的目录周期性备份实例
Jul 17 Python
Django验证码的生成与使用示例
May 20 Python
python enumerate函数的使用方法总结
Nov 15 Python
Python实现App自动签到领取积分功能
Sep 29 Python
python matplotlib实现双Y轴的实例
Feb 12 Python
pandas读取CSV文件时查看修改各列的数据类型格式
Jul 07 Python
Python shelve模块实现解析
Aug 28 Python
python 使用opencv 把视频分割成图片示例
Dec 12 Python
python飞机大战pygame游戏背景设计详解
Dec 17 Python
python中pyqtgraph知识点总结
Jan 26 Python
python 使用openpyxl读取excel数据
Feb 18 Python
pytorch 如何使用batch训练lstm网络
May 28 #Python
使用Pytorch训练two-head网络的操作
May 28 #Python
使用Python的开发框架Brownie部署以太坊智能合约
使用Pytorch实现two-head(多输出)模型的操作
8g内存用python读取10文件_面试题-python 如何读取一个大于 10G 的txt文件?
用python画城市轮播地图
用Python实现一个打字速度测试工具来测试你的手速
You might like
GD输出汉字的函数的分析
2006/10/09 PHP
Windows下IIS6/Apache2.2.4+MySQL5.2+PHP5.2.1安装配置方法
2007/05/03 PHP
php使用glob函数快速查询指定目录文件的方法
2014/11/15 PHP
常用的php图片处理类(水印、等比缩放、固定高宽)分享
2015/06/19 PHP
PHP命名空间namespace的定义方法详解
2017/03/29 PHP
云网广告中的代码,提示出错,大家找找
2006/11/21 Javascript
使用jQuery fancybox插件打造一个实用的数据传输模态弹出窗体
2013/01/15 Javascript
jquery自定义函数的多种方法
2014/01/09 Javascript
详解JavaScript语法对{}处理的坑爹之处
2014/06/05 Javascript
javascript实现点击后变换按钮显示文字的方法
2015/05/13 Javascript
浅谈JavaScript中的Math.atan()方法的使用
2015/06/14 Javascript
Bootstrap学习笔记之css组件(3)
2016/06/07 Javascript
JQueryEasyUI框架下的combobox的取值和绑定的方法
2017/01/22 Javascript
node.js基于fs模块对系统文件及目录进行读写操作的方法详解
2017/11/10 Javascript
JS实现去除数组中重复json的方法示例
2017/12/21 Javascript
js实现删除json中指定的元素
2020/09/22 Javascript
跟老齐学Python之一个免费的实验室
2014/09/14 Python
Python操作列表之List.insert()方法的使用
2015/05/20 Python
Python中装饰器兼容加括号和不加括号的写法详解
2017/07/05 Python
flask使用session保存登录状态及拦截未登录请求代码
2018/01/19 Python
使用pyecharts生成Echarts网页的实例
2019/08/12 Python
python实现复制大量文件功能
2019/08/31 Python
Pycharm远程连接服务器并实现代码同步上传更新功能
2020/02/25 Python
Python基于httpx模块实现发送请求
2020/07/07 Python
Python 实现将某一列设置为str类型
2020/07/14 Python
详解使用scrapy进行模拟登陆三种方式
2021/02/21 Python
移动端适配 使px自动转换rem
2019/08/26 HTML / CSS
html5小程序飞入购物车(抛物线绘制运动轨迹点)
2020/10/19 HTML / CSS
加拿大购物频道:The Shopping Channel
2016/07/21 全球购物
Notino意大利:购买香水和化妆品
2018/11/14 全球购物
AOP的定义以及作用
2013/09/08 面试题
银行求职信
2014/05/31 职场文书
2014党支部对照检查材料思想汇报
2014/10/05 职场文书
乡镇干部个人整改措施思想汇报
2014/10/10 职场文书
技术转让协议书
2016/03/19 职场文书
Python实战实现爬取天气数据并完成可视化分析详解
2022/06/16 Python