pytorch中的model=model.to(device)使用说明


Posted in Python onMay 24, 2021

这代表将模型加载到指定设备上。

其中,device=torch.device("cpu")代表的使用cpu,而device=torch.device("cuda")则代表的使用GPU。

当我们指定了设备之后,就需要将模型加载到相应设备中,此时需要使用model=model.to(device),将模型加载到相应的设备中。

将由GPU保存的模型加载到CPU上。

将torch.load()函数中的map_location参数设置为torch.device('cpu')

device = torch.device('cpu')
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH, map_location=device))

将由GPU保存的模型加载到GPU上。确保对输入的tensors调用input = input.to(device)方法。

device = torch.device("cuda")
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.to(device)

将由CPU保存的模型加载到GPU上。

确保对输入的tensors调用input = input.to(device)方法。map_location是将模型加载到GPU上,model.to(torch.device('cuda'))是将模型参数加载为CUDA的tensor。

最后保证使用.to(torch.device('cuda'))方法将需要使用的参数放入CUDA。

device = torch.device("cuda")
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH, map_location="cuda:0"))  # Choose whatever GPU device number you want
model.to(device)

补充:pytorch中model.to(device)和map_location=device的区别

一、简介

在已训练并保存在CPU上的GPU上加载模型时,加载模型时经常由于训练和保存模型时设备不同出现读取模型时出现错误,在对跨设备的模型读取时候涉及到两个参数的使用,分别是model.to(device)和map_location=devicel两个参数,简介一下两者的不同。

将map_location函数中的参数设置 torch.load()为 cuda:device_id。这会将模型加载到给定的GPU设备。

调用model.to(torch.device('cuda'))将模型的参数张量转换为CUDA张量,无论在cpu上训练还是gpu上训练,保存的模型参数都是参数张量不是cuda张量,因此,cpu设备上不需要使用torch.to(torch.device("cpu"))。

二、实例

了解了两者代表的意义,以下介绍两者的使用。

1、保存在GPU上,在CPU上加载

保存:

torch.save(model.state_dict(), PATH)

加载:

device = torch.device('cpu')
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH, map_location=device))

解释:

在使用GPU训练的CPU上加载模型时,请传递 torch.device('cpu')给map_location函数中的 torch.load()参数,使用map_location参数将张量下面的存储器动态地重新映射到CPU设备 。

2、保存在GPU上,在GPU上加载

保存:

torch.save(model.state_dict(), PATH)

加载:

device = torch.device("cuda")
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.to(device)
# Make sure to call input = input.to(device) on any input tensors that you feed to the model

解释:

在GPU上训练并保存在GPU上的模型时,只需将初始化model模型转换为CUDA优化模型即可model.to(torch.device('cuda'))。

此外,请务必.to(torch.device('cuda'))在所有模型输入上使用该 功能来准备模型的数据。

请注意,调用my_tensor.to(device) 返回my_tensorGPU上的新副本。

它不会覆盖 my_tensor。

因此,请记住手动覆盖张量: my_tensor = my_tensor.to(torch.device('cuda'))

3、保存在CPU,在GPU上加载

保存:

torch.save(model.state_dict(), PATH)

加载:

device = torch.device("cuda")
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH, map_location="cuda:0"))  # Choose whatever GPU device number you want
model.to(device)
# Make sure to call input = input.to(device) on any input tensors that you feed to the model

解释:

在已训练并保存在CPU上的GPU上加载模型时,请将map_location函数中的参数设置 torch.load()为 cuda:device_id。

这会将模型加载到给定的GPU设备。

接下来,请务必调用model.to(torch.device('cuda'))将模型的参数张量转换为CUDA张量。

最后,确保.to(torch.device('cuda'))在所有模型输入上使用该 函数来为CUDA优化模型准备数据。

请注意,调用 my_tensor.to(device)返回my_tensorGPU上的新副本。

它不会覆盖my_tensor。

因此,请记住手动覆盖张量:my_tensor = my_tensor.to(torch.device('cuda'))

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
使用Python实现下载网易云音乐的高清MV
Mar 16 Python
基于python的Tkinter实现一个简易计算器
Dec 31 Python
python中使用ctypes调用so传参设置遇到的问题及解决方法
Jun 19 Python
python opencv 图像拼接的实现方法
Jun 27 Python
使用 Python 快速实现 HTTP 和 FTP 服务器的方法
Jul 22 Python
python super用法及原理详解
Jan 20 Python
Keras load_model 导入错误的解决方式
Jun 09 Python
Python常用类型转换实现代码实例
Jul 28 Python
Python脚本打包成可执行文件过程解析
Oct 20 Python
利用Python实现学生信息管理系统的完整实例
Dec 30 Python
Requests什么的通通爬不了的Python超强反爬虫方案!
May 20 Python
利用python Pandas实现批量拆分Excel与合并Excel
May 23 Python
解决pytorch-gpu 安装失败的记录
May 24 #Python
如何解决.cuda()加载用时很长的问题
一劳永逸彻底解决pip install慢的办法
May 24 #Python
Django实现翻页的示例代码
May 24 #Python
pytorch--之halfTensor的使用详解
pandas DataFrame.shift()函数的具体使用
May 24 #Python
教你怎么用python实现字符串转日期
May 24 #Python
You might like
php中simplexml_load_string使用实例分享
2014/02/13 PHP
php 问卷调查结果统计
2015/10/08 PHP
PHP安装GeoIP扩展根据IP获取地理位置及计算距离的方法
2016/07/01 PHP
laravel框架模型中非静态方法也能静态调用的原理分析
2019/11/23 PHP
php生成随机数/生成随机字符串的方法小结【5种方法】
2020/05/27 PHP
20个最新的jQuery插件
2012/01/13 Javascript
JavaScript作用域链示例分享
2014/05/27 Javascript
javascript动态修改Li节点值的方法
2015/01/20 Javascript
JavaScript中实现sprintf、printf函数
2015/01/27 Javascript
详解JavaScript中的表单验证
2015/06/16 Javascript
js实现图片上传并正常显示
2015/12/19 Javascript
javascript鼠标滑过显示二级菜单特效
2020/11/18 Javascript
Bootstrap table 定制提示语的加载过程
2017/02/20 Javascript
简单谈谈JS中的正则表达式
2017/09/11 Javascript
vue keep-alive请求数据的方法示例
2018/05/16 Javascript
iview通过Dropdown(下拉菜单)实现的右键菜单
2018/10/26 Javascript
Vue.js如何使用Socket.IO的示例代码
2019/09/05 Javascript
nodejs使用socket5进行代理请求的实现
2020/02/21 NodeJs
Python使用re模块正则提取字符串中括号内的内容示例
2018/06/01 Python
Python 批量刷博客园访问量脚本过程解析
2019/08/30 Python
python实现把二维列表变为一维列表的方法分析
2019/10/08 Python
Python使用循环神经网络解决文本分类问题的方法详解
2020/01/16 Python
python内打印变量之%和f的实例
2020/02/19 Python
树莓派升级python的具体步骤
2020/07/05 Python
什么是Python包的循环导入
2020/09/08 Python
HTML5本地存储和本地数据库实例详解
2017/09/05 HTML / CSS
联想新加坡官方网站:Lenovo Singapore
2017/10/24 全球购物
澳大利亚领先的时尚内衣零售商:Bras N Things
2020/07/28 全球购物
MIKI HOUSE美国官方网上商店:日本领先的婴儿和儿童高级时装品牌
2020/06/21 全球购物
瑞士男士时尚网上商店:Babista
2020/05/14 全球购物
一套软件测试笔试题
2014/07/25 面试题
《浅水洼里的小鱼》听课反思
2014/02/28 职场文书
《赠汪伦》教学反思
2014/04/12 职场文书
超市商业计划书
2014/05/04 职场文书
tomcat默认最大连接数及相关调整方法
2022/05/06 Servers
Android开发手册TextInputLayout样式使用示例
2022/06/10 Java/Android