如何使用Pytorch搭建模型


Posted in Python onOctober 26, 2020

1  模型定义

和TF很像,Pytorch也通过继承父类来搭建模型,同样也是实现两个方法。在TF中是__init__()和call(),在Pytorch中则是__init__()和forward()。功能类似,都分别是初始化模型内部结构和进行推理。其它功能比如计算loss和训练函数,你也可以继承在里面,当然这是可选的。下面搭建一个判别MNIST手写字的Demo,首先给出模型代码:

import numpy as np
import matplotlib.pyplot as plt 
import torch 
from torch import nn,optim 
from torchsummary import summary 
from keras.datasets import mnist
from keras.utils import to_categorical
device = torch.device('cuda') #——————1——————
 
class ModelTest(nn.Module):
 def __init__(self,device):
  super().__init__() 
  self.layer1 = nn.Sequential(nn.Flatten(),nn.Linear(28*28,512),nn.ReLU())#——————2——————
  self.layer2 = nn.Sequential(nn.Linear(512,512),nn.ReLU()) 
  self.layer3 = nn.Sequential(nn.Linear(512,512),nn.ReLU())
  self.layer4 = nn.Sequential(nn.Linear(512,10),nn.Softmax()) 

  self.to(device) #——————3——————
  self.opt = optim.SGD(self.parameters(),lr=0.01)#——————4——————
 def forward(self,inputs): #——————5——————
  x = self.layer1(inputs)
  x = self.layer2(x)
  x = self.layer3(x)
  x = self.layer4(x)
  return x 
 def get_loss(self,true_labels,predicts): 
  loss = -true_labels * torch.log(predicts) #——————6——————
  loss = torch.mean(loss)
  return loss
 def train(self,imgs,labels): 
  predicts = model(imgs) 
  loss = self.get_loss(labels,predicts)
  self.opt.zero_grad()#——————7——————
  loss.backward()#——————8——————
  self.opt.step()#——————9——————
model = ModelTest(device)
summary(model,(1,28,28),3,device='cuda') #——————10——————

#1:获取设备,以方便后面的模型与变量进行内存迁移,设备名只有两种:'cuda'和'cpu'。通常是在你有GPU的情况下需要这样显式进行设备的设置,从而在需要时,你可以将变量从主存迁移到显存中。如果没有GPU,不获取也没事,pytorch会默认将参数都保存在主存中。

#2:模型中层的定义,可以使用Sequential将想要统一管理的层集中表示为一层。

#3:在初始化中将模型参数迁移到GPU显存中,加速运算,当然你也可以在需要时在外部执行model.to(device)进行迁移。

#4:定义模型的优化器,和TF不同,pytorch需要在定义时就将需要梯度下降的参数传入,也就是其中的self.parameters(),表示当前模型的所有参数。实际上你不用担心定义优化器和模型参数的顺序问题,因为self.parameters()的输出并不是模型参数的实例,而是整个模型参数对象的指针,所以即使你在定义优化器之后又定义了一个层,它依然能优化到。当然优化器你也可以在外部定义,传入model.parameters()即可。这里定义了一个随机梯度下降。

#5:模型的前向传播,和TF的call()类似,定义好model()所执行的就是这个函数。

#6:我将获取loss的函数集成在了模型中,这里计算的是真实标签和预测标签之间的交叉熵。

#7/8/9:在TF中,参数梯度是保存在梯度带中的,而在pytorch中,参数梯度是各自集成在对应的参数中的,可以使用tensor.grad来查看。每次对loss执行backward(),pytorch都会将参与loss计算的所有可训练参数关于loss的梯度叠加进去(直接相加)。所以如果我们没有叠加梯度的意愿的话,那就要在backward()之前先把之前的梯度删除。又因为我们前面已经把待训练的参数都传入了优化器,所以,对优化器使用zero_grad(),就能把所有待训练参数中已存在的梯度都清零。那么梯度叠加什么时候用到呢?比如批量梯度下降,当内存不够直接计算整个批量的梯度时,我们只能将批量分成一部分一部分来计算,每算一个部分得到loss就backward()一次,从而得到整个批量的梯度。梯度计算好后,再执行优化器的step(),优化器根据可训练参数的梯度对其执行一步优化。

#10:使用torchsummary函数显示模型结构。奇怪为什么不把这个继承在torch里面,要重新安装一个torchsummary库。

2  训练及可视化

接下来使用模型进行训练,因为pytorch自带的MNIST数据集并不好用,所以我使用的是Keras自带的,定义了一个获取数据的生成器。下面是完整的训练及绘图代码(50次迭代记录一次准确率):

import numpy as np
import matplotlib.pyplot as plt 
import torch 
from torch import nn,optim 
from torchsummary import summary 
from keras.datasets import mnist
from keras.utils import to_categorical
device = torch.device('cuda') #——————1——————
 
class ModelTest(nn.Module):
 def __init__(self,device):
  super().__init__() 
  self.layer1 = nn.Sequential(nn.Flatten(),nn.Linear(28*28,512),nn.ReLU())#——————2——————
  self.layer2 = nn.Sequential(nn.Linear(512,512),nn.ReLU()) 
  self.layer3 = nn.Sequential(nn.Linear(512,512),nn.ReLU())
  self.layer4 = nn.Sequential(nn.Linear(512,10),nn.Softmax()) 

  self.to(device) #——————3——————
  self.opt = optim.SGD(self.parameters(),lr=0.01)#——————4——————
 def forward(self,inputs): #——————5——————
  x = self.layer1(inputs)
  x = self.layer2(x)
  x = self.layer3(x)
  x = self.layer4(x)
  return x 
 def get_loss(self,true_labels,predicts): 
  loss = -true_labels * torch.log(predicts) #——————6——————
  loss = torch.mean(loss)
  return loss
 def train(self,imgs,labels): 
  predicts = model(imgs) 
  loss = self.get_loss(labels,predicts)
  self.opt.zero_grad()#——————7——————
  loss.backward()#——————8——————
  self.opt.step()#——————9——————
def get_data(device,is_train = True, batch = 1024, num = 10000):
 train_data,test_data = mnist.load_data()
 if is_train:
  imgs,labels = train_data
 else:
  imgs,labels = test_data 
 imgs = (imgs/255*2-1)[:,np.newaxis,...]
 labels = to_categorical(labels,10) 
 imgs = torch.tensor(imgs,dtype=torch.float32).to(device)
 labels = torch.tensor(labels,dtype=torch.float32).to(device)
 i = 0
 while(True):
  i += batch
  if i > num:
   i = batch 
  yield imgs[i-batch:i],labels[i-batch:i] 
train_dg = get_data(device, True,batch=4096,num=60000) 
test_dg = get_data(device, False,batch=5000,num=10000) 

model = ModelTest(device) 
summary(model,(1,28,28),11,device='cuda') 
ACCs = []
import time
start = time.time()
for j in range(20000):
 #训练
 imgs,labels = next(train_dg)
 model.train(imgs,labels)

 #验证
 img,label = next(test_dg)
 predicts = model(img) 
 acc = 1 - torch.count_nonzero(torch.argmax(predicts,axis=1) - torch.argmax(label,axis=1))/label.shape[0]
 if j % 50 == 0:
  t = time.time() - start
  start = time.time()
  ACCs.append(acc.cpu().numpy())
  print(j,t,'ACC: ',acc)
#绘图
x = np.linspace(0,len(ACCs),len(ACCs))
plt.plot(x,ACCs)

准确率变化图如下:

如何使用Pytorch搭建模型

3   注意事项

需要注意的是,pytorch的tensor基于numpy的array,它们是共享内存的。也就是说,如果你把tensor直接插入一个列表,当你修改这个tensor时,列表中的这个tensor也会被修改;更容易被忽略的是,即使你用tensor.detach.numpy(),先将tensor转换为array类型,再插入列表,当你修改原本的tensor时,列表中的这个array也依然会被修改。所以如果我们只是想保存tensor的值而不是整个对象,就要使用np.array(tensor)将tensor的值复制出来。

以上就是如何使用Pytorch搭建模型的详细内容,更多关于Pytorch搭建模型的资料请关注三水点靠木其它相关文章!

Python 相关文章推荐
python基础教程之popen函数操作其它程序的输入和输出示例
Feb 10 Python
python实现多线程的方式及多条命令并发执行
Jun 07 Python
Python中矩阵库Numpy基本操作详解
Nov 21 Python
解决vscode python print 输出窗口中文乱码的问题
Dec 03 Python
详解用Python练习画个美队盾牌
Mar 23 Python
Pandas之Fillna填充缺失数据的方法
Jun 25 Python
Python:Numpy 求平均向量的实例
Jun 29 Python
基于python实现自动化办公学习笔记(CSV、word、Excel、PPT)
Aug 06 Python
Python操作远程服务器 paramiko模块详细介绍
Aug 07 Python
python递归下载文件夹下所有文件
Aug 31 Python
Python3爬虫中关于Ajax分析方法的总结
Jul 10 Python
Python根据字符串调用函数过程解析
Nov 05 Python
使用python-cv2实现视频的分解与合成的示例代码
Oct 26 #Python
python递归函数用法详解
Oct 26 #Python
Python实现LR1文法的完整实例代码
Oct 25 #Python
Python操作word文档插入图片和表格的实例演示
Oct 25 #Python
python时间time模块处理大全
Oct 25 #Python
使用AJAX和Django获取数据的方法实例
Oct 25 #Python
Python Tkinter实例——模拟掷骰子
Oct 24 #Python
You might like
一个程序下载的管理程序(四)
2006/10/09 PHP
PHP中foreach循环中使用引用要注意的地方
2011/01/02 PHP
PHP使用静态方法的几个注意事项
2014/09/16 PHP
zend framework重定向方法小结
2016/05/28 PHP
php设计模式之代理模式分析【星际争霸游戏案例】
2020/03/23 PHP
利用onresize使得div可以随着屏幕大小而自适应的代码
2010/01/15 Javascript
js中window.open打开一个新的页面
2014/08/10 Javascript
实例讲解jQuery EasyUI tree中state属性慎用
2016/04/01 Javascript
jQuery中on绑定事件后引发的事件冒泡问题如何解决
2016/05/25 Javascript
jquery插件uploadify多图上传功能实现代码
2016/08/12 Javascript
jQuery弹出下拉列表插件(实现kindeditor的@功能)
2016/08/16 Javascript
Radio 单选JS动态添加的选项onchange事件无效的解决方法
2016/12/12 Javascript
原生JS+Canvas实现五子棋游戏实例
2017/06/19 Javascript
20行JS代码实现网页刮刮乐效果
2017/06/23 Javascript
vue-cli 如何打包上线的方法示例
2018/05/08 Javascript
vue实现同一个页面可以有多个router-view的方法
2018/09/20 Javascript
Vue自定义指令上报Google Analytics事件统计的方法
2019/02/25 Javascript
vue移动端实现手机左右滑动入场动画
2020/06/17 Javascript
JavaScript命令模式原理与用法实例详解
2020/03/10 Javascript
在react中使用vue的状态管理的方法示例
2020/05/02 Javascript
JavaScript实现表单验证功能
2020/12/09 Javascript
[01:38]【DOTA2亚洲邀请赛】Sumail——梦开始的地方
2017/03/03 DOTA
[46:32]Fnatic vs OG 2018国际邀请赛小组赛BO2 第一场 8.18
2018/08/19 DOTA
python通过ElementTree操作XML获取结点读取属性美化XML
2013/12/02 Python
python中元类用法实例
2014/10/10 Python
Python编程实战之Oracle数据库操作示例
2017/06/21 Python
Python参数类型以及常见的坑详解
2019/07/08 Python
Python实现Selenium自动化Page模式
2019/07/14 Python
使用CSS3和Checkbox实现JQuery的一些效果
2015/08/03 HTML / CSS
HTML5+css3:3D旋转木马效果相册
2017/01/03 HTML / CSS
教师节促销活动方案
2014/02/14 职场文书
有限责任公司股东合作协议书范本
2014/10/30 职场文书
旅游项目合作意向书
2015/05/08 职场文书
2015年民兵整组工作总结
2015/07/24 职场文书
Java tomcat手动配置servlet详解
2021/11/27 Java/Android
MySQ InnoDB和MyISAM存储引擎介绍
2022/04/26 MySQL