Pytorch 神经网络—自定义数据集上实现教程


Posted in Python onJanuary 07, 2020

第一步、导入需要的包

import os
import scipy.io as sio
import numpy as np
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
from torch.autograd import Variable
batchSize = 128 # batchsize的大小
niter = 10   # epoch的最大值

第二步、构建神经网络

Pytorch 神经网络—自定义数据集上实现教程

设神经网络为如上图所示,输入层4个神经元,两层隐含层各4个神经元,输出层一个神经。每一层网络所做的都是线性变换,即y=W×X+b;代码实现如下:

class Neuralnetwork(nn.Module):
  def __init__(self, in_dim, n_hidden_1, n_hidden_2, out_dim):
    super(Neuralnetwork, self).__init__()
    self.layer1 = nn.Linear(in_dim, n_hidden_1)
    self.layer2 = nn.Linear(n_hidden_1, n_hidden_2)
    self.layer3 = nn.Linear(n_hidden_2, out_dim)
 
  def forward(self, x):
    x = x.view(x.size(0), -1)
    x = self.layer1(x)
    x = self.layer2(x)
    x = self.layer3(x)
    return x
 
model = Neuralnetwork(1*3, 4, 4, 1)
 
print(model) # net architecture
Neuralnetwork(
 (layer1): Linear(in_features=3, out_features=4, bias=True)
 (layer2): Linear(in_features=4, out_features=4, bias=True)
 (layer3): Linear(in_features=4, out_features=1, bias=True)
)

​​ 第三步、读取数据

自定义的数据为demo_SBPFea.mat,是MATLAB保存的数据格式,其存储的内容如下:包括fea(1000*3)和sbp(1000*1)两个数组;fea为特征向量,行为样本数,列为特征宽度;sbp为标签

Pytorch 神经网络—自定义数据集上实现教程

class SBPEstimateDataset(Dataset):
 
  def __init__(self, ext='demo'):
  
    data = sio.loadmat(ext+'_SBPFea.mat')
    self.fea = data['fea']
    self.sbp = data['sbp']
    
  def __len__(self):
    
    return len(self.sbp)
 
  def __getitem__(self, idx):
 
    fea = self.fea[idx]
    sbp = self.sbp[idx]
    """Convert ndarrays to Tensors."""
    return {'fea': torch.from_numpy(fea).float(),
        'sbp': torch.from_numpy(sbp).float()
        }
    
train_dataset = SBPEstimateDataset(ext='demo')
train_loader = DataLoader(train_dataset, batch_size=batchSize, # 分批次训练
             shuffle=True, num_workers=int(8))

整个数据样本为1000,以batchSize = 128划分,分为8份,前7份为104个样本,第8份则为104个样本。在网络训练过程中,是一份数据一份数据进行训练的

第四步、模型训练

# 优化器,Adam 
optimizer = optim.Adam(list(model.parameters()), lr=0.0001, betas=(0.9, 0.999),weight_decay=0.004) 
scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.997) 
criterion = nn.MSELoss() # loss function 
 
if torch.cuda.is_available(): # 有GPU,则用GPU计算
   model.cuda() 
   criterion.cuda() 
 
for epoch in range(niter): 
   losses = [] 
   ERROR_Train = [] 
   model.train() 
   for i, data in enumerate(train_loader, 0): 
     model.zero_grad()# 首先提取清零 
     real_cpu, label_cpu = data['fea'], data['sbp'] 
 
     if torch.cuda.is_available():# CUDA可用情况下,将Tensor 在GPU上运行 
       real_cpu = real_cpu.cuda() 
       label_cpu = label_cpu.cuda() 
 
 
       input=real_cpu 
       label=label_cpu 
 
       inputv = Variable(input) 
       labelv = Variable(label) 
 
       output = model(inputv) 
       err = criterion(output, labelv) 
       err.backward() 
       optimizer.step() 
 
       losses.append(err.data[0]) 
 
       error = output.data-label+ 1e-12 
       ERROR_Train.extend(error) 
 
   MAE = np.average(np.abs(np.array(ERROR_Train))) 
   ME = np.average(np.array(ERROR_Train)) 
   STD = np.std(np.array(ERROR_Train)) 
 
   print('[%d/%d] Loss: %.4f MAE: %.4f Mean Error: %.4f STD: %.4f' % ( 
   epoch, niter, np.average(losses), MAE, ME, STD))
​​
[0/10] Loss: 18384.6699 MAE: 135.3871 Mean Error: -135.3871 STD: 7.5580
[1/10] Loss: 17063.0215 MAE: 130.4145 Mean Error: -130.4145 STD: 7.8918
[2/10] Loss: 13689.1934 MAE: 116.6625 Mean Error: -116.6625 STD: 9.7946
[3/10] Loss: 8192.9053 MAE: 89.6611 Mean Error: -89.6611 STD: 12.9911
[4/10] Loss: 2979.1340 MAE: 52.5410 Mean Error: -52.5279 STD: 15.0930
[5/10] Loss: 599.7094 MAE: 22.2735 Mean Error: -19.9979 STD: 14.2069
[6/10] Loss: 207.2831 MAE: 11.2394 Mean Error: -4.8821 STD: 13.5528
[7/10] Loss: 189.8173 MAE: 9.8020 Mean Error: -1.2357 STD: 13.7095
[8/10] Loss: 188.3376 MAE: 9.6512 Mean Error: -0.6498 STD: 13.7075
[9/10] Loss: 186.8393 MAE: 9.6946 Mean Error: -1.0850 STD: 13.6332​

以上这篇Pytorch 神经网络—自定义数据集上实现教程就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python求解平方根的方法
Mar 11 Python
Python远程桌面协议RDPY安装使用介绍
Apr 15 Python
详解python中的装饰器
Jul 10 Python
python绘制散点图并标记序号的方法
Dec 11 Python
Python数据结构之栈、队列及二叉树定义与用法浅析
Dec 27 Python
pip安装py_zipkin时提示的SSL问题对应
Dec 29 Python
python+selenium实现自动化百度搜索关键词
Jun 03 Python
django 捕获异常和日志系统过程详解
Jul 18 Python
Python中typing模块与类型注解的使用方法
Aug 05 Python
python实现递归查找某个路径下所有文件中的中文字符
Aug 31 Python
pytorch实现用CNN和LSTM对文本进行分类方式
Jan 08 Python
浅谈Tensorflow 动态双向RNN的输出问题
Jan 20 Python
浅谈Python访问MySQL的正确姿势
Jan 07 #Python
pytorch自定义二值化网络层方式
Jan 07 #Python
Pytorch: 自定义网络层实例
Jan 07 #Python
Python StringIO如何在内存中读写str
Jan 07 #Python
Python内置数据类型list各方法的性能测试过程解析
Jan 07 #Python
python模拟实现斗地主发牌
Jan 07 #Python
python全局变量引用与修改过程解析
Jan 07 #Python
You might like
PHP下对字符串的递增运算代码
2010/08/21 PHP
php中拷贝构造函数、赋值运算符重载
2012/07/25 PHP
thinkphp5 URL和路由的功能详解与实例
2017/12/26 PHP
JsDom 编程小结
2011/08/09 Javascript
document.getElementById介绍
2011/09/13 Javascript
js下拉菜单语言选项简单实现
2013/09/23 Javascript
JS 实现导航栏悬停效果(续)
2013/09/24 Javascript
parentElement,srcElement的使用小结
2014/01/13 Javascript
Node.js实现批量去除BOM文件头
2014/12/20 Javascript
jQuery实现简单的间隔向上滚动效果
2015/03/09 Javascript
在线引用最新jquery文件的实现方法
2016/08/26 Javascript
Vue.js开发环境搭建
2016/11/10 Javascript
详解在Vue中如何使用axios跨域访问数据
2017/07/07 Javascript
AngularJS模糊查询功能实现代码(过滤内容下拉菜单排序过滤敏感字符验证判断后添加表格信息)
2017/10/24 Javascript
jQuery实现基本淡入淡出效果的方法详解
2018/09/05 jQuery
Javascript实现动态时钟效果
2018/11/17 Javascript
关于NodeJS中的循环引用详解
2019/07/23 NodeJs
JS函数进阶之继承用法实例分析
2020/01/15 Javascript
python使用邻接矩阵构造图代码示例
2017/11/10 Python
简述:我为什么选择Python而不是Matlab和R语言
2017/11/14 Python
Python使用Matplotlib实现雨点图动画效果的方法
2017/12/23 Python
Python中GIL的使用详解
2018/10/03 Python
Python3.5迭代器与生成器用法实例分析
2019/04/30 Python
解决win7操作系统Python3.7.1安装后启动提示缺少.dll文件问题
2019/07/15 Python
通过Python编写一个简单登录功能过程解析
2019/09/04 Python
python模拟实现斗地主发牌
2020/01/07 Python
python isinstance函数用法详解
2020/02/13 Python
TensorFlow的环境配置与安装方法
2021/02/20 Python
使用HTML5做个画图板的方法介绍
2013/05/03 HTML / CSS
法国在线宠物店:zooplus.fr
2018/02/23 全球购物
时尚孕妇装:HATCH Collection
2019/09/24 全球购物
大学军训感言800字
2014/02/27 职场文书
机关干部纪律作风整顿心得体会
2016/01/23 职场文书
《风娃娃》教学反思
2016/02/18 职场文书
MySQL 常见存储引擎的优劣
2021/06/02 MySQL
Python获取江苏疫情实时数据及爬虫分析
2021/08/02 Python