pytorch中的model=model.to(device)使用说明


Posted in Python onMay 24, 2021

这代表将模型加载到指定设备上。

其中,device=torch.device("cpu")代表的使用cpu,而device=torch.device("cuda")则代表的使用GPU。

当我们指定了设备之后,就需要将模型加载到相应设备中,此时需要使用model=model.to(device),将模型加载到相应的设备中。

将由GPU保存的模型加载到CPU上。

将torch.load()函数中的map_location参数设置为torch.device('cpu')

device = torch.device('cpu')
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH, map_location=device))

将由GPU保存的模型加载到GPU上。确保对输入的tensors调用input = input.to(device)方法。

device = torch.device("cuda")
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.to(device)

将由CPU保存的模型加载到GPU上。

确保对输入的tensors调用input = input.to(device)方法。map_location是将模型加载到GPU上,model.to(torch.device('cuda'))是将模型参数加载为CUDA的tensor。

最后保证使用.to(torch.device('cuda'))方法将需要使用的参数放入CUDA。

device = torch.device("cuda")
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH, map_location="cuda:0"))  # Choose whatever GPU device number you want
model.to(device)

补充:pytorch中model.to(device)和map_location=device的区别

一、简介

在已训练并保存在CPU上的GPU上加载模型时,加载模型时经常由于训练和保存模型时设备不同出现读取模型时出现错误,在对跨设备的模型读取时候涉及到两个参数的使用,分别是model.to(device)和map_location=devicel两个参数,简介一下两者的不同。

将map_location函数中的参数设置 torch.load()为 cuda:device_id。这会将模型加载到给定的GPU设备。

调用model.to(torch.device('cuda'))将模型的参数张量转换为CUDA张量,无论在cpu上训练还是gpu上训练,保存的模型参数都是参数张量不是cuda张量,因此,cpu设备上不需要使用torch.to(torch.device("cpu"))。

二、实例

了解了两者代表的意义,以下介绍两者的使用。

1、保存在GPU上,在CPU上加载

保存:

torch.save(model.state_dict(), PATH)

加载:

device = torch.device('cpu')
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH, map_location=device))

解释:

在使用GPU训练的CPU上加载模型时,请传递 torch.device('cpu')给map_location函数中的 torch.load()参数,使用map_location参数将张量下面的存储器动态地重新映射到CPU设备 。

2、保存在GPU上,在GPU上加载

保存:

torch.save(model.state_dict(), PATH)

加载:

device = torch.device("cuda")
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.to(device)
# Make sure to call input = input.to(device) on any input tensors that you feed to the model

解释:

在GPU上训练并保存在GPU上的模型时,只需将初始化model模型转换为CUDA优化模型即可model.to(torch.device('cuda'))。

此外,请务必.to(torch.device('cuda'))在所有模型输入上使用该 功能来准备模型的数据。

请注意,调用my_tensor.to(device) 返回my_tensorGPU上的新副本。

它不会覆盖 my_tensor。

因此,请记住手动覆盖张量: my_tensor = my_tensor.to(torch.device('cuda'))

3、保存在CPU,在GPU上加载

保存:

torch.save(model.state_dict(), PATH)

加载:

device = torch.device("cuda")
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH, map_location="cuda:0"))  # Choose whatever GPU device number you want
model.to(device)
# Make sure to call input = input.to(device) on any input tensors that you feed to the model

解释:

在已训练并保存在CPU上的GPU上加载模型时,请将map_location函数中的参数设置 torch.load()为 cuda:device_id。

这会将模型加载到给定的GPU设备。

接下来,请务必调用model.to(torch.device('cuda'))将模型的参数张量转换为CUDA张量。

最后,确保.to(torch.device('cuda'))在所有模型输入上使用该 函数来为CUDA优化模型准备数据。

请注意,调用 my_tensor.to(device)返回my_tensorGPU上的新副本。

它不会覆盖my_tensor。

因此,请记住手动覆盖张量:my_tensor = my_tensor.to(torch.device('cuda'))

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python中optionParser模块的使用方法实例教程
Aug 29 Python
一波神奇的Python语句、函数与方法的使用技巧总结
Dec 08 Python
Python多进程分块读取超大文件的方法
Apr 13 Python
Python黑帽编程 3.4 跨越VLAN详解
Sep 28 Python
使用apidocJs快速生成在线文档的实例讲解
Feb 07 Python
python输出100以内的质数与合数实例代码
Jul 08 Python
Python实现随机漫步功能
Jul 09 Python
Python实现 版本号对比功能的实例代码
Apr 18 Python
将python字符串转化成长表达式的函数eval实例
May 11 Python
有关pycharm登录github时有的时候会报错connection reset的问题
Sep 15 Python
使用Python实现音频双通道分离
Dec 25 Python
sklearn中的交叉验证的实现(Cross-Validation)
Feb 22 Python
解决pytorch-gpu 安装失败的记录
May 24 #Python
如何解决.cuda()加载用时很长的问题
一劳永逸彻底解决pip install慢的办法
May 24 #Python
Django实现翻页的示例代码
May 24 #Python
pytorch--之halfTensor的使用详解
pandas DataFrame.shift()函数的具体使用
May 24 #Python
教你怎么用python实现字符串转日期
May 24 #Python
You might like
php程序之die调试法 快速解决错误
2009/09/17 PHP
php二分查找二种实现示例
2014/03/12 PHP
php使用环形链表解决约瑟夫问题完整示例
2018/08/07 PHP
php实现快速对二维数组某一列进行组装的方法小结
2019/12/04 PHP
提升你网站水平的jQuery插件集合推荐
2011/04/19 Javascript
JavaScript动态设置div的样式的方法
2015/12/26 Javascript
Bootstrap入门书籍之(四)菜单、按钮及导航
2016/02/17 Javascript
总结javascript中的六种迭代器
2016/08/16 Javascript
详解基于angular路由的requireJs按需加载js
2017/01/20 Javascript
JQuery和html+css实现带小圆点和左右按钮的轮播图实例
2017/07/22 jQuery
原生javascript实现文件异步上传的实例讲解
2017/10/26 Javascript
修改UA在PC中访问只能在微信中打开的链接方法
2017/11/27 Javascript
详解刷新页面vuex数据不消失和不跳转页面的解决
2018/01/30 Javascript
微信小程序自定义tab实现多层tab嵌套功能
2018/06/15 Javascript
vue中设置height:100%无效的问题及解决方法
2018/07/27 Javascript
vue左侧菜单,树形图递归实现代码
2018/08/24 Javascript
JS实现方形抽奖效果
2018/08/27 Javascript
详解Vue2.0组件的继承与扩展
2018/11/23 Javascript
vue集成chart.js的实现方法
2019/08/20 Javascript
Vue3.0中的monorepo管理模式的实现
2019/10/14 Javascript
Python实现分段线性插值
2018/12/17 Python
Python图像处理实现两幅图像合成一幅图像的方法【测试可用】
2019/01/04 Python
对Python捕获控制台输出流的方法详解
2019/01/07 Python
python中使用 xlwt 操作excel的常见方法与问题
2019/01/13 Python
python matplotlib包图像配色方案分享
2020/03/14 Python
如何将PySpark导入Python的放实现(2种)
2020/04/26 Python
客服文员岗位职责
2013/11/29 职场文书
哈弗商学院毕业生求职信
2014/02/26 职场文书
大学生学习2014全国两会心得体会
2014/03/13 职场文书
志愿者活动总结
2014/04/28 职场文书
歌唱比赛策划方案
2014/06/06 职场文书
企业贷款委托书格式
2014/09/12 职场文书
现场施工员岗位职责
2015/04/11 职场文书
Java spring单点登录系统
2021/09/04 Java/Android
4种方法python批量修改替换列表中元素
2022/04/07 Python
python如何利用cv2.rectangle()绘制矩形框
2022/12/24 Python