pytorch中的model=model.to(device)使用说明


Posted in Python onMay 24, 2021

这代表将模型加载到指定设备上。

其中,device=torch.device("cpu")代表的使用cpu,而device=torch.device("cuda")则代表的使用GPU。

当我们指定了设备之后,就需要将模型加载到相应设备中,此时需要使用model=model.to(device),将模型加载到相应的设备中。

将由GPU保存的模型加载到CPU上。

将torch.load()函数中的map_location参数设置为torch.device('cpu')

device = torch.device('cpu')
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH, map_location=device))

将由GPU保存的模型加载到GPU上。确保对输入的tensors调用input = input.to(device)方法。

device = torch.device("cuda")
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.to(device)

将由CPU保存的模型加载到GPU上。

确保对输入的tensors调用input = input.to(device)方法。map_location是将模型加载到GPU上,model.to(torch.device('cuda'))是将模型参数加载为CUDA的tensor。

最后保证使用.to(torch.device('cuda'))方法将需要使用的参数放入CUDA。

device = torch.device("cuda")
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH, map_location="cuda:0"))  # Choose whatever GPU device number you want
model.to(device)

补充:pytorch中model.to(device)和map_location=device的区别

一、简介

在已训练并保存在CPU上的GPU上加载模型时,加载模型时经常由于训练和保存模型时设备不同出现读取模型时出现错误,在对跨设备的模型读取时候涉及到两个参数的使用,分别是model.to(device)和map_location=devicel两个参数,简介一下两者的不同。

将map_location函数中的参数设置 torch.load()为 cuda:device_id。这会将模型加载到给定的GPU设备。

调用model.to(torch.device('cuda'))将模型的参数张量转换为CUDA张量,无论在cpu上训练还是gpu上训练,保存的模型参数都是参数张量不是cuda张量,因此,cpu设备上不需要使用torch.to(torch.device("cpu"))。

二、实例

了解了两者代表的意义,以下介绍两者的使用。

1、保存在GPU上,在CPU上加载

保存:

torch.save(model.state_dict(), PATH)

加载:

device = torch.device('cpu')
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH, map_location=device))

解释:

在使用GPU训练的CPU上加载模型时,请传递 torch.device('cpu')给map_location函数中的 torch.load()参数,使用map_location参数将张量下面的存储器动态地重新映射到CPU设备 。

2、保存在GPU上,在GPU上加载

保存:

torch.save(model.state_dict(), PATH)

加载:

device = torch.device("cuda")
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.to(device)
# Make sure to call input = input.to(device) on any input tensors that you feed to the model

解释:

在GPU上训练并保存在GPU上的模型时,只需将初始化model模型转换为CUDA优化模型即可model.to(torch.device('cuda'))。

此外,请务必.to(torch.device('cuda'))在所有模型输入上使用该 功能来准备模型的数据。

请注意,调用my_tensor.to(device) 返回my_tensorGPU上的新副本。

它不会覆盖 my_tensor。

因此,请记住手动覆盖张量: my_tensor = my_tensor.to(torch.device('cuda'))

3、保存在CPU,在GPU上加载

保存:

torch.save(model.state_dict(), PATH)

加载:

device = torch.device("cuda")
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH, map_location="cuda:0"))  # Choose whatever GPU device number you want
model.to(device)
# Make sure to call input = input.to(device) on any input tensors that you feed to the model

解释:

在已训练并保存在CPU上的GPU上加载模型时,请将map_location函数中的参数设置 torch.load()为 cuda:device_id。

这会将模型加载到给定的GPU设备。

接下来,请务必调用model.to(torch.device('cuda'))将模型的参数张量转换为CUDA张量。

最后,确保.to(torch.device('cuda'))在所有模型输入上使用该 函数来为CUDA优化模型准备数据。

请注意,调用 my_tensor.to(device)返回my_tensorGPU上的新副本。

它不会覆盖my_tensor。

因此,请记住手动覆盖张量:my_tensor = my_tensor.to(torch.device('cuda'))

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python实现单词拼写检查
Apr 25 Python
Python实现统计文本文件字数的方法
May 05 Python
python: line=f.readlines()消除line中\n的方法
Mar 19 Python
Anaconda2下实现Python2.7和Python3.5的共存方法
Jun 11 Python
深入浅析Python的类
Jun 22 Python
对python读写文件去重、RE、set的使用详解
Dec 11 Python
解决pyinstaller打包发布后的exe文件打开控制台闪退的问题
Jun 21 Python
django使用django-apscheduler 实现定时任务的例子
Jul 20 Python
用 Python 制作地球仪的方法
Apr 24 Python
Python sublime安装及配置过程详解
Jun 29 Python
Python通过yagmail实现发送邮件代码解析
Oct 27 Python
python爬取企查查企业信息之selenium自动模拟登录企查查
Apr 08 Python
解决pytorch-gpu 安装失败的记录
May 24 #Python
如何解决.cuda()加载用时很长的问题
一劳永逸彻底解决pip install慢的办法
May 24 #Python
Django实现翻页的示例代码
May 24 #Python
pytorch--之halfTensor的使用详解
pandas DataFrame.shift()函数的具体使用
May 24 #Python
教你怎么用python实现字符串转日期
May 24 #Python
You might like
PHP 编程请选择正确的文本编辑软件
2006/12/21 PHP
PHP中的integer类型使用分析
2010/07/27 PHP
超强的IE背景图片闪烁(抖动)的解决办法
2007/09/09 Javascript
SyntaxHighlighter代码加色使用方法
2008/09/07 Javascript
一步一步制作jquery插件Tabs实现过程
2010/07/06 Javascript
js播放wav文件(源码)
2013/04/22 Javascript
封装的jquery翻页滚动(示例代码)
2013/11/18 Javascript
jQuery实现仿Google首页拖动效果的方法
2015/05/04 Javascript
js实现点击按钮后给Div图层设置随机背景颜色的方法
2015/05/06 Javascript
JQuery中层次选择器用法实例详解
2015/05/18 Javascript
基于JS实现数字+字母+中文的混合排序方法
2016/06/06 Javascript
详解Jquery的事件操作和文档操作
2016/12/19 Javascript
Vue常用指令V-model用法
2017/03/08 Javascript
详解Vue.js之视图和数据的双向绑定(v-model)
2017/06/23 Javascript
详解vue数据渲染出现闪烁问题
2017/06/29 Javascript
jQuery实现用户信息表格的添加和删除功能
2017/09/12 jQuery
微信小程序分享功能之按钮button 边框隐藏和点击隐藏
2018/06/14 Javascript
在react中使用vuex的示例代码
2018/07/30 Javascript
微信小程序环境下将文件上传到OSS的方法步骤
2019/05/31 Javascript
如何基于jQuery实现五角星评分
2020/09/02 jQuery
详解vue中使用transition和animation的实例代码
2020/12/12 Vue.js
[02:16]DOTA2超级联赛专访Burning 逆袭需要抓住机会
2013/06/24 DOTA
[01:18:31]DOTA2-DPC中国联赛定级赛 LBZS vs Magma BO3第一场 1月10日
2021/03/11 DOTA
Python中使用第三方库xlrd来写入Excel文件示例
2015/04/05 Python
Python内存管理方式和垃圾回收算法解析
2017/11/11 Python
python中pylint使用方法(pylint代码检查)
2018/04/06 Python
python使用matplotlib画柱状图、散点图
2019/03/18 Python
python实现多线程端口扫描
2019/08/31 Python
python实现滑雪者小游戏
2020/02/22 Python
Python多线程threading join和守护线程setDeamon原理详解
2020/03/18 Python
Python 判断时间是否在时间区间内的实例
2020/05/16 Python
HTML5无刷新改变当前url的代码
2017/03/15 HTML / CSS
水产养殖学应届生求职信
2013/09/29 职场文书
项目经理岗位职责
2015/01/31 职场文书
辞职报告(范文三篇)
2019/08/27 职场文书
idea搭建可运行Servlet的Web项目
2021/06/26 Java/Android