pytorch中的model=model.to(device)使用说明


Posted in Python onMay 24, 2021

这代表将模型加载到指定设备上。

其中,device=torch.device("cpu")代表的使用cpu,而device=torch.device("cuda")则代表的使用GPU。

当我们指定了设备之后,就需要将模型加载到相应设备中,此时需要使用model=model.to(device),将模型加载到相应的设备中。

将由GPU保存的模型加载到CPU上。

将torch.load()函数中的map_location参数设置为torch.device('cpu')

device = torch.device('cpu')
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH, map_location=device))

将由GPU保存的模型加载到GPU上。确保对输入的tensors调用input = input.to(device)方法。

device = torch.device("cuda")
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.to(device)

将由CPU保存的模型加载到GPU上。

确保对输入的tensors调用input = input.to(device)方法。map_location是将模型加载到GPU上,model.to(torch.device('cuda'))是将模型参数加载为CUDA的tensor。

最后保证使用.to(torch.device('cuda'))方法将需要使用的参数放入CUDA。

device = torch.device("cuda")
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH, map_location="cuda:0"))  # Choose whatever GPU device number you want
model.to(device)

补充:pytorch中model.to(device)和map_location=device的区别

一、简介

在已训练并保存在CPU上的GPU上加载模型时,加载模型时经常由于训练和保存模型时设备不同出现读取模型时出现错误,在对跨设备的模型读取时候涉及到两个参数的使用,分别是model.to(device)和map_location=devicel两个参数,简介一下两者的不同。

将map_location函数中的参数设置 torch.load()为 cuda:device_id。这会将模型加载到给定的GPU设备。

调用model.to(torch.device('cuda'))将模型的参数张量转换为CUDA张量,无论在cpu上训练还是gpu上训练,保存的模型参数都是参数张量不是cuda张量,因此,cpu设备上不需要使用torch.to(torch.device("cpu"))。

二、实例

了解了两者代表的意义,以下介绍两者的使用。

1、保存在GPU上,在CPU上加载

保存:

torch.save(model.state_dict(), PATH)

加载:

device = torch.device('cpu')
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH, map_location=device))

解释:

在使用GPU训练的CPU上加载模型时,请传递 torch.device('cpu')给map_location函数中的 torch.load()参数,使用map_location参数将张量下面的存储器动态地重新映射到CPU设备 。

2、保存在GPU上,在GPU上加载

保存:

torch.save(model.state_dict(), PATH)

加载:

device = torch.device("cuda")
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.to(device)
# Make sure to call input = input.to(device) on any input tensors that you feed to the model

解释:

在GPU上训练并保存在GPU上的模型时,只需将初始化model模型转换为CUDA优化模型即可model.to(torch.device('cuda'))。

此外,请务必.to(torch.device('cuda'))在所有模型输入上使用该 功能来准备模型的数据。

请注意,调用my_tensor.to(device) 返回my_tensorGPU上的新副本。

它不会覆盖 my_tensor。

因此,请记住手动覆盖张量: my_tensor = my_tensor.to(torch.device('cuda'))

3、保存在CPU,在GPU上加载

保存:

torch.save(model.state_dict(), PATH)

加载:

device = torch.device("cuda")
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH, map_location="cuda:0"))  # Choose whatever GPU device number you want
model.to(device)
# Make sure to call input = input.to(device) on any input tensors that you feed to the model

解释:

在已训练并保存在CPU上的GPU上加载模型时,请将map_location函数中的参数设置 torch.load()为 cuda:device_id。

这会将模型加载到给定的GPU设备。

接下来,请务必调用model.to(torch.device('cuda'))将模型的参数张量转换为CUDA张量。

最后,确保.to(torch.device('cuda'))在所有模型输入上使用该 函数来为CUDA优化模型准备数据。

请注意,调用 my_tensor.to(device)返回my_tensorGPU上的新副本。

它不会覆盖my_tensor。

因此,请记住手动覆盖张量:my_tensor = my_tensor.to(torch.device('cuda'))

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python发送伪造的arp请求
Jan 09 Python
Python列表生成器的循环技巧分享
Mar 06 Python
在Python中处理日期和时间的基本知识点整理汇总
May 22 Python
Python使用dis模块把Python反编译为字节码的用法详解
Jun 14 Python
使用NumPy和pandas对CSV文件进行写操作的实例
Jun 14 Python
Python 按字典dict的键排序,并取出相应的键值放于list中的实例
Feb 12 Python
python 将日期戳(五位数时间)转换为标准时间
Jul 11 Python
python数值基础知识浅析
Nov 19 Python
Django框架models使用group by详解
Mar 11 Python
python 实现压缩和解压缩的示例
Sep 22 Python
Numpy中np.random.rand()和np.random.randn() 用法和区别详解
Oct 23 Python
PyQt5 显示超清高分辨率图片的方法
Apr 11 Python
解决pytorch-gpu 安装失败的记录
May 24 #Python
如何解决.cuda()加载用时很长的问题
一劳永逸彻底解决pip install慢的办法
May 24 #Python
Django实现翻页的示例代码
May 24 #Python
pytorch--之halfTensor的使用详解
pandas DataFrame.shift()函数的具体使用
May 24 #Python
教你怎么用python实现字符串转日期
May 24 #Python
You might like
php利用gd库为图片添加水印
2016/11/09 PHP
php实现与python进行socket通信的方法示例
2017/08/30 PHP
各种效果的jquery ui(接口)介绍
2008/09/17 Javascript
js 与或运算符 || && 妙用
2009/12/09 Javascript
jQuery学习总结之元素的相对定位和选择器(持续更新)
2011/04/26 Javascript
一个Action如何调用两个不同的方法
2014/05/22 Javascript
Javascript冒泡排序算法详解
2014/12/03 Javascript
AngularJS入门教程之Hello World!
2014/12/06 Javascript
jQueryUI中的datepicker使用方法详解
2016/05/25 Javascript
手机图片预览插件photoswipe.js使用总结
2016/08/25 Javascript
Angularjs实现mvvm式的选项卡示例代码
2016/09/08 Javascript
SelecT下拉框选中和取值的解决方法
2016/11/22 Javascript
Angularjs 依赖压缩及自定义过滤器写法
2017/02/04 Javascript
Vue-Router实现页面正在加载特效方法示例
2017/02/12 Javascript
浅谈JS中的常用选择器及属性、方法的调用
2017/07/28 Javascript
详解利用eventemitter2实现Vue组件通信
2019/11/04 Javascript
Vue 实现分页与输入框关键字筛选功能
2020/01/02 Javascript
基于js实现逐步显示文字输出代码实例
2020/04/02 Javascript
在Vue 中获取下拉框的文本及选项值操作
2020/08/13 Javascript
linux系统使用python监测系统负载脚本分享
2014/01/15 Python
利用Python的装饰器解决Bottle框架中用户验证问题
2015/04/24 Python
使用Python编写提取日志中的中文的脚本的方法
2015/04/30 Python
Python运维之获取系统CPU信息的实现方法
2018/06/11 Python
python如何实现一个刷网页小程序
2018/11/27 Python
Python安装及Pycharm安装使用教程图解
2019/09/20 Python
python修改文件内容的3种方法详解
2019/11/15 Python
Python3的socket使用方法详解
2020/02/18 Python
pytorch ImageFolder的覆写实例
2020/02/20 Python
Python 操作SQLite数据库的示例
2020/10/16 Python
俄罗斯和世界各地的酒店预订:Hotels.com俄罗斯
2016/08/19 全球购物
新闻专业推荐信范文
2013/11/20 职场文书
应用化学专业职业生涯规划书
2013/12/31 职场文书
影视动画专业个人的自我评价
2013/12/31 职场文书
慈善晚会策划方案
2014/05/14 职场文书
Python实现拼音转换
2021/06/07 Python
Mysql中where与on的区别及何时使用详析
2021/08/04 MySQL