pytorch 6 batch_train 批训练操作


Posted in Python onMay 28, 2021

看代码吧~

import torch
import torch.utils.data as Data
torch.manual_seed(1)    # reproducible
# BATCH_SIZE = 5  
BATCH_SIZE = 8      # 每次使用8个数据同时传入网路
x = torch.linspace(1, 10, 10)       # this is x data (torch tensor)
y = torch.linspace(10, 1, 10)       # this is y data (torch tensor)
torch_dataset = Data.TensorDataset(x, y)
loader = Data.DataLoader(
    dataset=torch_dataset,      # torch TensorDataset format
    batch_size=BATCH_SIZE,      # mini batch size
    shuffle=False,              # 设置不随机打乱数据 random shuffle for training
    num_workers=2,              # 使用两个进程提取数据,subprocesses for loading data
)
def show_batch():
    for epoch in range(3):   # 全部的数据使用3遍,train entire dataset 3 times
        for step, (batch_x, batch_y) in enumerate(loader):  # for each training step
            # train your data...
            print('Epoch: ', epoch, '| Step: ', step, '| batch x: ',
                  batch_x.numpy(), '| batch y: ', batch_y.numpy())
if __name__ == '__main__':
    show_batch()

BATCH_SIZE = 8 , 所有数据利用三次

Epoch:  0 | Step:  0 | batch x:  [1. 2. 3. 4. 5. 6. 7. 8.] | batch y:  [10.  9.  8.  7.  6.  5.  4.  3.]
Epoch:  0 | Step:  1 | batch x:  [ 9. 10.] | batch y:  [2. 1.]
Epoch:  1 | Step:  0 | batch x:  [1. 2. 3. 4. 5. 6. 7. 8.] | batch y:  [10.  9.  8.  7.  6.  5.  4.  3.]
Epoch:  1 | Step:  1 | batch x:  [ 9. 10.] | batch y:  [2. 1.]
Epoch:  2 | Step:  0 | batch x:  [1. 2. 3. 4. 5. 6. 7. 8.] | batch y:  [10.  9.  8.  7.  6.  5.  4.  3.]
Epoch:  2 | Step:  1 | batch x:  [ 9. 10.] | batch y:  [2. 1.]

补充:pytorch批训练bug

问题描述:

在进行pytorch神经网络批训练的时候,有时会出现报错 

TypeError: batch must contain tensors, numbers, dicts or lists; found <class 'torch.autograd.variable.Variable'>

解决办法:

第一步:

检查(重点!!!!!):

train_dataset = Data.TensorDataset(train_x, train_y)

train_x,和train_y格式,要求是tensor类,我第一次出错就是因为传入的是variable

可以这样将数据变为tensor类:

train_x = torch.FloatTensor(train_x)

第二步:

train_loader = Data.DataLoader(
        dataset=train_dataset,
        batch_size=batch_size,
        shuffle=True
    )

实例化一个DataLoader对象

第三步:

for epoch in range(epochs):
        for step, (batch_x, batch_y) in enumerate(train_loader):
            batch_x, batch_y = Variable(batch_x), Variable(batch_y)

这样就可以批训练了

需要注意的是:train_loader输出的是tensor,在训练网络时,需要变成Variable

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python使用稀疏矩阵节省内存实例
Jun 27 Python
python实现批量图片格式转换
Jun 16 Python
使用Numpy读取CSV文件,并进行行列删除的操作方法
Jul 04 Python
对Python 两大环境管理神器 pyenv 和 virtualenv详解
Dec 31 Python
Python魔法方法功能与用法简介
Apr 04 Python
python ChainMap的使用和说明详解
Jun 11 Python
python如何基于redis实现ip代理池
Jan 17 Python
在django中使用apscheduler 执行计划任务的实现方法
Feb 11 Python
python实现QQ邮箱发送邮件
Mar 06 Python
django ListView的使用 ListView中获取url中的参数值方式
Mar 27 Python
Django如何重置migration的几种情景
Feb 24 Python
Python 处理表格进行成绩排序的操作代码
Jul 26 Python
pytorch 如何使用batch训练lstm网络
May 28 #Python
使用Pytorch训练two-head网络的操作
May 28 #Python
使用Python的开发框架Brownie部署以太坊智能合约
使用Pytorch实现two-head(多输出)模型的操作
8g内存用python读取10文件_面试题-python 如何读取一个大于 10G 的txt文件?
用python画城市轮播地图
用Python实现一个打字速度测试工具来测试你的手速
You might like
从一个不错的留言本弄的mysql数据库操作类
2007/09/02 PHP
WindowsXP中快速配置Apache+PHP5+Mysql
2008/06/05 PHP
PHP开发中常见的安全问题详解和解决方法(如Sql注入、CSRF、Xss、CC等)
2014/04/21 PHP
PHP编写简单的App接口
2016/08/28 PHP
php数值转换时间及时间转换数值用法示例
2017/05/18 PHP
PHP实现的mysql操作类【MySQL与MySQLi方式】
2017/10/07 PHP
PHP中的Iterator迭代对象属性详解
2019/04/12 PHP
jquery.validate使用攻略 第五步 正则验证
2010/07/01 Javascript
基于JQuery的6个Tab选项卡插件
2010/09/03 Javascript
Jquery优化效率 提升性能解决方案
2010/09/06 Javascript
JQuery获取当前屏幕的高度宽度的实现代码
2011/07/12 Javascript
基于dom编程中 动态创建与删除元素的使用
2013/04/17 Javascript
浅析JavaScript中的同名标识符优先级
2013/12/06 Javascript
Bootstrap编写一个在当前网页弹出可关闭的对话框 非弹窗
2016/06/30 Javascript
JS去除重复并统计数量的实现方法
2016/12/15 Javascript
Angular.js中window.onload(),$(document).ready()的写法浅析
2017/09/28 Javascript
详解vue+css3做交互特效的方法
2017/11/20 Javascript
vue+webpack 打包文件 404 页面空白的解决方法
2018/02/28 Javascript
Koa项目搭建过程详细记录
2018/04/12 Javascript
JS伪继承prototype实现方法示例
2018/06/20 Javascript
js实现input密码框显示/隐藏功能
2020/09/10 Javascript
vue props 一次传多个值实例
2020/07/22 Javascript
将Django使用的数据库从MySQL迁移到PostgreSQL的教程
2015/04/11 Python
Selenium控制浏览器常见操作示例
2018/08/13 Python
Python项目跨域问题解决方案
2020/06/22 Python
HTML5 Canvas如何实现纹理填充与描边(Fill And Stroke)
2013/07/15 HTML / CSS
Janie and Jack美国官网:GAP旗下的高档童装品牌
2019/09/09 全球购物
美国翻新电子产品商店:The Store
2019/10/08 全球购物
杭州-DOTNET笔试题集
2013/09/25 面试题
医学院四年学习生活的自我评价
2013/11/06 职场文书
给市场的环保建议书
2014/05/14 职场文书
小学学校门卫岗位职责
2014/08/03 职场文书
颂军魂爱军营演讲稿
2014/09/13 职场文书
Ajax 的初步实现(使用vscode+node.js+express框架)
2021/06/18 Javascript
如何解决php-fpm启动不了问题
2021/11/17 PHP
MySQL中的全表扫描和索引树扫描
2022/05/15 MySQL