Pytorch 神经网络—自定义数据集上实现教程


Posted in Python onJanuary 07, 2020

第一步、导入需要的包

import os
import scipy.io as sio
import numpy as np
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
from torch.autograd import Variable
batchSize = 128 # batchsize的大小
niter = 10   # epoch的最大值

第二步、构建神经网络

Pytorch 神经网络—自定义数据集上实现教程

设神经网络为如上图所示,输入层4个神经元,两层隐含层各4个神经元,输出层一个神经。每一层网络所做的都是线性变换,即y=W×X+b;代码实现如下:

class Neuralnetwork(nn.Module):
  def __init__(self, in_dim, n_hidden_1, n_hidden_2, out_dim):
    super(Neuralnetwork, self).__init__()
    self.layer1 = nn.Linear(in_dim, n_hidden_1)
    self.layer2 = nn.Linear(n_hidden_1, n_hidden_2)
    self.layer3 = nn.Linear(n_hidden_2, out_dim)
 
  def forward(self, x):
    x = x.view(x.size(0), -1)
    x = self.layer1(x)
    x = self.layer2(x)
    x = self.layer3(x)
    return x
 
model = Neuralnetwork(1*3, 4, 4, 1)
 
print(model) # net architecture
Neuralnetwork(
 (layer1): Linear(in_features=3, out_features=4, bias=True)
 (layer2): Linear(in_features=4, out_features=4, bias=True)
 (layer3): Linear(in_features=4, out_features=1, bias=True)
)

​​ 第三步、读取数据

自定义的数据为demo_SBPFea.mat,是MATLAB保存的数据格式,其存储的内容如下:包括fea(1000*3)和sbp(1000*1)两个数组;fea为特征向量,行为样本数,列为特征宽度;sbp为标签

Pytorch 神经网络—自定义数据集上实现教程

class SBPEstimateDataset(Dataset):
 
  def __init__(self, ext='demo'):
  
    data = sio.loadmat(ext+'_SBPFea.mat')
    self.fea = data['fea']
    self.sbp = data['sbp']
    
  def __len__(self):
    
    return len(self.sbp)
 
  def __getitem__(self, idx):
 
    fea = self.fea[idx]
    sbp = self.sbp[idx]
    """Convert ndarrays to Tensors."""
    return {'fea': torch.from_numpy(fea).float(),
        'sbp': torch.from_numpy(sbp).float()
        }
    
train_dataset = SBPEstimateDataset(ext='demo')
train_loader = DataLoader(train_dataset, batch_size=batchSize, # 分批次训练
             shuffle=True, num_workers=int(8))

整个数据样本为1000,以batchSize = 128划分,分为8份,前7份为104个样本,第8份则为104个样本。在网络训练过程中,是一份数据一份数据进行训练的

第四步、模型训练

# 优化器,Adam 
optimizer = optim.Adam(list(model.parameters()), lr=0.0001, betas=(0.9, 0.999),weight_decay=0.004) 
scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.997) 
criterion = nn.MSELoss() # loss function 
 
if torch.cuda.is_available(): # 有GPU,则用GPU计算
   model.cuda() 
   criterion.cuda() 
 
for epoch in range(niter): 
   losses = [] 
   ERROR_Train = [] 
   model.train() 
   for i, data in enumerate(train_loader, 0): 
     model.zero_grad()# 首先提取清零 
     real_cpu, label_cpu = data['fea'], data['sbp'] 
 
     if torch.cuda.is_available():# CUDA可用情况下,将Tensor 在GPU上运行 
       real_cpu = real_cpu.cuda() 
       label_cpu = label_cpu.cuda() 
 
 
       input=real_cpu 
       label=label_cpu 
 
       inputv = Variable(input) 
       labelv = Variable(label) 
 
       output = model(inputv) 
       err = criterion(output, labelv) 
       err.backward() 
       optimizer.step() 
 
       losses.append(err.data[0]) 
 
       error = output.data-label+ 1e-12 
       ERROR_Train.extend(error) 
 
   MAE = np.average(np.abs(np.array(ERROR_Train))) 
   ME = np.average(np.array(ERROR_Train)) 
   STD = np.std(np.array(ERROR_Train)) 
 
   print('[%d/%d] Loss: %.4f MAE: %.4f Mean Error: %.4f STD: %.4f' % ( 
   epoch, niter, np.average(losses), MAE, ME, STD))
​​
[0/10] Loss: 18384.6699 MAE: 135.3871 Mean Error: -135.3871 STD: 7.5580
[1/10] Loss: 17063.0215 MAE: 130.4145 Mean Error: -130.4145 STD: 7.8918
[2/10] Loss: 13689.1934 MAE: 116.6625 Mean Error: -116.6625 STD: 9.7946
[3/10] Loss: 8192.9053 MAE: 89.6611 Mean Error: -89.6611 STD: 12.9911
[4/10] Loss: 2979.1340 MAE: 52.5410 Mean Error: -52.5279 STD: 15.0930
[5/10] Loss: 599.7094 MAE: 22.2735 Mean Error: -19.9979 STD: 14.2069
[6/10] Loss: 207.2831 MAE: 11.2394 Mean Error: -4.8821 STD: 13.5528
[7/10] Loss: 189.8173 MAE: 9.8020 Mean Error: -1.2357 STD: 13.7095
[8/10] Loss: 188.3376 MAE: 9.6512 Mean Error: -0.6498 STD: 13.7075
[9/10] Loss: 186.8393 MAE: 9.6946 Mean Error: -1.0850 STD: 13.6332​

以上这篇Pytorch 神经网络—自定义数据集上实现教程就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python实现的数据结构与算法之快速排序详解
Apr 22 Python
python从入门到精通(DAY 2)
Dec 20 Python
Python实现的概率分布运算操作示例
Aug 14 Python
Atom的python插件和常用插件说明
Jul 08 Python
python 异或加密字符串的实例
Oct 14 Python
Python中修改字符串的四种方法
Nov 02 Python
对Python正则匹配IP、Url、Mail的方法详解
Dec 25 Python
分析经典Python开发工程师面试题
Apr 08 Python
python 比较2张图片的相似度的方法示例
Dec 18 Python
Python PyQt5模块实现窗口GUI界面代码实例
May 12 Python
深入了解Python 变量作用域
Jul 24 Python
Python时间操作之pytz模块使用详解
Jun 14 Python
浅谈Python访问MySQL的正确姿势
Jan 07 #Python
pytorch自定义二值化网络层方式
Jan 07 #Python
Pytorch: 自定义网络层实例
Jan 07 #Python
Python StringIO如何在内存中读写str
Jan 07 #Python
Python内置数据类型list各方法的性能测试过程解析
Jan 07 #Python
python模拟实现斗地主发牌
Jan 07 #Python
python全局变量引用与修改过程解析
Jan 07 #Python
You might like
php fsockopen中多线程问题的解决办法[翻译]
2011/11/09 PHP
PHP可逆加密/解密函数分享
2012/09/25 PHP
javascript静态的url如何传递
2007/05/03 Javascript
用javascript实现在小方框中浏览大图的代码
2007/08/14 Javascript
基于JQuery的模拟苹果桌面Dock效果(稳定版)
2012/10/15 Javascript
javascript 实现键盘上下左右功能的小例子
2013/09/15 Javascript
jQuery实现仿Google首页拖动效果的方法
2015/05/04 Javascript
用JavaScript实现对话框的教程
2015/06/04 Javascript
js电话号码验证方法
2015/09/28 Javascript
JS中innerHTML和pasteHTML的区别实例分析
2016/06/22 Javascript
jQuery搜索框效果实现代码(百度关键词联想)
2021/02/25 Javascript
深入浅析JavaScript中的3DES
2016/08/24 Javascript
微信小程序 教程之wxapp 视图容器 view
2016/10/19 Javascript
JS类的定义与使用方法深入探索
2016/11/26 Javascript
Vuejs入门教程之Vue生命周期,数据,手动挂载,指令,过滤器
2017/04/19 Javascript
极简主义法编写JavaScript类
2017/11/02 Javascript
vue在使用ECharts时的异步更新和数据加载详解
2017/11/22 Javascript
一步快速解决微信小程序中textarea层级太高遮挡其他组件
2019/03/04 Javascript
Vue proxyTable配置多个接口地址,解决跨域的问题
2020/09/11 Javascript
Python2.7下安装Scrapy框架步骤教程
2017/12/22 Python
Python递归函数实例讲解
2019/02/27 Python
利用python控制Autocad:pyautocad方式
2020/06/01 Python
python爬虫爬取淘宝商品比价(附淘宝反爬虫机制解决小办法)
2020/12/03 Python
HTML5 通过Vedio标签实现视频循环播放的示例代码
2020/08/05 HTML / CSS
Helly Hansen工作服美国官方网上商店:为最恶劣的环境
2019/09/04 全球购物
建筑自我鉴定
2013/10/19 职场文书
学校招生宣传广告词
2014/03/19 职场文书
小学生操行评语大全
2014/04/22 职场文书
中学生教师节演讲稿
2014/09/03 职场文书
高中生第一学年自我鉴定2015
2014/09/28 职场文书
领导班子对照检查剖析材料
2014/10/13 职场文书
小学五一劳动节活动总结
2015/02/09 职场文书
幼儿园国培研修日志
2015/11/13 职场文书
《地震中的父与子》教学反思
2016/02/16 职场文书
不会写演讲稿,快来看看这篇文章!
2019/08/06 职场文书
pytorch 如何使用float64训练
2021/05/24 Python