Pytorch 神经网络—自定义数据集上实现教程


Posted in Python onJanuary 07, 2020

第一步、导入需要的包

import os
import scipy.io as sio
import numpy as np
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
from torch.autograd import Variable
batchSize = 128 # batchsize的大小
niter = 10   # epoch的最大值

第二步、构建神经网络

Pytorch 神经网络—自定义数据集上实现教程

设神经网络为如上图所示,输入层4个神经元,两层隐含层各4个神经元,输出层一个神经。每一层网络所做的都是线性变换,即y=W×X+b;代码实现如下:

class Neuralnetwork(nn.Module):
  def __init__(self, in_dim, n_hidden_1, n_hidden_2, out_dim):
    super(Neuralnetwork, self).__init__()
    self.layer1 = nn.Linear(in_dim, n_hidden_1)
    self.layer2 = nn.Linear(n_hidden_1, n_hidden_2)
    self.layer3 = nn.Linear(n_hidden_2, out_dim)
 
  def forward(self, x):
    x = x.view(x.size(0), -1)
    x = self.layer1(x)
    x = self.layer2(x)
    x = self.layer3(x)
    return x
 
model = Neuralnetwork(1*3, 4, 4, 1)
 
print(model) # net architecture
Neuralnetwork(
 (layer1): Linear(in_features=3, out_features=4, bias=True)
 (layer2): Linear(in_features=4, out_features=4, bias=True)
 (layer3): Linear(in_features=4, out_features=1, bias=True)
)

​​ 第三步、读取数据

自定义的数据为demo_SBPFea.mat,是MATLAB保存的数据格式,其存储的内容如下:包括fea(1000*3)和sbp(1000*1)两个数组;fea为特征向量,行为样本数,列为特征宽度;sbp为标签

Pytorch 神经网络—自定义数据集上实现教程

class SBPEstimateDataset(Dataset):
 
  def __init__(self, ext='demo'):
  
    data = sio.loadmat(ext+'_SBPFea.mat')
    self.fea = data['fea']
    self.sbp = data['sbp']
    
  def __len__(self):
    
    return len(self.sbp)
 
  def __getitem__(self, idx):
 
    fea = self.fea[idx]
    sbp = self.sbp[idx]
    """Convert ndarrays to Tensors."""
    return {'fea': torch.from_numpy(fea).float(),
        'sbp': torch.from_numpy(sbp).float()
        }
    
train_dataset = SBPEstimateDataset(ext='demo')
train_loader = DataLoader(train_dataset, batch_size=batchSize, # 分批次训练
             shuffle=True, num_workers=int(8))

整个数据样本为1000,以batchSize = 128划分,分为8份,前7份为104个样本,第8份则为104个样本。在网络训练过程中,是一份数据一份数据进行训练的

第四步、模型训练

# 优化器,Adam 
optimizer = optim.Adam(list(model.parameters()), lr=0.0001, betas=(0.9, 0.999),weight_decay=0.004) 
scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.997) 
criterion = nn.MSELoss() # loss function 
 
if torch.cuda.is_available(): # 有GPU,则用GPU计算
   model.cuda() 
   criterion.cuda() 
 
for epoch in range(niter): 
   losses = [] 
   ERROR_Train = [] 
   model.train() 
   for i, data in enumerate(train_loader, 0): 
     model.zero_grad()# 首先提取清零 
     real_cpu, label_cpu = data['fea'], data['sbp'] 
 
     if torch.cuda.is_available():# CUDA可用情况下,将Tensor 在GPU上运行 
       real_cpu = real_cpu.cuda() 
       label_cpu = label_cpu.cuda() 
 
 
       input=real_cpu 
       label=label_cpu 
 
       inputv = Variable(input) 
       labelv = Variable(label) 
 
       output = model(inputv) 
       err = criterion(output, labelv) 
       err.backward() 
       optimizer.step() 
 
       losses.append(err.data[0]) 
 
       error = output.data-label+ 1e-12 
       ERROR_Train.extend(error) 
 
   MAE = np.average(np.abs(np.array(ERROR_Train))) 
   ME = np.average(np.array(ERROR_Train)) 
   STD = np.std(np.array(ERROR_Train)) 
 
   print('[%d/%d] Loss: %.4f MAE: %.4f Mean Error: %.4f STD: %.4f' % ( 
   epoch, niter, np.average(losses), MAE, ME, STD))
​​
[0/10] Loss: 18384.6699 MAE: 135.3871 Mean Error: -135.3871 STD: 7.5580
[1/10] Loss: 17063.0215 MAE: 130.4145 Mean Error: -130.4145 STD: 7.8918
[2/10] Loss: 13689.1934 MAE: 116.6625 Mean Error: -116.6625 STD: 9.7946
[3/10] Loss: 8192.9053 MAE: 89.6611 Mean Error: -89.6611 STD: 12.9911
[4/10] Loss: 2979.1340 MAE: 52.5410 Mean Error: -52.5279 STD: 15.0930
[5/10] Loss: 599.7094 MAE: 22.2735 Mean Error: -19.9979 STD: 14.2069
[6/10] Loss: 207.2831 MAE: 11.2394 Mean Error: -4.8821 STD: 13.5528
[7/10] Loss: 189.8173 MAE: 9.8020 Mean Error: -1.2357 STD: 13.7095
[8/10] Loss: 188.3376 MAE: 9.6512 Mean Error: -0.6498 STD: 13.7075
[9/10] Loss: 186.8393 MAE: 9.6946 Mean Error: -1.0850 STD: 13.6332​

以上这篇Pytorch 神经网络—自定义数据集上实现教程就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
pymssql数据库操作MSSQL2005实例分析
May 25 Python
python实现各进制转换的总结大全
Jun 18 Python
利用python批量给云主机配置安全组的方法教程
Jun 21 Python
python控制nao机器人身体动作实例详解
Apr 29 Python
基于Numpy.convolve使用Python实现滑动平均滤波的思路详解
May 16 Python
PyQt5 多窗口连接实例
Jun 19 Python
Python中用pyinstaller打包时的图标问题及解决方法
Feb 17 Python
深入理解Tensorflow中的masking和padding
Feb 24 Python
Python 序列化和反序列化库 MarshMallow 的用法实例代码
Feb 25 Python
Python获取excel内容及相关操作代码实例
Aug 10 Python
Python reversed反转序列并生成可迭代对象
Oct 22 Python
Django显示可视化图表的实践
May 10 Python
浅谈Python访问MySQL的正确姿势
Jan 07 #Python
pytorch自定义二值化网络层方式
Jan 07 #Python
Pytorch: 自定义网络层实例
Jan 07 #Python
Python StringIO如何在内存中读写str
Jan 07 #Python
Python内置数据类型list各方法的性能测试过程解析
Jan 07 #Python
python模拟实现斗地主发牌
Jan 07 #Python
python全局变量引用与修改过程解析
Jan 07 #Python
You might like
ThinkPHP中公共函数路径和配置项路径的映射分析
2014/11/22 PHP
PHP实现生成唯一会员卡号
2015/08/24 PHP
如何使用PHP给图片加水印
2016/10/12 PHP
php基于自定义函数记录log日志方法
2017/07/21 PHP
javascript 面向对象编程基础:封装
2009/08/21 Javascript
javascript中的undefined 与 null 的区别  补充篇
2010/03/17 Javascript
jquery实现控制表格行高亮实例
2013/06/05 Javascript
Jquery 1.9.1源码分析系列(十二)之筛选操作
2015/12/02 Javascript
深入浅析JSON.parse()、JSON.stringify()和eval()的作用详解
2016/04/03 Javascript
prototype.js常用函数详解
2016/06/18 Javascript
深入理解jQuery layui分页控件的使用
2016/08/17 Javascript
详解jQuery中基本的动画方法
2016/12/14 Javascript
jQuery 实时保存页面动态添加的数据的示例
2017/08/14 jQuery
详解使用angular的HttpClient搭配rxjs
2017/09/01 Javascript
jQuery EasyUI 折叠面板accordion的使用实例(分享)
2017/12/25 jQuery
vue自定义tap指令及tap事件的实现
2018/09/18 Javascript
angular2 组件之间通过service互相传递的实例
2018/09/30 Javascript
[43:24]VG vs Serenity 2018国际邀请赛小组赛BO2 第二场 8.17
2018/08/20 DOTA
paramiko模块安装和使用(远程登录服务器)
2014/01/27 Python
Python微信公众号开发平台
2018/01/25 Python
python消费kafka数据批量插入到es的方法
2018/12/27 Python
python rsync服务器之间文件夹同步脚本
2019/08/29 Python
Python3 socket即时通讯脚本实现代码实例(threading多线程)
2020/06/01 Python
用python实现一个简单的验证码
2020/12/09 Python
美国大码时尚女装购物网站:ELOQUII
2017/12/28 全球购物
工作表现评语
2014/01/19 职场文书
学子宴答谢词
2014/01/25 职场文书
拓展训练激励口号
2014/06/17 职场文书
国际残疾人日广播稿范文
2014/10/09 职场文书
2014年纠风工作总结
2014/12/08 职场文书
交通事故和解协议书
2015/01/27 职场文书
2016年暑假学生家长评语
2015/12/01 职场文书
mysql数据库入门第一步之创建表
2021/05/14 MySQL
压缩Redis里的字符串大对象操作
2021/06/23 Redis
图神经网络GNN算法
2022/05/11 Python
css3手动实现pc端横向滚动
2022/06/21 HTML / CSS