Pytorch 神经网络—自定义数据集上实现教程


Posted in Python onJanuary 07, 2020

第一步、导入需要的包

import os
import scipy.io as sio
import numpy as np
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
from torch.autograd import Variable
batchSize = 128 # batchsize的大小
niter = 10   # epoch的最大值

第二步、构建神经网络

Pytorch 神经网络—自定义数据集上实现教程

设神经网络为如上图所示,输入层4个神经元,两层隐含层各4个神经元,输出层一个神经。每一层网络所做的都是线性变换,即y=W×X+b;代码实现如下:

class Neuralnetwork(nn.Module):
  def __init__(self, in_dim, n_hidden_1, n_hidden_2, out_dim):
    super(Neuralnetwork, self).__init__()
    self.layer1 = nn.Linear(in_dim, n_hidden_1)
    self.layer2 = nn.Linear(n_hidden_1, n_hidden_2)
    self.layer3 = nn.Linear(n_hidden_2, out_dim)
 
  def forward(self, x):
    x = x.view(x.size(0), -1)
    x = self.layer1(x)
    x = self.layer2(x)
    x = self.layer3(x)
    return x
 
model = Neuralnetwork(1*3, 4, 4, 1)
 
print(model) # net architecture
Neuralnetwork(
 (layer1): Linear(in_features=3, out_features=4, bias=True)
 (layer2): Linear(in_features=4, out_features=4, bias=True)
 (layer3): Linear(in_features=4, out_features=1, bias=True)
)

​​ 第三步、读取数据

自定义的数据为demo_SBPFea.mat,是MATLAB保存的数据格式,其存储的内容如下:包括fea(1000*3)和sbp(1000*1)两个数组;fea为特征向量,行为样本数,列为特征宽度;sbp为标签

Pytorch 神经网络—自定义数据集上实现教程

class SBPEstimateDataset(Dataset):
 
  def __init__(self, ext='demo'):
  
    data = sio.loadmat(ext+'_SBPFea.mat')
    self.fea = data['fea']
    self.sbp = data['sbp']
    
  def __len__(self):
    
    return len(self.sbp)
 
  def __getitem__(self, idx):
 
    fea = self.fea[idx]
    sbp = self.sbp[idx]
    """Convert ndarrays to Tensors."""
    return {'fea': torch.from_numpy(fea).float(),
        'sbp': torch.from_numpy(sbp).float()
        }
    
train_dataset = SBPEstimateDataset(ext='demo')
train_loader = DataLoader(train_dataset, batch_size=batchSize, # 分批次训练
             shuffle=True, num_workers=int(8))

整个数据样本为1000,以batchSize = 128划分,分为8份,前7份为104个样本,第8份则为104个样本。在网络训练过程中,是一份数据一份数据进行训练的

第四步、模型训练

# 优化器,Adam 
optimizer = optim.Adam(list(model.parameters()), lr=0.0001, betas=(0.9, 0.999),weight_decay=0.004) 
scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.997) 
criterion = nn.MSELoss() # loss function 
 
if torch.cuda.is_available(): # 有GPU,则用GPU计算
   model.cuda() 
   criterion.cuda() 
 
for epoch in range(niter): 
   losses = [] 
   ERROR_Train = [] 
   model.train() 
   for i, data in enumerate(train_loader, 0): 
     model.zero_grad()# 首先提取清零 
     real_cpu, label_cpu = data['fea'], data['sbp'] 
 
     if torch.cuda.is_available():# CUDA可用情况下,将Tensor 在GPU上运行 
       real_cpu = real_cpu.cuda() 
       label_cpu = label_cpu.cuda() 
 
 
       input=real_cpu 
       label=label_cpu 
 
       inputv = Variable(input) 
       labelv = Variable(label) 
 
       output = model(inputv) 
       err = criterion(output, labelv) 
       err.backward() 
       optimizer.step() 
 
       losses.append(err.data[0]) 
 
       error = output.data-label+ 1e-12 
       ERROR_Train.extend(error) 
 
   MAE = np.average(np.abs(np.array(ERROR_Train))) 
   ME = np.average(np.array(ERROR_Train)) 
   STD = np.std(np.array(ERROR_Train)) 
 
   print('[%d/%d] Loss: %.4f MAE: %.4f Mean Error: %.4f STD: %.4f' % ( 
   epoch, niter, np.average(losses), MAE, ME, STD))
​​
[0/10] Loss: 18384.6699 MAE: 135.3871 Mean Error: -135.3871 STD: 7.5580
[1/10] Loss: 17063.0215 MAE: 130.4145 Mean Error: -130.4145 STD: 7.8918
[2/10] Loss: 13689.1934 MAE: 116.6625 Mean Error: -116.6625 STD: 9.7946
[3/10] Loss: 8192.9053 MAE: 89.6611 Mean Error: -89.6611 STD: 12.9911
[4/10] Loss: 2979.1340 MAE: 52.5410 Mean Error: -52.5279 STD: 15.0930
[5/10] Loss: 599.7094 MAE: 22.2735 Mean Error: -19.9979 STD: 14.2069
[6/10] Loss: 207.2831 MAE: 11.2394 Mean Error: -4.8821 STD: 13.5528
[7/10] Loss: 189.8173 MAE: 9.8020 Mean Error: -1.2357 STD: 13.7095
[8/10] Loss: 188.3376 MAE: 9.6512 Mean Error: -0.6498 STD: 13.7075
[9/10] Loss: 186.8393 MAE: 9.6946 Mean Error: -1.0850 STD: 13.6332​

以上这篇Pytorch 神经网络—自定义数据集上实现教程就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
写了个监控nginx进程的Python脚本
May 10 Python
在Windows8上的搭建Python和Django环境
Jul 03 Python
python模块之StringIO使用示例
Apr 08 Python
深入解析Python中的__builtins__内建对象
Jun 21 Python
Python 如何访问外围作用域中的变量
Sep 11 Python
python如何生成各种随机分布图
Aug 27 Python
Python3删除排序数组中重复项的方法分析
Jan 31 Python
python操作日志的封装方法(两种方法)
May 23 Python
Python使用sklearn实现的各种回归算法示例
Jul 04 Python
python 动态渲染 mysql 配置文件的示例
Nov 20 Python
Pytorch反向传播中的细节-计算梯度时的默认累加操作
Jun 05 Python
python字符串的一些常见实用操作
Apr 06 Python
浅谈Python访问MySQL的正确姿势
Jan 07 #Python
pytorch自定义二值化网络层方式
Jan 07 #Python
Pytorch: 自定义网络层实例
Jan 07 #Python
Python StringIO如何在内存中读写str
Jan 07 #Python
Python内置数据类型list各方法的性能测试过程解析
Jan 07 #Python
python模拟实现斗地主发牌
Jan 07 #Python
python全局变量引用与修改过程解析
Jan 07 #Python
You might like
怎样辨别一杯好咖啡
2021/03/03 新手入门
php empty,isset,is_null判断比较(差异与异同)
2010/10/19 PHP
phpstrom使用xdebug配置方法
2013/12/17 PHP
PHP将URL转换成短网址的算法分享
2016/09/13 PHP
CI框架封装的常用图像处理方法(缩略图,水印,旋转,上传等)
2016/11/22 PHP
PHP数组对象与Json转换操作实例分析
2019/10/22 PHP
thinkphp框架类库扩展操作示例
2019/11/26 PHP
基于JQuery的asp.net树实现代码
2010/11/30 Javascript
善用事件代理,警惕闭包的性能陷阱。
2011/01/20 Javascript
一个简单的网站访问JS计数器 刷新1次加1次访问
2012/09/20 Javascript
document节点对象的获取方式示例介绍
2013/12/24 Javascript
js实现的map方法示例代码
2014/01/13 Javascript
JavaScript编程的10个实用小技巧
2014/04/18 Javascript
cookie的secure属性详解
2015/04/08 Javascript
微信小程序 wxapp内容组件 progress详细介绍
2016/10/31 Javascript
Bootstrap 中data-[*] 属性的整理
2018/03/13 Javascript
三分钟教你用Node做一个微信哄女友(基友)神器(面向小白)
2019/06/21 Javascript
vue 实现Web端的定位功能 获取经纬度
2019/08/08 Javascript
jQuery 筛选器简单操作示例
2019/10/02 jQuery
Layui实现主窗口和Iframe层参数传递
2019/11/14 Javascript
Python中尝试多线程编程的一个简明例子
2015/04/07 Python
Python中list初始化方法示例
2016/09/18 Python
pandas groupby 分组取每组的前几行记录方法
2018/04/20 Python
celery4+django2定时任务的实现代码
2018/12/23 Python
python,Django实现的淘宝客登录功能示例
2019/06/12 Python
Django实现文件上传下载功能
2019/10/06 Python
Python模拟登录requests.Session应用详解
2020/11/17 Python
深入解析HTML5使用SVG图像时的viewBox属性用法
2015/09/02 HTML / CSS
写一个方法1000的阶乘
2012/11/21 面试题
计算机本科生自荐信
2013/10/15 职场文书
化学教育专业自荐信
2014/07/04 职场文书
介绍信模板
2015/01/31 职场文书
售后前台接待岗位职责
2015/04/03 职场文书
2016年主题党日活动总结
2016/04/05 职场文书
一条 SQL 语句执行过程
2022/03/17 MySQL
《吸血鬼幸存者》新内容发布 追加多个全新模式
2022/04/07 其他游戏