pytorch中的model=model.to(device)使用说明


Posted in Python onMay 24, 2021

这代表将模型加载到指定设备上。

其中,device=torch.device("cpu")代表的使用cpu,而device=torch.device("cuda")则代表的使用GPU。

当我们指定了设备之后,就需要将模型加载到相应设备中,此时需要使用model=model.to(device),将模型加载到相应的设备中。

将由GPU保存的模型加载到CPU上。

将torch.load()函数中的map_location参数设置为torch.device('cpu')

device = torch.device('cpu')
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH, map_location=device))

将由GPU保存的模型加载到GPU上。确保对输入的tensors调用input = input.to(device)方法。

device = torch.device("cuda")
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.to(device)

将由CPU保存的模型加载到GPU上。

确保对输入的tensors调用input = input.to(device)方法。map_location是将模型加载到GPU上,model.to(torch.device('cuda'))是将模型参数加载为CUDA的tensor。

最后保证使用.to(torch.device('cuda'))方法将需要使用的参数放入CUDA。

device = torch.device("cuda")
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH, map_location="cuda:0"))  # Choose whatever GPU device number you want
model.to(device)

补充:pytorch中model.to(device)和map_location=device的区别

一、简介

在已训练并保存在CPU上的GPU上加载模型时,加载模型时经常由于训练和保存模型时设备不同出现读取模型时出现错误,在对跨设备的模型读取时候涉及到两个参数的使用,分别是model.to(device)和map_location=devicel两个参数,简介一下两者的不同。

将map_location函数中的参数设置 torch.load()为 cuda:device_id。这会将模型加载到给定的GPU设备。

调用model.to(torch.device('cuda'))将模型的参数张量转换为CUDA张量,无论在cpu上训练还是gpu上训练,保存的模型参数都是参数张量不是cuda张量,因此,cpu设备上不需要使用torch.to(torch.device("cpu"))。

二、实例

了解了两者代表的意义,以下介绍两者的使用。

1、保存在GPU上,在CPU上加载

保存:

torch.save(model.state_dict(), PATH)

加载:

device = torch.device('cpu')
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH, map_location=device))

解释:

在使用GPU训练的CPU上加载模型时,请传递 torch.device('cpu')给map_location函数中的 torch.load()参数,使用map_location参数将张量下面的存储器动态地重新映射到CPU设备 。

2、保存在GPU上,在GPU上加载

保存:

torch.save(model.state_dict(), PATH)

加载:

device = torch.device("cuda")
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.to(device)
# Make sure to call input = input.to(device) on any input tensors that you feed to the model

解释:

在GPU上训练并保存在GPU上的模型时,只需将初始化model模型转换为CUDA优化模型即可model.to(torch.device('cuda'))。

此外,请务必.to(torch.device('cuda'))在所有模型输入上使用该 功能来准备模型的数据。

请注意,调用my_tensor.to(device) 返回my_tensorGPU上的新副本。

它不会覆盖 my_tensor。

因此,请记住手动覆盖张量: my_tensor = my_tensor.to(torch.device('cuda'))

3、保存在CPU,在GPU上加载

保存:

torch.save(model.state_dict(), PATH)

加载:

device = torch.device("cuda")
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH, map_location="cuda:0"))  # Choose whatever GPU device number you want
model.to(device)
# Make sure to call input = input.to(device) on any input tensors that you feed to the model

解释:

在已训练并保存在CPU上的GPU上加载模型时,请将map_location函数中的参数设置 torch.load()为 cuda:device_id。

这会将模型加载到给定的GPU设备。

接下来,请务必调用model.to(torch.device('cuda'))将模型的参数张量转换为CUDA张量。

最后,确保.to(torch.device('cuda'))在所有模型输入上使用该 函数来为CUDA优化模型准备数据。

请注意,调用 my_tensor.to(device)返回my_tensorGPU上的新副本。

它不会覆盖my_tensor。

因此,请记住手动覆盖张量:my_tensor = my_tensor.to(torch.device('cuda'))

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python不带重复的全排列代码
Aug 13 Python
在GitHub Pages上使用Pelican搭建博客的教程
Apr 25 Python
在python的类中动态添加属性与生成对象
Sep 17 Python
Python使用Matplotlib模块时坐标轴标题中文及各种特殊符号显示方法
May 04 Python
Python读取本地文件并解析网页元素的方法
May 21 Python
教你利用Python玩转histogram直方图的五种方法
Jul 30 Python
python 产生token及token验证的方法
Dec 26 Python
python中tkinter的应用:修改字体的实例讲解
Jul 17 Python
Python下应用opencv 实现人脸检测功能
Oct 24 Python
Ubuntu18.04安装 PyCharm并使用 Anaconda 管理的Python环境
Apr 08 Python
python 定义函数 返回值只取其中一个的实现
May 21 Python
Django+Celery实现定时任务的示例
Jun 23 Python
解决pytorch-gpu 安装失败的记录
May 24 #Python
如何解决.cuda()加载用时很长的问题
一劳永逸彻底解决pip install慢的办法
May 24 #Python
Django实现翻页的示例代码
May 24 #Python
pytorch--之halfTensor的使用详解
pandas DataFrame.shift()函数的具体使用
May 24 #Python
教你怎么用python实现字符串转日期
May 24 #Python
You might like
ajax取消挂起请求的处理方法
2013/03/18 PHP
PHP仿qq空间或朋友圈发布动态、评论动态、回复评论、删除动态或评论的功能(上)
2017/05/26 PHP
JavaScript 高效运行代码分析
2010/03/18 Javascript
JavaScript Date对象 日期获取函数
2010/12/19 Javascript
基于pthread_create,readlink,getpid等函数的学习与总结
2013/07/17 Javascript
浅析JavaScript声明变量
2015/12/21 Javascript
JavaScript表单焦点自动切换代码
2016/07/24 Javascript
微信小程序开发探究
2016/12/27 Javascript
vue绑定设置属性的多种方式(5)
2017/08/16 Javascript
webpack打包js文件及部署的实现方法
2017/12/18 Javascript
Vue slot用法(小结)
2018/10/22 Javascript
javascript的this关键字详解
2019/05/20 Javascript
微信小程序文章详情页跳转案例详解
2019/07/09 Javascript
微信头像地址失效踩坑记附带解决方案
2019/09/23 Javascript
简单了解vue中的v-if和v-show的区别
2019/10/08 Javascript
vue项目打包后请求地址错误/打包后跨域操作
2020/11/04 Javascript
[01:57]2016完美“圣”典风云人物:国士无双专访
2016/12/04 DOTA
Python 的描述符 descriptor详解
2016/02/27 Python
Python极简代码实现杨辉三角示例代码
2016/11/15 Python
python实现日常记账本小程序
2018/03/10 Python
基于python log取对数详解
2018/06/08 Python
Python2.7.10以上pip更新及其他包的安装教程
2018/06/12 Python
使用Python更换外网IP的方法
2018/07/09 Python
Django页面数据的缓存与使用的具体方法
2019/04/23 Python
Python进程,多进程,获取进程id,给子进程传递参数操作示例
2019/10/11 Python
Django 实现Admin自动填充当前用户的示例代码
2019/11/18 Python
2014两会优秀的心得体会范文
2014/03/17 职场文书
感恩父母的演讲稿
2014/05/06 职场文书
企业文化标语大全
2014/06/10 职场文书
毕业生找工作自荐书
2014/06/30 职场文书
2015年采购工作总结
2015/04/10 职场文书
企业财务人员岗位职责
2015/04/14 职场文书
施工安全保证书
2015/05/09 职场文书
python之django路由和视图案例教程
2021/07/26 Python
SQL之各种join小结详细讲解
2021/08/04 MySQL
详解Vue中$props、$attrs和$listeners的使用方法
2022/02/18 Vue.js