pytorch 6 batch_train 批训练操作


Posted in Python onMay 28, 2021

看代码吧~

import torch
import torch.utils.data as Data
torch.manual_seed(1)    # reproducible
# BATCH_SIZE = 5  
BATCH_SIZE = 8      # 每次使用8个数据同时传入网路
x = torch.linspace(1, 10, 10)       # this is x data (torch tensor)
y = torch.linspace(10, 1, 10)       # this is y data (torch tensor)
torch_dataset = Data.TensorDataset(x, y)
loader = Data.DataLoader(
    dataset=torch_dataset,      # torch TensorDataset format
    batch_size=BATCH_SIZE,      # mini batch size
    shuffle=False,              # 设置不随机打乱数据 random shuffle for training
    num_workers=2,              # 使用两个进程提取数据,subprocesses for loading data
)
def show_batch():
    for epoch in range(3):   # 全部的数据使用3遍,train entire dataset 3 times
        for step, (batch_x, batch_y) in enumerate(loader):  # for each training step
            # train your data...
            print('Epoch: ', epoch, '| Step: ', step, '| batch x: ',
                  batch_x.numpy(), '| batch y: ', batch_y.numpy())
if __name__ == '__main__':
    show_batch()

BATCH_SIZE = 8 , 所有数据利用三次

Epoch:  0 | Step:  0 | batch x:  [1. 2. 3. 4. 5. 6. 7. 8.] | batch y:  [10.  9.  8.  7.  6.  5.  4.  3.]
Epoch:  0 | Step:  1 | batch x:  [ 9. 10.] | batch y:  [2. 1.]
Epoch:  1 | Step:  0 | batch x:  [1. 2. 3. 4. 5. 6. 7. 8.] | batch y:  [10.  9.  8.  7.  6.  5.  4.  3.]
Epoch:  1 | Step:  1 | batch x:  [ 9. 10.] | batch y:  [2. 1.]
Epoch:  2 | Step:  0 | batch x:  [1. 2. 3. 4. 5. 6. 7. 8.] | batch y:  [10.  9.  8.  7.  6.  5.  4.  3.]
Epoch:  2 | Step:  1 | batch x:  [ 9. 10.] | batch y:  [2. 1.]

补充:pytorch批训练bug

问题描述:

在进行pytorch神经网络批训练的时候,有时会出现报错 

TypeError: batch must contain tensors, numbers, dicts or lists; found <class 'torch.autograd.variable.Variable'>

解决办法:

第一步:

检查(重点!!!!!):

train_dataset = Data.TensorDataset(train_x, train_y)

train_x,和train_y格式,要求是tensor类,我第一次出错就是因为传入的是variable

可以这样将数据变为tensor类:

train_x = torch.FloatTensor(train_x)

第二步:

train_loader = Data.DataLoader(
        dataset=train_dataset,
        batch_size=batch_size,
        shuffle=True
    )

实例化一个DataLoader对象

第三步:

for epoch in range(epochs):
        for step, (batch_x, batch_y) in enumerate(train_loader):
            batch_x, batch_y = Variable(batch_x), Variable(batch_y)

这样就可以批训练了

需要注意的是:train_loader输出的是tensor,在训练网络时,需要变成Variable

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python处理json字符串转化为字典的简单实现
Jul 07 Python
python生成1行四列全2矩阵的方法
Aug 04 Python
在Qt5和PyQt5中设置支持高分辨率屏幕自适应的方法
Jun 18 Python
python3用PIL把图片转换为RGB图片的实例
Jul 04 Python
Python中list循环遍历删除数据的正确方法
Sep 02 Python
Python 使用元类type创建类对象常见应用详解
Oct 17 Python
利用Tensorflow构建和训练自己的CNN来做简单的验证码识别方式
Jan 20 Python
配置python的编程环境之Anaconda + VSCode的教程
Mar 29 Python
Python进行特征提取的示例代码
Oct 15 Python
Python中Permission denied的解决方案
Apr 02 Python
Python打包为exe详细教程
May 18 Python
python用海龟绘图写贪吃蛇游戏
Jun 18 Python
pytorch 如何使用batch训练lstm网络
May 28 #Python
使用Pytorch训练two-head网络的操作
May 28 #Python
使用Python的开发框架Brownie部署以太坊智能合约
使用Pytorch实现two-head(多输出)模型的操作
8g内存用python读取10文件_面试题-python 如何读取一个大于 10G 的txt文件?
用python画城市轮播地图
用Python实现一个打字速度测试工具来测试你的手速
You might like
zen_cart实现支付前生成订单的方法
2016/05/06 PHP
jquery 简短几句代码实现给元素动态添加及获取提示信息
2011/09/01 Javascript
javascript动态加载实现方法一
2012/08/22 Javascript
使用jQuery实现的网页版的个人简历(可换肤)
2013/04/19 Javascript
javascript中如何处理引号编码&amp;#034;
2013/08/15 Javascript
jQuery 绑定事件到动态创建的元素上的方法实例
2013/08/18 Javascript
JavaScript中一个奇葩的IE浏览器判断方法
2014/04/16 Javascript
js控制网页背景音乐播放与停止的方法
2015/02/06 Javascript
JavaScript获取当前网页标题(title)的方法
2015/04/03 Javascript
Bootstrap编写一个兼容主流浏览器的受众门户式风格页面
2016/07/01 Javascript
AngularJS中关于ng-class指令的几种实现方式详解
2016/09/17 Javascript
JS克隆,属性,数组,对象,函数实例分析
2016/11/26 Javascript
js 递归和定时器的实例解析
2017/02/03 Javascript
JavaScript中重名的函数与对象示例详析
2017/09/28 Javascript
Vue2.0实现调用摄像头进行拍照功能 exif.js实现图片上传功能
2018/04/28 Javascript
vue 监听屏幕高度的实例
2018/09/05 Javascript
详解Vue.js在页面加载时执行某个方法
2018/11/20 Javascript
微信小程序实现基于三元运算验证手机号/姓名功能示例
2019/01/19 Javascript
vue项目打包上传github并制作预览链接(pages)
2019/04/19 Javascript
详解微信小程序胶囊按钮返回|首页自定义导航栏功能
2019/06/14 Javascript
window下eclipse安装python插件教程
2017/04/24 Python
python实现图片二值化及灰度处理方式
2019/12/07 Python
使用numpngw和matplotlib生成png动画的示例代码
2021/01/24 Python
床上用品全球在线购物:BeddingInn
2016/12/18 全球购物
国际领先的在线时尚服装和配饰店:DressLily
2019/03/03 全球购物
字符串str除首尾字符外的其他字符按升序排列
2013/03/08 面试题
毕业生动漫设计求职信
2013/10/11 职场文书
料理师求职信
2014/01/30 职场文书
理工类毕业自我鉴定
2014/02/20 职场文书
庆祝教师节标语
2014/10/09 职场文书
幼儿园小班见习报告
2014/10/31 职场文书
2014年银行个人工作总结
2014/12/05 职场文书
销售内勤岗位职责
2015/02/10 职场文书
沂蒙六姐妹观后感
2015/06/08 职场文书
叶问观后感
2015/06/15 职场文书
学校趣味运动会开幕词
2016/03/04 职场文书