pytorch中的model=model.to(device)使用说明


Posted in Python onMay 24, 2021

这代表将模型加载到指定设备上。

其中,device=torch.device("cpu")代表的使用cpu,而device=torch.device("cuda")则代表的使用GPU。

当我们指定了设备之后,就需要将模型加载到相应设备中,此时需要使用model=model.to(device),将模型加载到相应的设备中。

将由GPU保存的模型加载到CPU上。

将torch.load()函数中的map_location参数设置为torch.device('cpu')

device = torch.device('cpu')
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH, map_location=device))

将由GPU保存的模型加载到GPU上。确保对输入的tensors调用input = input.to(device)方法。

device = torch.device("cuda")
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.to(device)

将由CPU保存的模型加载到GPU上。

确保对输入的tensors调用input = input.to(device)方法。map_location是将模型加载到GPU上,model.to(torch.device('cuda'))是将模型参数加载为CUDA的tensor。

最后保证使用.to(torch.device('cuda'))方法将需要使用的参数放入CUDA。

device = torch.device("cuda")
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH, map_location="cuda:0"))  # Choose whatever GPU device number you want
model.to(device)

补充:pytorch中model.to(device)和map_location=device的区别

一、简介

在已训练并保存在CPU上的GPU上加载模型时,加载模型时经常由于训练和保存模型时设备不同出现读取模型时出现错误,在对跨设备的模型读取时候涉及到两个参数的使用,分别是model.to(device)和map_location=devicel两个参数,简介一下两者的不同。

将map_location函数中的参数设置 torch.load()为 cuda:device_id。这会将模型加载到给定的GPU设备。

调用model.to(torch.device('cuda'))将模型的参数张量转换为CUDA张量,无论在cpu上训练还是gpu上训练,保存的模型参数都是参数张量不是cuda张量,因此,cpu设备上不需要使用torch.to(torch.device("cpu"))。

二、实例

了解了两者代表的意义,以下介绍两者的使用。

1、保存在GPU上,在CPU上加载

保存:

torch.save(model.state_dict(), PATH)

加载:

device = torch.device('cpu')
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH, map_location=device))

解释:

在使用GPU训练的CPU上加载模型时,请传递 torch.device('cpu')给map_location函数中的 torch.load()参数,使用map_location参数将张量下面的存储器动态地重新映射到CPU设备 。

2、保存在GPU上,在GPU上加载

保存:

torch.save(model.state_dict(), PATH)

加载:

device = torch.device("cuda")
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.to(device)
# Make sure to call input = input.to(device) on any input tensors that you feed to the model

解释:

在GPU上训练并保存在GPU上的模型时,只需将初始化model模型转换为CUDA优化模型即可model.to(torch.device('cuda'))。

此外,请务必.to(torch.device('cuda'))在所有模型输入上使用该 功能来准备模型的数据。

请注意,调用my_tensor.to(device) 返回my_tensorGPU上的新副本。

它不会覆盖 my_tensor。

因此,请记住手动覆盖张量: my_tensor = my_tensor.to(torch.device('cuda'))

3、保存在CPU,在GPU上加载

保存:

torch.save(model.state_dict(), PATH)

加载:

device = torch.device("cuda")
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH, map_location="cuda:0"))  # Choose whatever GPU device number you want
model.to(device)
# Make sure to call input = input.to(device) on any input tensors that you feed to the model

解释:

在已训练并保存在CPU上的GPU上加载模型时,请将map_location函数中的参数设置 torch.load()为 cuda:device_id。

这会将模型加载到给定的GPU设备。

接下来,请务必调用model.to(torch.device('cuda'))将模型的参数张量转换为CUDA张量。

最后,确保.to(torch.device('cuda'))在所有模型输入上使用该 函数来为CUDA优化模型准备数据。

请注意,调用 my_tensor.to(device)返回my_tensorGPU上的新副本。

它不会覆盖my_tensor。

因此,请记住手动覆盖张量:my_tensor = my_tensor.to(torch.device('cuda'))

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
请不要重复犯我在学习Python和Linux系统上的错误
Dec 12 Python
一个基于flask的web应用诞生 bootstrap框架美化(3)
Apr 11 Python
pandas.DataFrame选取/排除特定行的方法
Jul 03 Python
Python subprocess库的使用详解
Oct 26 Python
在IPython中执行Python程序文件的示例
Nov 01 Python
pygame游戏之旅 创建游戏窗口界面
Nov 20 Python
Python 3.6 中使用pdfminer解析pdf文件的实现
Sep 25 Python
win10环境下配置vscode python开发环境的教程详解
Oct 16 Python
Pandas操作CSV文件的读写实现方法
Nov 13 Python
python os.path.isfile()因参数问题判断错误的解决
Nov 29 Python
python opencv把一张图片嵌入(叠加)到另一张图片上的实现代码
Jun 11 Python
python中numpy数组与list相互转换实例方法
Jan 29 Python
解决pytorch-gpu 安装失败的记录
May 24 #Python
如何解决.cuda()加载用时很长的问题
一劳永逸彻底解决pip install慢的办法
May 24 #Python
Django实现翻页的示例代码
May 24 #Python
pytorch--之halfTensor的使用详解
pandas DataFrame.shift()函数的具体使用
May 24 #Python
教你怎么用python实现字符串转日期
May 24 #Python
You might like
2019年中国咖啡业现状与发展趋势
2021/03/04 咖啡文化
目录,文件操作详谈―PHP
2006/11/25 PHP
PHP Class&Object -- PHP 自排序二叉树的深入解析
2013/06/25 PHP
php 启动报错如何解决
2014/01/17 PHP
php实现用于验证所有类型的信用卡类
2015/03/24 PHP
PHP中应该避免使用同名变量(拆分临时变量)
2015/04/03 PHP
jQuery第三课 修改元素属性及内容的代码
2010/03/14 Javascript
使用JS进行目录上传(相当于批量上传)
2010/12/05 Javascript
Google (Local) Search API的简单使用介绍
2013/11/28 Javascript
AngularJS入门之动画
2016/07/27 Javascript
详解如何使用webpack打包Vue工程
2017/05/27 Javascript
vue2.0项目中使用Ueditor富文本编辑器示例代码
2017/08/14 Javascript
angularjs实现过滤并替换关键字小功能
2017/09/19 Javascript
在Vue中使用echarts的方法
2018/02/05 Javascript
react中实现搜索结果中关键词高亮显示
2018/07/31 Javascript
Windows下Node爬虫神器Puppeteer安装记
2019/01/09 Javascript
vue使用原生swiper代码实例
2020/02/05 Javascript
Vue路由的模块自动化与统一加载实现
2020/06/05 Javascript
ES6的循环与可迭代对象示例详解
2021/01/31 Javascript
Python中装饰器的一个妙用
2015/02/08 Python
python实现简单购物商城
2016/05/21 Python
老生常谈python中的重载
2018/11/11 Python
使用Python制作表情包实现换脸功能
2019/07/19 Python
Python小程序 控制鼠标循环点击代码实例
2019/10/08 Python
如何基于python实现归一化处理
2020/01/20 Python
Python3.7黑帽编程之病毒篇(基础篇)
2020/02/04 Python
Tensorflow全局设置可见GPU编号操作
2020/06/30 Python
利用scikitlearn画ROC曲线实例
2020/07/02 Python
Python用requests库爬取返回为空的解决办法
2021/02/21 Python
以下为Windows NT 下的32 位C++程序,请计算sizeof 的值
2016/12/07 面试题
内业资料员岗位职责
2014/01/04 职场文书
参观监狱警示教育心得体会
2016/01/15 职场文书
jQuery ajax - getScript() 方法和getJSON方法
2021/05/14 jQuery
python自动计算图像数据集的RGB均值
2021/06/18 Python
Python 数据科学 Matplotlib图库详解
2021/07/07 Python
Mybatis是这样防止sql注入的
2021/12/06 Java/Android