pytorch中的model=model.to(device)使用说明


Posted in Python onMay 24, 2021

这代表将模型加载到指定设备上。

其中,device=torch.device("cpu")代表的使用cpu,而device=torch.device("cuda")则代表的使用GPU。

当我们指定了设备之后,就需要将模型加载到相应设备中,此时需要使用model=model.to(device),将模型加载到相应的设备中。

将由GPU保存的模型加载到CPU上。

将torch.load()函数中的map_location参数设置为torch.device('cpu')

device = torch.device('cpu')
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH, map_location=device))

将由GPU保存的模型加载到GPU上。确保对输入的tensors调用input = input.to(device)方法。

device = torch.device("cuda")
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.to(device)

将由CPU保存的模型加载到GPU上。

确保对输入的tensors调用input = input.to(device)方法。map_location是将模型加载到GPU上,model.to(torch.device('cuda'))是将模型参数加载为CUDA的tensor。

最后保证使用.to(torch.device('cuda'))方法将需要使用的参数放入CUDA。

device = torch.device("cuda")
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH, map_location="cuda:0"))  # Choose whatever GPU device number you want
model.to(device)

补充:pytorch中model.to(device)和map_location=device的区别

一、简介

在已训练并保存在CPU上的GPU上加载模型时,加载模型时经常由于训练和保存模型时设备不同出现读取模型时出现错误,在对跨设备的模型读取时候涉及到两个参数的使用,分别是model.to(device)和map_location=devicel两个参数,简介一下两者的不同。

将map_location函数中的参数设置 torch.load()为 cuda:device_id。这会将模型加载到给定的GPU设备。

调用model.to(torch.device('cuda'))将模型的参数张量转换为CUDA张量,无论在cpu上训练还是gpu上训练,保存的模型参数都是参数张量不是cuda张量,因此,cpu设备上不需要使用torch.to(torch.device("cpu"))。

二、实例

了解了两者代表的意义,以下介绍两者的使用。

1、保存在GPU上,在CPU上加载

保存:

torch.save(model.state_dict(), PATH)

加载:

device = torch.device('cpu')
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH, map_location=device))

解释:

在使用GPU训练的CPU上加载模型时,请传递 torch.device('cpu')给map_location函数中的 torch.load()参数,使用map_location参数将张量下面的存储器动态地重新映射到CPU设备 。

2、保存在GPU上,在GPU上加载

保存:

torch.save(model.state_dict(), PATH)

加载:

device = torch.device("cuda")
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.to(device)
# Make sure to call input = input.to(device) on any input tensors that you feed to the model

解释:

在GPU上训练并保存在GPU上的模型时,只需将初始化model模型转换为CUDA优化模型即可model.to(torch.device('cuda'))。

此外,请务必.to(torch.device('cuda'))在所有模型输入上使用该 功能来准备模型的数据。

请注意,调用my_tensor.to(device) 返回my_tensorGPU上的新副本。

它不会覆盖 my_tensor。

因此,请记住手动覆盖张量: my_tensor = my_tensor.to(torch.device('cuda'))

3、保存在CPU,在GPU上加载

保存:

torch.save(model.state_dict(), PATH)

加载:

device = torch.device("cuda")
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH, map_location="cuda:0"))  # Choose whatever GPU device number you want
model.to(device)
# Make sure to call input = input.to(device) on any input tensors that you feed to the model

解释:

在已训练并保存在CPU上的GPU上加载模型时,请将map_location函数中的参数设置 torch.load()为 cuda:device_id。

这会将模型加载到给定的GPU设备。

接下来,请务必调用model.to(torch.device('cuda'))将模型的参数张量转换为CUDA张量。

最后,确保.to(torch.device('cuda'))在所有模型输入上使用该 函数来为CUDA优化模型准备数据。

请注意,调用 my_tensor.to(device)返回my_tensorGPU上的新副本。

它不会覆盖my_tensor。

因此,请记住手动覆盖张量:my_tensor = my_tensor.to(torch.device('cuda'))

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python使用cookie库操保存cookie详解
Mar 03 Python
Python urls.py的三种配置写法实例详解
Apr 28 Python
Python实现excel转sqlite的方法
Jul 17 Python
python生成tensorflow输入输出的图像格式的方法
Feb 12 Python
修改python plot折线图的坐标轴刻度方法
Dec 13 Python
通过shell+python实现企业微信预警
Mar 07 Python
python里运用私有属性和方法总结
Jul 08 Python
Python定时任务工具之APScheduler使用方式
Jul 24 Python
Ranorex通过Python将报告发送到邮箱的方法
Jan 12 Python
python 输入字符串生成所有有效的IP地址(LeetCode 93号题)
Oct 15 Python
Python基于argparse与ConfigParser库进行入参解析与ini parser
Feb 02 Python
Python OpenGL基本配置方式
May 20 Python
解决pytorch-gpu 安装失败的记录
May 24 #Python
如何解决.cuda()加载用时很长的问题
一劳永逸彻底解决pip install慢的办法
May 24 #Python
Django实现翻页的示例代码
May 24 #Python
pytorch--之halfTensor的使用详解
pandas DataFrame.shift()函数的具体使用
May 24 #Python
教你怎么用python实现字符串转日期
May 24 #Python
You might like
菜鸟修复电子管记
2021/03/02 无线电
Php中文件下载功能实现超详细流程分析
2012/06/13 PHP
简单谈谈PHP中的Reload操作
2016/12/12 PHP
PHP实现多级分类生成树的方法示例
2017/02/07 PHP
php学习笔记之mb_strstr的基本使用
2018/02/03 PHP
Laravel框架使用Seeder实现自动填充数据功能
2018/06/13 PHP
php与阿里云短信接口接入操作案例分析
2020/05/27 PHP
js+JQuery返回顶部功能如何实现
2012/12/03 Javascript
使用javascript为网页增加夜间模式
2014/01/26 Javascript
javascript实现禁止鼠标滚轮事件
2015/07/24 Javascript
jquery地址栏链接与a标签链接匹配之特效代码总结
2015/08/24 Javascript
JavaScript字符串常用的方法
2016/03/10 Javascript
基于Vue如何封装分页组件
2016/12/16 Javascript
javascript异步编程的六种方式总结
2019/05/17 Javascript
vue element-ui el-date-picker限制选择时间为当天之前的代码
2019/11/07 Javascript
详解ES6中class的实现原理
2020/10/03 Javascript
Nuxt的动态路由和参数校验操作
2020/11/09 Javascript
Python中的类学习笔记
2014/09/23 Python
在Linux系统上安装Python的Scrapy框架的教程
2015/06/11 Python
在arcgis使用python脚本进行字段计算时是如何解决中文问题的
2015/10/18 Python
python中安装模块包版本冲突问题的解决
2017/05/02 Python
Python实现绘制双柱状图并显示数值功能示例
2018/06/23 Python
Django框架安装方法图文详解
2019/11/04 Python
PyTorch中torch.tensor与torch.Tensor的区别详解
2020/05/18 Python
Python Opencv轮廓常用操作代码实例解析
2020/09/01 Python
python批量修改文件名的示例
2020/09/27 Python
为什么使用接口?
2014/08/13 面试题
老同学聚会感言
2014/02/23 职场文书
职业规划实施方案
2014/06/10 职场文书
幼师自荐信范文
2015/03/06 职场文书
2015年客房服务员工作总结
2015/05/15 职场文书
2015年公司行政后勤工作总结
2015/05/20 职场文书
婚庆公司开业主持词
2015/06/30 职场文书
学习党章心得体会2016
2016/01/15 职场文书
MySQL复制问题的三个参数分析
2021/04/07 MySQL
MySQL update set 和 and的区别
2021/05/08 MySQL