Pytorch 神经网络—自定义数据集上实现教程


Posted in Python onJanuary 07, 2020

第一步、导入需要的包

import os
import scipy.io as sio
import numpy as np
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
from torch.autograd import Variable
batchSize = 128 # batchsize的大小
niter = 10   # epoch的最大值

第二步、构建神经网络

Pytorch 神经网络—自定义数据集上实现教程

设神经网络为如上图所示,输入层4个神经元,两层隐含层各4个神经元,输出层一个神经。每一层网络所做的都是线性变换,即y=W×X+b;代码实现如下:

class Neuralnetwork(nn.Module):
  def __init__(self, in_dim, n_hidden_1, n_hidden_2, out_dim):
    super(Neuralnetwork, self).__init__()
    self.layer1 = nn.Linear(in_dim, n_hidden_1)
    self.layer2 = nn.Linear(n_hidden_1, n_hidden_2)
    self.layer3 = nn.Linear(n_hidden_2, out_dim)
 
  def forward(self, x):
    x = x.view(x.size(0), -1)
    x = self.layer1(x)
    x = self.layer2(x)
    x = self.layer3(x)
    return x
 
model = Neuralnetwork(1*3, 4, 4, 1)
 
print(model) # net architecture
Neuralnetwork(
 (layer1): Linear(in_features=3, out_features=4, bias=True)
 (layer2): Linear(in_features=4, out_features=4, bias=True)
 (layer3): Linear(in_features=4, out_features=1, bias=True)
)

​​ 第三步、读取数据

自定义的数据为demo_SBPFea.mat,是MATLAB保存的数据格式,其存储的内容如下:包括fea(1000*3)和sbp(1000*1)两个数组;fea为特征向量,行为样本数,列为特征宽度;sbp为标签

Pytorch 神经网络—自定义数据集上实现教程

class SBPEstimateDataset(Dataset):
 
  def __init__(self, ext='demo'):
  
    data = sio.loadmat(ext+'_SBPFea.mat')
    self.fea = data['fea']
    self.sbp = data['sbp']
    
  def __len__(self):
    
    return len(self.sbp)
 
  def __getitem__(self, idx):
 
    fea = self.fea[idx]
    sbp = self.sbp[idx]
    """Convert ndarrays to Tensors."""
    return {'fea': torch.from_numpy(fea).float(),
        'sbp': torch.from_numpy(sbp).float()
        }
    
train_dataset = SBPEstimateDataset(ext='demo')
train_loader = DataLoader(train_dataset, batch_size=batchSize, # 分批次训练
             shuffle=True, num_workers=int(8))

整个数据样本为1000,以batchSize = 128划分,分为8份,前7份为104个样本,第8份则为104个样本。在网络训练过程中,是一份数据一份数据进行训练的

第四步、模型训练

# 优化器,Adam 
optimizer = optim.Adam(list(model.parameters()), lr=0.0001, betas=(0.9, 0.999),weight_decay=0.004) 
scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.997) 
criterion = nn.MSELoss() # loss function 
 
if torch.cuda.is_available(): # 有GPU,则用GPU计算
   model.cuda() 
   criterion.cuda() 
 
for epoch in range(niter): 
   losses = [] 
   ERROR_Train = [] 
   model.train() 
   for i, data in enumerate(train_loader, 0): 
     model.zero_grad()# 首先提取清零 
     real_cpu, label_cpu = data['fea'], data['sbp'] 
 
     if torch.cuda.is_available():# CUDA可用情况下,将Tensor 在GPU上运行 
       real_cpu = real_cpu.cuda() 
       label_cpu = label_cpu.cuda() 
 
 
       input=real_cpu 
       label=label_cpu 
 
       inputv = Variable(input) 
       labelv = Variable(label) 
 
       output = model(inputv) 
       err = criterion(output, labelv) 
       err.backward() 
       optimizer.step() 
 
       losses.append(err.data[0]) 
 
       error = output.data-label+ 1e-12 
       ERROR_Train.extend(error) 
 
   MAE = np.average(np.abs(np.array(ERROR_Train))) 
   ME = np.average(np.array(ERROR_Train)) 
   STD = np.std(np.array(ERROR_Train)) 
 
   print('[%d/%d] Loss: %.4f MAE: %.4f Mean Error: %.4f STD: %.4f' % ( 
   epoch, niter, np.average(losses), MAE, ME, STD))
​​
[0/10] Loss: 18384.6699 MAE: 135.3871 Mean Error: -135.3871 STD: 7.5580
[1/10] Loss: 17063.0215 MAE: 130.4145 Mean Error: -130.4145 STD: 7.8918
[2/10] Loss: 13689.1934 MAE: 116.6625 Mean Error: -116.6625 STD: 9.7946
[3/10] Loss: 8192.9053 MAE: 89.6611 Mean Error: -89.6611 STD: 12.9911
[4/10] Loss: 2979.1340 MAE: 52.5410 Mean Error: -52.5279 STD: 15.0930
[5/10] Loss: 599.7094 MAE: 22.2735 Mean Error: -19.9979 STD: 14.2069
[6/10] Loss: 207.2831 MAE: 11.2394 Mean Error: -4.8821 STD: 13.5528
[7/10] Loss: 189.8173 MAE: 9.8020 Mean Error: -1.2357 STD: 13.7095
[8/10] Loss: 188.3376 MAE: 9.6512 Mean Error: -0.6498 STD: 13.7075
[9/10] Loss: 186.8393 MAE: 9.6946 Mean Error: -1.0850 STD: 13.6332​

以上这篇Pytorch 神经网络—自定义数据集上实现教程就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python 的 Socket 编程
Mar 24 Python
Python中的字符串替换操作示例
Jun 27 Python
Pandas DataFrame 取一行数据会得到Series的方法
Nov 10 Python
python处理两种分隔符的数据集方法
Dec 12 Python
python基于TCP实现的文件下载器功能案例
Dec 10 Python
Python unittest工作原理和使用过程解析
Feb 24 Python
Python 序列化和反序列化库 MarshMallow 的用法实例代码
Feb 25 Python
python3实现往mysql中插入datetime类型的数据
Mar 02 Python
Python 面向对象部分知识点小结
Mar 09 Python
python中get和post有什么区别
Jun 19 Python
python基于socket模拟实现ssh远程执行命令
Dec 05 Python
Python快速实现一键抠图功能的全过程
Jun 29 Python
浅谈Python访问MySQL的正确姿势
Jan 07 #Python
pytorch自定义二值化网络层方式
Jan 07 #Python
Pytorch: 自定义网络层实例
Jan 07 #Python
Python StringIO如何在内存中读写str
Jan 07 #Python
Python内置数据类型list各方法的性能测试过程解析
Jan 07 #Python
python模拟实现斗地主发牌
Jan 07 #Python
python全局变量引用与修改过程解析
Jan 07 #Python
You might like
php FPDF类库应用实现代码
2009/03/20 PHP
thinkPHP5框架分页样式类完整示例
2018/09/01 PHP
PDO::exec讲解
2019/01/28 PHP
(currentStyle)javascript为何有时用style得不到已设定的CSS的属性
2007/08/15 Javascript
JavaScript中的console.time()函数详细介绍
2014/12/29 Javascript
程序员必知35个jQuery 代码片段
2015/11/05 Javascript
JavaScript实现的微信二维码图片生成器的示例
2016/10/26 Javascript
javascript读取文本节点方法小结
2016/12/15 Javascript
jQuery动态产生select option下拉列表
2017/03/15 Javascript
javascript 数据存储的常用函数总结
2017/06/01 Javascript
简述jQuery Easyui一些用法
2017/08/01 jQuery
Bootstrap Paginator+PageHelper实现分页效果
2018/12/29 Javascript
JavaScript查看代码运行效率console.time()与console.timeEnd()用法
2019/01/18 Javascript
小程序如何构建骨架屏
2019/05/29 Javascript
tweenjs缓动算法的使用实例分析
2019/08/26 Javascript
python分析网页上所有超链接的方法
2015/05/08 Python
详解Python各大聊天系统的屏蔽脏话功能原理
2016/12/01 Python
python 寻找list中最大元素对应的索引方法
2018/06/28 Python
Python使用Selenium模块实现模拟浏览器抓取淘宝商品美食信息功能示例
2018/07/18 Python
对Pytorch中nn.ModuleList 和 nn.Sequential详解
2019/08/18 Python
Python reduce函数作用及实例解析
2020/05/08 Python
css3新增颜色表示方式分享
2014/04/15 HTML / CSS
解决html5中video标签无法播放mp4问题的办法
2017/05/07 HTML / CSS
经济实惠的豪华背包和行李袋:Packs Project
2018/10/17 全球购物
几个常见的消息中间件(MOM)
2014/01/08 面试题
如果重写了对象的equals()方法,需要考虑什么
2014/11/02 面试题
化学教师自荐信范文
2013/12/28 职场文书
预备党员党课思想汇报
2014/01/13 职场文书
服务标兵事迹材料
2014/05/04 职场文书
超市促销活动总结
2014/07/01 职场文书
社区个人对照检查材料(群众路线)
2014/09/26 职场文书
老公保证书
2015/01/17 职场文书
沈阳故宫导游词
2015/01/31 职场文书
独生子女证明范本
2015/06/19 职场文书
昆虫记读书笔记
2015/06/26 职场文书
HTML静态页面获取url参数和UserAgent的实现
2022/08/05 HTML / CSS