Pytorch 神经网络—自定义数据集上实现教程


Posted in Python onJanuary 07, 2020

第一步、导入需要的包

import os
import scipy.io as sio
import numpy as np
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
from torch.autograd import Variable
batchSize = 128 # batchsize的大小
niter = 10   # epoch的最大值

第二步、构建神经网络

Pytorch 神经网络—自定义数据集上实现教程

设神经网络为如上图所示,输入层4个神经元,两层隐含层各4个神经元,输出层一个神经。每一层网络所做的都是线性变换,即y=W×X+b;代码实现如下:

class Neuralnetwork(nn.Module):
  def __init__(self, in_dim, n_hidden_1, n_hidden_2, out_dim):
    super(Neuralnetwork, self).__init__()
    self.layer1 = nn.Linear(in_dim, n_hidden_1)
    self.layer2 = nn.Linear(n_hidden_1, n_hidden_2)
    self.layer3 = nn.Linear(n_hidden_2, out_dim)
 
  def forward(self, x):
    x = x.view(x.size(0), -1)
    x = self.layer1(x)
    x = self.layer2(x)
    x = self.layer3(x)
    return x
 
model = Neuralnetwork(1*3, 4, 4, 1)
 
print(model) # net architecture
Neuralnetwork(
 (layer1): Linear(in_features=3, out_features=4, bias=True)
 (layer2): Linear(in_features=4, out_features=4, bias=True)
 (layer3): Linear(in_features=4, out_features=1, bias=True)
)

​​ 第三步、读取数据

自定义的数据为demo_SBPFea.mat,是MATLAB保存的数据格式,其存储的内容如下:包括fea(1000*3)和sbp(1000*1)两个数组;fea为特征向量,行为样本数,列为特征宽度;sbp为标签

Pytorch 神经网络—自定义数据集上实现教程

class SBPEstimateDataset(Dataset):
 
  def __init__(self, ext='demo'):
  
    data = sio.loadmat(ext+'_SBPFea.mat')
    self.fea = data['fea']
    self.sbp = data['sbp']
    
  def __len__(self):
    
    return len(self.sbp)
 
  def __getitem__(self, idx):
 
    fea = self.fea[idx]
    sbp = self.sbp[idx]
    """Convert ndarrays to Tensors."""
    return {'fea': torch.from_numpy(fea).float(),
        'sbp': torch.from_numpy(sbp).float()
        }
    
train_dataset = SBPEstimateDataset(ext='demo')
train_loader = DataLoader(train_dataset, batch_size=batchSize, # 分批次训练
             shuffle=True, num_workers=int(8))

整个数据样本为1000,以batchSize = 128划分,分为8份,前7份为104个样本,第8份则为104个样本。在网络训练过程中,是一份数据一份数据进行训练的

第四步、模型训练

# 优化器,Adam 
optimizer = optim.Adam(list(model.parameters()), lr=0.0001, betas=(0.9, 0.999),weight_decay=0.004) 
scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.997) 
criterion = nn.MSELoss() # loss function 
 
if torch.cuda.is_available(): # 有GPU,则用GPU计算
   model.cuda() 
   criterion.cuda() 
 
for epoch in range(niter): 
   losses = [] 
   ERROR_Train = [] 
   model.train() 
   for i, data in enumerate(train_loader, 0): 
     model.zero_grad()# 首先提取清零 
     real_cpu, label_cpu = data['fea'], data['sbp'] 
 
     if torch.cuda.is_available():# CUDA可用情况下,将Tensor 在GPU上运行 
       real_cpu = real_cpu.cuda() 
       label_cpu = label_cpu.cuda() 
 
 
       input=real_cpu 
       label=label_cpu 
 
       inputv = Variable(input) 
       labelv = Variable(label) 
 
       output = model(inputv) 
       err = criterion(output, labelv) 
       err.backward() 
       optimizer.step() 
 
       losses.append(err.data[0]) 
 
       error = output.data-label+ 1e-12 
       ERROR_Train.extend(error) 
 
   MAE = np.average(np.abs(np.array(ERROR_Train))) 
   ME = np.average(np.array(ERROR_Train)) 
   STD = np.std(np.array(ERROR_Train)) 
 
   print('[%d/%d] Loss: %.4f MAE: %.4f Mean Error: %.4f STD: %.4f' % ( 
   epoch, niter, np.average(losses), MAE, ME, STD))
​​
[0/10] Loss: 18384.6699 MAE: 135.3871 Mean Error: -135.3871 STD: 7.5580
[1/10] Loss: 17063.0215 MAE: 130.4145 Mean Error: -130.4145 STD: 7.8918
[2/10] Loss: 13689.1934 MAE: 116.6625 Mean Error: -116.6625 STD: 9.7946
[3/10] Loss: 8192.9053 MAE: 89.6611 Mean Error: -89.6611 STD: 12.9911
[4/10] Loss: 2979.1340 MAE: 52.5410 Mean Error: -52.5279 STD: 15.0930
[5/10] Loss: 599.7094 MAE: 22.2735 Mean Error: -19.9979 STD: 14.2069
[6/10] Loss: 207.2831 MAE: 11.2394 Mean Error: -4.8821 STD: 13.5528
[7/10] Loss: 189.8173 MAE: 9.8020 Mean Error: -1.2357 STD: 13.7095
[8/10] Loss: 188.3376 MAE: 9.6512 Mean Error: -0.6498 STD: 13.7075
[9/10] Loss: 186.8393 MAE: 9.6946 Mean Error: -1.0850 STD: 13.6332​

以上这篇Pytorch 神经网络—自定义数据集上实现教程就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python自动安装pip
Apr 24 Python
Python远程桌面协议RDPY安装使用介绍
Apr 15 Python
python中json格式数据输出的简单实现方法
Oct 31 Python
tensorflow实现对图片的读取的示例代码
Feb 12 Python
python机器学习之KNN分类算法
Aug 29 Python
Python Datetime模块和Calendar模块用法实例分析
Apr 15 Python
Python3从零开始搭建一个语音对话机器人的实现
Aug 23 Python
使用OpenCV获取图片连通域数量,并用不同颜色标记函
Jun 04 Python
了解一下python内建模块collections
Sep 07 Python
基于python实现坦克大战游戏
Oct 27 Python
python如何修改文件时间属性
Feb 05 Python
使用Python获取字典键对应值的方法
Apr 26 Python
浅谈Python访问MySQL的正确姿势
Jan 07 #Python
pytorch自定义二值化网络层方式
Jan 07 #Python
Pytorch: 自定义网络层实例
Jan 07 #Python
Python StringIO如何在内存中读写str
Jan 07 #Python
Python内置数据类型list各方法的性能测试过程解析
Jan 07 #Python
python模拟实现斗地主发牌
Jan 07 #Python
python全局变量引用与修改过程解析
Jan 07 #Python
You might like
PHP中使用gettext来支持多语言的方法
2011/05/02 PHP
Yii中使用PHPExcel导出Excel的方法
2014/12/26 PHP
php实现批量上传数据到数据库(.csv格式)的案例
2017/06/18 PHP
PHP正则匹配操作简单示例【preg_match_all应用】
2017/07/10 PHP
NodeJS框架Express的模板视图机制分析
2011/07/19 NodeJs
关于jquery css的使用介绍
2013/04/18 Javascript
仿当当网淘宝网等主流电子商务网站商品分类导航菜单
2013/09/25 Javascript
JS数组去重与取重的示例代码
2014/01/24 Javascript
详解JS正则replace的使用方法
2016/03/06 Javascript
酷! 不同风格页面布局幻灯片特效js实现
2021/02/19 Javascript
微信小程序 教程之注册页面
2016/10/17 Javascript
详解Vue 非父子组件通信方法(非Vuex)
2017/05/24 Javascript
JS获取子、父、兄节点方法小结
2017/08/14 Javascript
vue鼠标移入添加class样式,鼠标移出去除样式(active)实现方法
2018/08/22 Javascript
nvm、nrm、npm 安装和使用详解(小结)
2019/01/17 Javascript
vue 导航菜单刷新状态不消失,显示对应的路由界面操作
2020/08/06 Javascript
解决vue项目运行提示Warnings while compiling.警告的问题
2020/09/18 Javascript
Python中声明只包含一个元素的元组数据方法
2014/08/25 Python
Django框架下在URLconf中指定视图缓存的方法
2015/07/23 Python
对于Python中RawString的理解介绍
2016/07/07 Python
使用Python对Excel进行读写操作
2017/03/30 Python
TensorFlow实现简单卷积神经网络
2018/05/24 Python
Python3使用Matplotlib 绘制精美的数学函数图形
2019/04/11 Python
selenium2.0中常用的python函数汇总
2019/08/05 Python
HTML5 通过Vedio标签实现视频循环播放的示例代码
2020/08/05 HTML / CSS
英国知名的护肤彩妆与时尚配饰大型综合零售电商:Unineed
2016/11/21 全球购物
东芝官网商城:还原日式美学,打造美好生活
2018/12/27 全球购物
NFL加拿大官方网上商店:NHLShop.ca
2019/03/12 全球购物
美国在线购买内衣网站:HerRoom
2020/02/22 全球购物
应用化学专业职业生涯规划书
2013/12/31 职场文书
幼儿教师考核制度
2014/01/25 职场文书
临床医学专业求职信
2014/08/08 职场文书
培训班通知
2015/04/25 职场文书
《确定位置》教学反思
2016/02/18 职场文书
Django+Nginx+uWSGI 定时任务的实现方法
2022/01/22 Python
CSS的calc函数用法小结
2022/06/25 HTML / CSS