Pytorch 神经网络—自定义数据集上实现教程


Posted in Python onJanuary 07, 2020

第一步、导入需要的包

import os
import scipy.io as sio
import numpy as np
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
from torch.autograd import Variable
batchSize = 128 # batchsize的大小
niter = 10   # epoch的最大值

第二步、构建神经网络

Pytorch 神经网络—自定义数据集上实现教程

设神经网络为如上图所示,输入层4个神经元,两层隐含层各4个神经元,输出层一个神经。每一层网络所做的都是线性变换,即y=W×X+b;代码实现如下:

class Neuralnetwork(nn.Module):
  def __init__(self, in_dim, n_hidden_1, n_hidden_2, out_dim):
    super(Neuralnetwork, self).__init__()
    self.layer1 = nn.Linear(in_dim, n_hidden_1)
    self.layer2 = nn.Linear(n_hidden_1, n_hidden_2)
    self.layer3 = nn.Linear(n_hidden_2, out_dim)
 
  def forward(self, x):
    x = x.view(x.size(0), -1)
    x = self.layer1(x)
    x = self.layer2(x)
    x = self.layer3(x)
    return x
 
model = Neuralnetwork(1*3, 4, 4, 1)
 
print(model) # net architecture
Neuralnetwork(
 (layer1): Linear(in_features=3, out_features=4, bias=True)
 (layer2): Linear(in_features=4, out_features=4, bias=True)
 (layer3): Linear(in_features=4, out_features=1, bias=True)
)

​​ 第三步、读取数据

自定义的数据为demo_SBPFea.mat,是MATLAB保存的数据格式,其存储的内容如下:包括fea(1000*3)和sbp(1000*1)两个数组;fea为特征向量,行为样本数,列为特征宽度;sbp为标签

Pytorch 神经网络—自定义数据集上实现教程

class SBPEstimateDataset(Dataset):
 
  def __init__(self, ext='demo'):
  
    data = sio.loadmat(ext+'_SBPFea.mat')
    self.fea = data['fea']
    self.sbp = data['sbp']
    
  def __len__(self):
    
    return len(self.sbp)
 
  def __getitem__(self, idx):
 
    fea = self.fea[idx]
    sbp = self.sbp[idx]
    """Convert ndarrays to Tensors."""
    return {'fea': torch.from_numpy(fea).float(),
        'sbp': torch.from_numpy(sbp).float()
        }
    
train_dataset = SBPEstimateDataset(ext='demo')
train_loader = DataLoader(train_dataset, batch_size=batchSize, # 分批次训练
             shuffle=True, num_workers=int(8))

整个数据样本为1000,以batchSize = 128划分,分为8份,前7份为104个样本,第8份则为104个样本。在网络训练过程中,是一份数据一份数据进行训练的

第四步、模型训练

# 优化器,Adam 
optimizer = optim.Adam(list(model.parameters()), lr=0.0001, betas=(0.9, 0.999),weight_decay=0.004) 
scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.997) 
criterion = nn.MSELoss() # loss function 
 
if torch.cuda.is_available(): # 有GPU,则用GPU计算
   model.cuda() 
   criterion.cuda() 
 
for epoch in range(niter): 
   losses = [] 
   ERROR_Train = [] 
   model.train() 
   for i, data in enumerate(train_loader, 0): 
     model.zero_grad()# 首先提取清零 
     real_cpu, label_cpu = data['fea'], data['sbp'] 
 
     if torch.cuda.is_available():# CUDA可用情况下,将Tensor 在GPU上运行 
       real_cpu = real_cpu.cuda() 
       label_cpu = label_cpu.cuda() 
 
 
       input=real_cpu 
       label=label_cpu 
 
       inputv = Variable(input) 
       labelv = Variable(label) 
 
       output = model(inputv) 
       err = criterion(output, labelv) 
       err.backward() 
       optimizer.step() 
 
       losses.append(err.data[0]) 
 
       error = output.data-label+ 1e-12 
       ERROR_Train.extend(error) 
 
   MAE = np.average(np.abs(np.array(ERROR_Train))) 
   ME = np.average(np.array(ERROR_Train)) 
   STD = np.std(np.array(ERROR_Train)) 
 
   print('[%d/%d] Loss: %.4f MAE: %.4f Mean Error: %.4f STD: %.4f' % ( 
   epoch, niter, np.average(losses), MAE, ME, STD))
​​
[0/10] Loss: 18384.6699 MAE: 135.3871 Mean Error: -135.3871 STD: 7.5580
[1/10] Loss: 17063.0215 MAE: 130.4145 Mean Error: -130.4145 STD: 7.8918
[2/10] Loss: 13689.1934 MAE: 116.6625 Mean Error: -116.6625 STD: 9.7946
[3/10] Loss: 8192.9053 MAE: 89.6611 Mean Error: -89.6611 STD: 12.9911
[4/10] Loss: 2979.1340 MAE: 52.5410 Mean Error: -52.5279 STD: 15.0930
[5/10] Loss: 599.7094 MAE: 22.2735 Mean Error: -19.9979 STD: 14.2069
[6/10] Loss: 207.2831 MAE: 11.2394 Mean Error: -4.8821 STD: 13.5528
[7/10] Loss: 189.8173 MAE: 9.8020 Mean Error: -1.2357 STD: 13.7095
[8/10] Loss: 188.3376 MAE: 9.6512 Mean Error: -0.6498 STD: 13.7075
[9/10] Loss: 186.8393 MAE: 9.6946 Mean Error: -1.0850 STD: 13.6332​

以上这篇Pytorch 神经网络—自定义数据集上实现教程就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python下的Mysql模块MySQLdb安装详解
Apr 09 Python
Python中字典的基本知识初步介绍
May 21 Python
Django基于ORM操作数据库的方法详解
Mar 27 Python
Python使用re模块实现信息筛选的方法
Apr 29 Python
python斐波那契数列的计算方法
Sep 27 Python
Python字符串的常见操作实例小结
Apr 08 Python
如何使用django的MTV开发模式返回一个网页
Jul 22 Python
Django实现简单网页弹出警告代码
Nov 15 Python
Python基于Dlib的人脸识别系统的实现
Feb 26 Python
使用python求斐波那契数列中第n个数的值示例代码
Jul 26 Python
Python3 用什么IDE开发工具比较好
Nov 28 Python
pandas抽取行列数据的几种方法
Dec 13 Python
浅谈Python访问MySQL的正确姿势
Jan 07 #Python
pytorch自定义二值化网络层方式
Jan 07 #Python
Pytorch: 自定义网络层实例
Jan 07 #Python
Python StringIO如何在内存中读写str
Jan 07 #Python
Python内置数据类型list各方法的性能测试过程解析
Jan 07 #Python
python模拟实现斗地主发牌
Jan 07 #Python
python全局变量引用与修改过程解析
Jan 07 #Python
You might like
深入理解PHP原理之异常机制
2010/08/21 PHP
php中一个完整表单处理实现代码
2011/11/10 PHP
PHP中模拟处理HTTP PUT请求的例子
2014/07/22 PHP
PHP Hash算法:Times33算法代码实例
2015/05/13 PHP
利用php实现一周之内自动登录存储机制(cookie、session、localStorage)
2016/10/31 PHP
利用javascript移动div层-javascript 拖动层
2009/03/22 Javascript
JS判定是否原生方法
2013/07/22 Javascript
原生JS实现加入收藏夹的代码
2013/10/24 Javascript
js使用正则实现ReplaceAll全部替换的方法
2014/07/18 Javascript
jQuery实现默认是闭合的FAQ展开效果菜单
2015/09/14 Javascript
js获取及修改网页背景色和字体色的方法
2015/12/29 Javascript
深入学习jQuery Validate表单验证(二)
2016/01/18 Javascript
基于JavaScript实现文字超出部分隐藏
2016/02/29 Javascript
Vue.js每天必学之组件与组件间的通信
2016/09/08 Javascript
什么是JavaScript注入攻击?
2016/09/14 Javascript
ES6新特征数字、数组、字符串
2016/10/01 Javascript
JavaScrpt中如何使用 cookie 设置查看与删除功能
2017/07/09 Javascript
jQuery 开发之EasyUI 添加数据的实例
2017/09/26 jQuery
浅谈webpack4 图片处理汇总
2018/09/12 Javascript
Vue组件内部实现一个双向数据绑定的实例代码
2019/04/04 Javascript
如何用vue-cli3脚手架搭建一个基于ts的基础脚手架的方法
2019/12/12 Javascript
python中查找excel某一列的重复数据 剔除之后打印
2013/02/10 Python
详解在Python程序中使用Cookie的教程
2015/04/30 Python
详解Python编程中time模块的使用
2015/11/20 Python
python实现多线程的方式及多条命令并发执行
2016/06/07 Python
Python基于socket实现简单的即时通讯功能示例
2018/01/16 Python
python机器学习理论与实战(六)支持向量机
2018/01/19 Python
PyTorch的深度学习入门之PyTorch安装和配置
2019/06/27 Python
python 偷懒技巧——使用 keyboard 录制键盘事件
2020/09/21 Python
canvas版人体时钟的实现示例
2021/01/29 HTML / CSS
工业设计专业推荐信
2013/10/29 职场文书
冰淇淋店创业计划书范文
2013/12/27 职场文书
党员作风建设自查报告
2014/10/23 职场文书
大二学年个人总结
2015/03/03 职场文书
新郎接新娘保证书
2015/05/08 职场文书
2015年六年级班主任工作总结
2015/10/15 职场文书