Pytorch 神经网络—自定义数据集上实现教程


Posted in Python onJanuary 07, 2020

第一步、导入需要的包

import os
import scipy.io as sio
import numpy as np
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
from torch.autograd import Variable
batchSize = 128 # batchsize的大小
niter = 10   # epoch的最大值

第二步、构建神经网络

Pytorch 神经网络—自定义数据集上实现教程

设神经网络为如上图所示,输入层4个神经元,两层隐含层各4个神经元,输出层一个神经。每一层网络所做的都是线性变换,即y=W×X+b;代码实现如下:

class Neuralnetwork(nn.Module):
  def __init__(self, in_dim, n_hidden_1, n_hidden_2, out_dim):
    super(Neuralnetwork, self).__init__()
    self.layer1 = nn.Linear(in_dim, n_hidden_1)
    self.layer2 = nn.Linear(n_hidden_1, n_hidden_2)
    self.layer3 = nn.Linear(n_hidden_2, out_dim)
 
  def forward(self, x):
    x = x.view(x.size(0), -1)
    x = self.layer1(x)
    x = self.layer2(x)
    x = self.layer3(x)
    return x
 
model = Neuralnetwork(1*3, 4, 4, 1)
 
print(model) # net architecture
Neuralnetwork(
 (layer1): Linear(in_features=3, out_features=4, bias=True)
 (layer2): Linear(in_features=4, out_features=4, bias=True)
 (layer3): Linear(in_features=4, out_features=1, bias=True)
)

​​ 第三步、读取数据

自定义的数据为demo_SBPFea.mat,是MATLAB保存的数据格式,其存储的内容如下:包括fea(1000*3)和sbp(1000*1)两个数组;fea为特征向量,行为样本数,列为特征宽度;sbp为标签

Pytorch 神经网络—自定义数据集上实现教程

class SBPEstimateDataset(Dataset):
 
  def __init__(self, ext='demo'):
  
    data = sio.loadmat(ext+'_SBPFea.mat')
    self.fea = data['fea']
    self.sbp = data['sbp']
    
  def __len__(self):
    
    return len(self.sbp)
 
  def __getitem__(self, idx):
 
    fea = self.fea[idx]
    sbp = self.sbp[idx]
    """Convert ndarrays to Tensors."""
    return {'fea': torch.from_numpy(fea).float(),
        'sbp': torch.from_numpy(sbp).float()
        }
    
train_dataset = SBPEstimateDataset(ext='demo')
train_loader = DataLoader(train_dataset, batch_size=batchSize, # 分批次训练
             shuffle=True, num_workers=int(8))

整个数据样本为1000,以batchSize = 128划分,分为8份,前7份为104个样本,第8份则为104个样本。在网络训练过程中,是一份数据一份数据进行训练的

第四步、模型训练

# 优化器,Adam 
optimizer = optim.Adam(list(model.parameters()), lr=0.0001, betas=(0.9, 0.999),weight_decay=0.004) 
scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.997) 
criterion = nn.MSELoss() # loss function 
 
if torch.cuda.is_available(): # 有GPU,则用GPU计算
   model.cuda() 
   criterion.cuda() 
 
for epoch in range(niter): 
   losses = [] 
   ERROR_Train = [] 
   model.train() 
   for i, data in enumerate(train_loader, 0): 
     model.zero_grad()# 首先提取清零 
     real_cpu, label_cpu = data['fea'], data['sbp'] 
 
     if torch.cuda.is_available():# CUDA可用情况下,将Tensor 在GPU上运行 
       real_cpu = real_cpu.cuda() 
       label_cpu = label_cpu.cuda() 
 
 
       input=real_cpu 
       label=label_cpu 
 
       inputv = Variable(input) 
       labelv = Variable(label) 
 
       output = model(inputv) 
       err = criterion(output, labelv) 
       err.backward() 
       optimizer.step() 
 
       losses.append(err.data[0]) 
 
       error = output.data-label+ 1e-12 
       ERROR_Train.extend(error) 
 
   MAE = np.average(np.abs(np.array(ERROR_Train))) 
   ME = np.average(np.array(ERROR_Train)) 
   STD = np.std(np.array(ERROR_Train)) 
 
   print('[%d/%d] Loss: %.4f MAE: %.4f Mean Error: %.4f STD: %.4f' % ( 
   epoch, niter, np.average(losses), MAE, ME, STD))
​​
[0/10] Loss: 18384.6699 MAE: 135.3871 Mean Error: -135.3871 STD: 7.5580
[1/10] Loss: 17063.0215 MAE: 130.4145 Mean Error: -130.4145 STD: 7.8918
[2/10] Loss: 13689.1934 MAE: 116.6625 Mean Error: -116.6625 STD: 9.7946
[3/10] Loss: 8192.9053 MAE: 89.6611 Mean Error: -89.6611 STD: 12.9911
[4/10] Loss: 2979.1340 MAE: 52.5410 Mean Error: -52.5279 STD: 15.0930
[5/10] Loss: 599.7094 MAE: 22.2735 Mean Error: -19.9979 STD: 14.2069
[6/10] Loss: 207.2831 MAE: 11.2394 Mean Error: -4.8821 STD: 13.5528
[7/10] Loss: 189.8173 MAE: 9.8020 Mean Error: -1.2357 STD: 13.7095
[8/10] Loss: 188.3376 MAE: 9.6512 Mean Error: -0.6498 STD: 13.7075
[9/10] Loss: 186.8393 MAE: 9.6946 Mean Error: -1.0850 STD: 13.6332​

以上这篇Pytorch 神经网络—自定义数据集上实现教程就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python实现挑选出来100以内的质数
Mar 24 Python
python中随机函数random用法实例
Apr 30 Python
理解Python中的With语句
Mar 18 Python
python中的格式化输出用法总结
Jul 28 Python
python 与GO中操作slice,list的方式实例代码
Mar 20 Python
python matplotlib中文显示参数设置解析
Dec 15 Python
Python/Django后端使用PIL Image生成头像缩略图
Apr 30 Python
python async with和async for的使用
Jun 20 Python
Python库安装速度过慢解决方案
Jul 14 Python
详解python算法常用技巧与内置库
Oct 17 Python
用Python监控你的朋友都在浏览哪些网站?
May 27 Python
Pygame Draw绘图函数的具体使用
Nov 17 Python
浅谈Python访问MySQL的正确姿势
Jan 07 #Python
pytorch自定义二值化网络层方式
Jan 07 #Python
Pytorch: 自定义网络层实例
Jan 07 #Python
Python StringIO如何在内存中读写str
Jan 07 #Python
Python内置数据类型list各方法的性能测试过程解析
Jan 07 #Python
python模拟实现斗地主发牌
Jan 07 #Python
python全局变量引用与修改过程解析
Jan 07 #Python
You might like
PHP发明人谈MVC和网站设计架构 貌似他不支持php用mvc
2011/06/04 PHP
PHP中func_get_args(),func_get_arg(),func_num_args()的区别
2013/09/30 PHP
微信公众平台网页授权获取用户基本信息中授权回调域名设置的变动
2014/10/21 PHP
php的sso单点登录实现方法
2015/01/08 PHP
详解PHP数据压缩、加解密(pack, unpack)
2016/12/17 PHP
关于php支持的协议与封装协议总结(推荐)
2017/11/17 PHP
php实现将数组或对象写入到文件的方法小结【三种方法】
2020/04/22 PHP
Javascript 汉字字节判断
2009/08/01 Javascript
说明你的Javascript技术很烂的五个原因
2011/04/26 Javascript
70+漂亮且极具亲和力的导航菜单设计国外网站推荐
2011/09/20 Javascript
Prototype源码浅析 String部分(一)之有关indexOf优化
2012/01/15 Javascript
jquery获取一组checkbox的值(实例代码)
2013/11/04 Javascript
JavaScript实现检查页面上的广告是否被AdBlock屏蔽了的方法
2014/11/03 Javascript
Javascript 拖拽的一些高级的应用(逐行分析代码,让你轻松了拖拽的原理)
2015/01/23 Javascript
基于jquery实现图片相关操作(重绘、获取尺寸、调整大小、缩放)
2015/12/25 Javascript
jQuery+ThinkPHP+Ajax实现即时消息提醒功能实例代码
2017/03/21 jQuery
js使用i18n实现页面国际化的方法
2017/05/09 Javascript
使用 Node.js 对文本内容分词和关键词抽取
2017/05/27 Javascript
Angular ng-animate和ng-cookies用法详解
2018/04/18 Javascript
nodejs aes 加解密实例
2018/10/10 NodeJs
element的el-table中记录滚动条位置的示例代码
2019/11/06 Javascript
python实现扫描日志关键字的示例
2018/04/28 Python
Flask框架通过Flask_login实现用户登录功能示例
2018/07/17 Python
Python 实现两个列表里元素对应相乘的方法
2018/11/14 Python
如何爬取通过ajax加载数据的网站
2019/08/15 Python
解决Python3.8运行tornado项目报NotImplementedError错误
2020/09/02 Python
10分钟理解CSS3 Grid布局
2018/12/20 HTML / CSS
外企C语言笔试题
2013/11/10 面试题
优秀学生干部推荐材料
2014/02/03 职场文书
2014国庆节演讲稿:祖国在我心中(400字)
2014/09/25 职场文书
民事辩护词范文
2015/05/21 职场文书
鸦片战争观后感
2015/06/09 职场文书
纪检干部学习心得体会
2016/01/23 职场文书
教师纪律作风整顿心得体会
2016/01/23 职场文书
深入讲解数据库中Decimal类型的使用以及实现方法
2022/02/15 MySQL
Python尝试实现蒙特卡罗模拟期权定价
2022/04/21 Python