pytorch中的model=model.to(device)使用说明


Posted in Python onMay 24, 2021

这代表将模型加载到指定设备上。

其中,device=torch.device("cpu")代表的使用cpu,而device=torch.device("cuda")则代表的使用GPU。

当我们指定了设备之后,就需要将模型加载到相应设备中,此时需要使用model=model.to(device),将模型加载到相应的设备中。

将由GPU保存的模型加载到CPU上。

将torch.load()函数中的map_location参数设置为torch.device('cpu')

device = torch.device('cpu')
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH, map_location=device))

将由GPU保存的模型加载到GPU上。确保对输入的tensors调用input = input.to(device)方法。

device = torch.device("cuda")
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.to(device)

将由CPU保存的模型加载到GPU上。

确保对输入的tensors调用input = input.to(device)方法。map_location是将模型加载到GPU上,model.to(torch.device('cuda'))是将模型参数加载为CUDA的tensor。

最后保证使用.to(torch.device('cuda'))方法将需要使用的参数放入CUDA。

device = torch.device("cuda")
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH, map_location="cuda:0"))  # Choose whatever GPU device number you want
model.to(device)

补充:pytorch中model.to(device)和map_location=device的区别

一、简介

在已训练并保存在CPU上的GPU上加载模型时,加载模型时经常由于训练和保存模型时设备不同出现读取模型时出现错误,在对跨设备的模型读取时候涉及到两个参数的使用,分别是model.to(device)和map_location=devicel两个参数,简介一下两者的不同。

将map_location函数中的参数设置 torch.load()为 cuda:device_id。这会将模型加载到给定的GPU设备。

调用model.to(torch.device('cuda'))将模型的参数张量转换为CUDA张量,无论在cpu上训练还是gpu上训练,保存的模型参数都是参数张量不是cuda张量,因此,cpu设备上不需要使用torch.to(torch.device("cpu"))。

二、实例

了解了两者代表的意义,以下介绍两者的使用。

1、保存在GPU上,在CPU上加载

保存:

torch.save(model.state_dict(), PATH)

加载:

device = torch.device('cpu')
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH, map_location=device))

解释:

在使用GPU训练的CPU上加载模型时,请传递 torch.device('cpu')给map_location函数中的 torch.load()参数,使用map_location参数将张量下面的存储器动态地重新映射到CPU设备 。

2、保存在GPU上,在GPU上加载

保存:

torch.save(model.state_dict(), PATH)

加载:

device = torch.device("cuda")
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.to(device)
# Make sure to call input = input.to(device) on any input tensors that you feed to the model

解释:

在GPU上训练并保存在GPU上的模型时,只需将初始化model模型转换为CUDA优化模型即可model.to(torch.device('cuda'))。

此外,请务必.to(torch.device('cuda'))在所有模型输入上使用该 功能来准备模型的数据。

请注意,调用my_tensor.to(device) 返回my_tensorGPU上的新副本。

它不会覆盖 my_tensor。

因此,请记住手动覆盖张量: my_tensor = my_tensor.to(torch.device('cuda'))

3、保存在CPU,在GPU上加载

保存:

torch.save(model.state_dict(), PATH)

加载:

device = torch.device("cuda")
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH, map_location="cuda:0"))  # Choose whatever GPU device number you want
model.to(device)
# Make sure to call input = input.to(device) on any input tensors that you feed to the model

解释:

在已训练并保存在CPU上的GPU上加载模型时,请将map_location函数中的参数设置 torch.load()为 cuda:device_id。

这会将模型加载到给定的GPU设备。

接下来,请务必调用model.to(torch.device('cuda'))将模型的参数张量转换为CUDA张量。

最后,确保.to(torch.device('cuda'))在所有模型输入上使用该 函数来为CUDA优化模型准备数据。

请注意,调用 my_tensor.to(device)返回my_tensorGPU上的新副本。

它不会覆盖my_tensor。

因此,请记住手动覆盖张量:my_tensor = my_tensor.to(torch.device('cuda'))

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python创建日历实例
Aug 21 Python
以一段代码为实例快速入门Python2.7
Mar 31 Python
在Python的Flask框架中实现全文搜索功能
Apr 20 Python
Hadoop中的Python框架的使用指南
Apr 22 Python
PyMongo安装使用笔记
Apr 27 Python
Python缩进和冒号详解
Jun 01 Python
详解Python的Flask框架中的signals信号机制
Jun 13 Python
对numpy中布尔型数组的处理方法详解
Apr 17 Python
python中pytest收集用例规则与运行指定用例详解
Jun 27 Python
基于Django实现日志记录报错信息
Dec 17 Python
Python的Django框架实现数据库查询(不返回QuerySet的方法)
May 19 Python
Python实现简繁体转换
Jun 07 Python
解决pytorch-gpu 安装失败的记录
May 24 #Python
如何解决.cuda()加载用时很长的问题
一劳永逸彻底解决pip install慢的办法
May 24 #Python
Django实现翻页的示例代码
May 24 #Python
pytorch--之halfTensor的使用详解
pandas DataFrame.shift()函数的具体使用
May 24 #Python
教你怎么用python实现字符串转日期
May 24 #Python
You might like
如何正确理解PHP的错误信息
2006/10/09 PHP
PHP学习之PHP变量
2006/10/09 PHP
php实现删除空目录的方法
2015/03/16 PHP
thinkPHP实现将excel导入到数据库中的方法
2016/04/22 PHP
动态表单验证的操作方法和TP框架里面的ajax表单验证
2017/07/19 PHP
Valerio 发布了 Mootools
2006/09/23 Javascript
js获取单元格自定义属性值的代码(IE/Firefox)
2010/04/05 Javascript
jquery简单瀑布流实现原理及ie8下测试代码
2013/01/23 Javascript
js将long日期格式转换为标准日期格式实现思路
2013/04/07 Javascript
js使用for循环及if语句判断多个一样的name
2014/09/09 Javascript
Google Maps API地图应用示例分享
2014/10/23 Javascript
js实现鼠标经过表格行变色的方法
2015/05/12 Javascript
jQuery实现查找链接文字替换属性的方法
2016/06/27 Javascript
详解BootStrap中Affix控件的使用及保持布局的美观的方法
2016/07/08 Javascript
jQuery的图片轮播插件PgwSlideshow使用详解
2016/08/11 Javascript
JavaScript常用正则函数用法示例
2017/01/23 Javascript
jQuery实现在新增加的元素上添加事件方法案例分析
2017/02/09 Javascript
vue中axios的二次封装实例讲解
2019/10/14 Javascript
详解JS预解析原理
2020/06/16 Javascript
[42:20]Winstrike vs VGJ.S 2018国际邀请赛淘汰赛BO3 第二场 8.23
2018/08/24 DOTA
Python实现多线程抓取妹子图
2015/08/08 Python
Python列表切片用法示例
2017/04/19 Python
利用Python实现kNN算法的代码
2019/08/16 Python
python实现最大优先队列
2019/08/29 Python
python3 Scrapy爬虫框架ip代理配置的方法
2020/01/17 Python
Python selenium爬取微信公众号文章代码详解
2020/08/12 Python
解决Python3.7.0 SSL低版本导致Pip无法使用问题
2020/09/03 Python
html5指南-1.html5全局属性(html5 global attributes)深入理解
2013/01/07 HTML / CSS
澳大利亚吉他在线:Artist Guitars
2017/03/30 全球购物
金士达面试非笔试
2012/03/14 面试题
党的群众路线教育实践活动总结报告
2014/04/28 职场文书
未婚证明书模板
2014/10/08 职场文书
python神经网络编程之手写数字识别
2021/05/08 Python
教你做个可爱的css滑动导航条
2021/06/15 HTML / CSS
剖析后OpLog订阅MongoDB的数据变更就没那么难了
2022/02/24 MongoDB
MySQL优化常用的19种有效方法(推荐!)
2022/03/17 MySQL