Pytorch 神经网络—自定义数据集上实现教程


Posted in Python onJanuary 07, 2020

第一步、导入需要的包

import os
import scipy.io as sio
import numpy as np
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
from torch.autograd import Variable
batchSize = 128 # batchsize的大小
niter = 10   # epoch的最大值

第二步、构建神经网络

Pytorch 神经网络—自定义数据集上实现教程

设神经网络为如上图所示,输入层4个神经元,两层隐含层各4个神经元,输出层一个神经。每一层网络所做的都是线性变换,即y=W×X+b;代码实现如下:

class Neuralnetwork(nn.Module):
  def __init__(self, in_dim, n_hidden_1, n_hidden_2, out_dim):
    super(Neuralnetwork, self).__init__()
    self.layer1 = nn.Linear(in_dim, n_hidden_1)
    self.layer2 = nn.Linear(n_hidden_1, n_hidden_2)
    self.layer3 = nn.Linear(n_hidden_2, out_dim)
 
  def forward(self, x):
    x = x.view(x.size(0), -1)
    x = self.layer1(x)
    x = self.layer2(x)
    x = self.layer3(x)
    return x
 
model = Neuralnetwork(1*3, 4, 4, 1)
 
print(model) # net architecture
Neuralnetwork(
 (layer1): Linear(in_features=3, out_features=4, bias=True)
 (layer2): Linear(in_features=4, out_features=4, bias=True)
 (layer3): Linear(in_features=4, out_features=1, bias=True)
)

​​ 第三步、读取数据

自定义的数据为demo_SBPFea.mat,是MATLAB保存的数据格式,其存储的内容如下:包括fea(1000*3)和sbp(1000*1)两个数组;fea为特征向量,行为样本数,列为特征宽度;sbp为标签

Pytorch 神经网络—自定义数据集上实现教程

class SBPEstimateDataset(Dataset):
 
  def __init__(self, ext='demo'):
  
    data = sio.loadmat(ext+'_SBPFea.mat')
    self.fea = data['fea']
    self.sbp = data['sbp']
    
  def __len__(self):
    
    return len(self.sbp)
 
  def __getitem__(self, idx):
 
    fea = self.fea[idx]
    sbp = self.sbp[idx]
    """Convert ndarrays to Tensors."""
    return {'fea': torch.from_numpy(fea).float(),
        'sbp': torch.from_numpy(sbp).float()
        }
    
train_dataset = SBPEstimateDataset(ext='demo')
train_loader = DataLoader(train_dataset, batch_size=batchSize, # 分批次训练
             shuffle=True, num_workers=int(8))

整个数据样本为1000,以batchSize = 128划分,分为8份,前7份为104个样本,第8份则为104个样本。在网络训练过程中,是一份数据一份数据进行训练的

第四步、模型训练

# 优化器,Adam 
optimizer = optim.Adam(list(model.parameters()), lr=0.0001, betas=(0.9, 0.999),weight_decay=0.004) 
scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.997) 
criterion = nn.MSELoss() # loss function 
 
if torch.cuda.is_available(): # 有GPU,则用GPU计算
   model.cuda() 
   criterion.cuda() 
 
for epoch in range(niter): 
   losses = [] 
   ERROR_Train = [] 
   model.train() 
   for i, data in enumerate(train_loader, 0): 
     model.zero_grad()# 首先提取清零 
     real_cpu, label_cpu = data['fea'], data['sbp'] 
 
     if torch.cuda.is_available():# CUDA可用情况下,将Tensor 在GPU上运行 
       real_cpu = real_cpu.cuda() 
       label_cpu = label_cpu.cuda() 
 
 
       input=real_cpu 
       label=label_cpu 
 
       inputv = Variable(input) 
       labelv = Variable(label) 
 
       output = model(inputv) 
       err = criterion(output, labelv) 
       err.backward() 
       optimizer.step() 
 
       losses.append(err.data[0]) 
 
       error = output.data-label+ 1e-12 
       ERROR_Train.extend(error) 
 
   MAE = np.average(np.abs(np.array(ERROR_Train))) 
   ME = np.average(np.array(ERROR_Train)) 
   STD = np.std(np.array(ERROR_Train)) 
 
   print('[%d/%d] Loss: %.4f MAE: %.4f Mean Error: %.4f STD: %.4f' % ( 
   epoch, niter, np.average(losses), MAE, ME, STD))
​​
[0/10] Loss: 18384.6699 MAE: 135.3871 Mean Error: -135.3871 STD: 7.5580
[1/10] Loss: 17063.0215 MAE: 130.4145 Mean Error: -130.4145 STD: 7.8918
[2/10] Loss: 13689.1934 MAE: 116.6625 Mean Error: -116.6625 STD: 9.7946
[3/10] Loss: 8192.9053 MAE: 89.6611 Mean Error: -89.6611 STD: 12.9911
[4/10] Loss: 2979.1340 MAE: 52.5410 Mean Error: -52.5279 STD: 15.0930
[5/10] Loss: 599.7094 MAE: 22.2735 Mean Error: -19.9979 STD: 14.2069
[6/10] Loss: 207.2831 MAE: 11.2394 Mean Error: -4.8821 STD: 13.5528
[7/10] Loss: 189.8173 MAE: 9.8020 Mean Error: -1.2357 STD: 13.7095
[8/10] Loss: 188.3376 MAE: 9.6512 Mean Error: -0.6498 STD: 13.7075
[9/10] Loss: 186.8393 MAE: 9.6946 Mean Error: -1.0850 STD: 13.6332​

以上这篇Pytorch 神经网络—自定义数据集上实现教程就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python 用Redis简单实现分布式爬虫的方法
Nov 23 Python
Python语言描述KNN算法与Kd树
Dec 13 Python
python 读取视频,处理后,实时计算帧数fps的方法
Jul 10 Python
python实现学员管理系统
Feb 26 Python
Python发送邮件的实例代码讲解
Oct 16 Python
在Python中利用pickle保存变量的实例
Dec 30 Python
python集合删除多种方法详解
Feb 10 Python
使用keras和tensorflow保存为可部署的pb格式
May 25 Python
浅谈PyTorch中in-place operation的含义
Jun 27 Python
基于CentOS搭建Python Django环境过程解析
Aug 24 Python
pycharm不以pytest方式运行,想要切换回普通模式运行的操作
Sep 01 Python
Jupyter Notebook内使用argparse报错的解决方案
Jun 03 Python
浅谈Python访问MySQL的正确姿势
Jan 07 #Python
pytorch自定义二值化网络层方式
Jan 07 #Python
Pytorch: 自定义网络层实例
Jan 07 #Python
Python StringIO如何在内存中读写str
Jan 07 #Python
Python内置数据类型list各方法的性能测试过程解析
Jan 07 #Python
python模拟实现斗地主发牌
Jan 07 #Python
python全局变量引用与修改过程解析
Jan 07 #Python
You might like
PHP新手上路(十二)
2006/10/09 PHP
CI框架实现cookie登陆的方法详解
2016/05/18 PHP
JavaScript 全面解析各种浏览器网页中的JS 执行顺序
2009/02/17 Javascript
JavaScript iframe的相互操作浅析
2009/10/14 Javascript
原生js实现shift/ctrl/alt按键的获取
2013/04/08 Javascript
jQuery实现灰蓝风格标准二级下拉菜单效果代码
2015/08/31 Javascript
jQuery实现伪分页的方法分享
2016/02/17 Javascript
功能强大的Bootstrap效果展示(二)
2016/08/03 Javascript
引用jquery框架后出错的解决方法
2016/08/09 Javascript
JavaScript中的对象和原型(一)
2016/08/12 Javascript
JS实现的Unicode编码转换操作示例
2017/04/28 Javascript
解决webpack打包速度慢的解决办法汇总
2017/07/06 Javascript
详解使用Typescript开发node.js项目(简单的环境配置)
2017/10/09 Javascript
对vue里函数的调用顺序介绍
2018/03/17 Javascript
vue服务端渲染缓存应用详解
2018/09/12 Javascript
微信小程序修改数组长度的问题的解决
2019/12/17 Javascript
在Vue.js中使用TypeScript的方法
2020/03/19 Javascript
vant组件中 dialog的确认按钮的回调事件操作
2020/11/04 Javascript
Python获取脚本所在目录的正确方法
2014/04/15 Python
基于Python实现的扫雷游戏实例代码
2014/08/01 Python
浅析python协程相关概念
2018/01/20 Python
利用python numpy+matplotlib绘制股票k线图的方法
2019/06/26 Python
matlab 计算灰度图像的一阶矩,二阶矩,三阶矩实例
2020/04/22 Python
互斥锁解决 Python 中多线程共享全局变量的问题(推荐)
2020/09/28 Python
英国最大的婴儿监视器网上商店:Baby Monitors Direct
2018/04/24 全球购物
美国正宗设计师眼镜在线零售商:EYEZZ
2019/03/23 全球购物
阿迪达斯印尼官方网站:adidas印尼
2020/02/10 全球购物
自荐书模板
2013/12/15 职场文书
语文高效课堂实施方案
2014/05/03 职场文书
白血病募捐倡议书
2014/05/14 职场文书
普通话演讲稿
2014/09/03 职场文书
无房产证房屋转让协议书合同样本
2014/10/18 职场文书
部队2014年终工作总结
2014/11/27 职场文书
创业计划书之家教中心
2019/09/25 职场文书
vue-router中hash模式与history模式的区别
2021/06/23 Vue.js
Redis 异步机制
2022/05/15 Redis