pytorch中的model=model.to(device)使用说明


Posted in Python onMay 24, 2021

这代表将模型加载到指定设备上。

其中,device=torch.device("cpu")代表的使用cpu,而device=torch.device("cuda")则代表的使用GPU。

当我们指定了设备之后,就需要将模型加载到相应设备中,此时需要使用model=model.to(device),将模型加载到相应的设备中。

将由GPU保存的模型加载到CPU上。

将torch.load()函数中的map_location参数设置为torch.device('cpu')

device = torch.device('cpu')
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH, map_location=device))

将由GPU保存的模型加载到GPU上。确保对输入的tensors调用input = input.to(device)方法。

device = torch.device("cuda")
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.to(device)

将由CPU保存的模型加载到GPU上。

确保对输入的tensors调用input = input.to(device)方法。map_location是将模型加载到GPU上,model.to(torch.device('cuda'))是将模型参数加载为CUDA的tensor。

最后保证使用.to(torch.device('cuda'))方法将需要使用的参数放入CUDA。

device = torch.device("cuda")
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH, map_location="cuda:0"))  # Choose whatever GPU device number you want
model.to(device)

补充:pytorch中model.to(device)和map_location=device的区别

一、简介

在已训练并保存在CPU上的GPU上加载模型时,加载模型时经常由于训练和保存模型时设备不同出现读取模型时出现错误,在对跨设备的模型读取时候涉及到两个参数的使用,分别是model.to(device)和map_location=devicel两个参数,简介一下两者的不同。

将map_location函数中的参数设置 torch.load()为 cuda:device_id。这会将模型加载到给定的GPU设备。

调用model.to(torch.device('cuda'))将模型的参数张量转换为CUDA张量,无论在cpu上训练还是gpu上训练,保存的模型参数都是参数张量不是cuda张量,因此,cpu设备上不需要使用torch.to(torch.device("cpu"))。

二、实例

了解了两者代表的意义,以下介绍两者的使用。

1、保存在GPU上,在CPU上加载

保存:

torch.save(model.state_dict(), PATH)

加载:

device = torch.device('cpu')
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH, map_location=device))

解释:

在使用GPU训练的CPU上加载模型时,请传递 torch.device('cpu')给map_location函数中的 torch.load()参数,使用map_location参数将张量下面的存储器动态地重新映射到CPU设备 。

2、保存在GPU上,在GPU上加载

保存:

torch.save(model.state_dict(), PATH)

加载:

device = torch.device("cuda")
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.to(device)
# Make sure to call input = input.to(device) on any input tensors that you feed to the model

解释:

在GPU上训练并保存在GPU上的模型时,只需将初始化model模型转换为CUDA优化模型即可model.to(torch.device('cuda'))。

此外,请务必.to(torch.device('cuda'))在所有模型输入上使用该 功能来准备模型的数据。

请注意,调用my_tensor.to(device) 返回my_tensorGPU上的新副本。

它不会覆盖 my_tensor。

因此,请记住手动覆盖张量: my_tensor = my_tensor.to(torch.device('cuda'))

3、保存在CPU,在GPU上加载

保存:

torch.save(model.state_dict(), PATH)

加载:

device = torch.device("cuda")
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH, map_location="cuda:0"))  # Choose whatever GPU device number you want
model.to(device)
# Make sure to call input = input.to(device) on any input tensors that you feed to the model

解释:

在已训练并保存在CPU上的GPU上加载模型时,请将map_location函数中的参数设置 torch.load()为 cuda:device_id。

这会将模型加载到给定的GPU设备。

接下来,请务必调用model.to(torch.device('cuda'))将模型的参数张量转换为CUDA张量。

最后,确保.to(torch.device('cuda'))在所有模型输入上使用该 函数来为CUDA优化模型准备数据。

请注意,调用 my_tensor.to(device)返回my_tensorGPU上的新副本。

它不会覆盖my_tensor。

因此,请记住手动覆盖张量:my_tensor = my_tensor.to(torch.device('cuda'))

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python中sets模块的用法实例
Sep 30 Python
Python实现的多线程端口扫描工具分享
Jan 21 Python
python的Crypto模块实现AES加密实例代码
Jan 22 Python
python如何修改装饰器中参数
Mar 20 Python
不知道这5种下划线的含义,你就不算真的会Python!
Oct 09 Python
Python SMTP发送邮件遇到的一些问题及解决办法
Oct 24 Python
pytorch中的embedding词向量的使用方法
Aug 18 Python
python 实现return返回多个值
Nov 19 Python
Python OpenCV读取显示视频的方法示例
Feb 20 Python
详解python常用命令行选项与环境变量
Feb 20 Python
Python3 pickle对象串行化代码实例解析
Mar 23 Python
基于Django快速集成Echarts代码示例
Dec 01 Python
解决pytorch-gpu 安装失败的记录
May 24 #Python
如何解决.cuda()加载用时很长的问题
一劳永逸彻底解决pip install慢的办法
May 24 #Python
Django实现翻页的示例代码
May 24 #Python
pytorch--之halfTensor的使用详解
pandas DataFrame.shift()函数的具体使用
May 24 #Python
教你怎么用python实现字符串转日期
May 24 #Python
You might like
php中数据的批量导入(csv文件)
2006/10/09 PHP
linux下使用ThinkPHP需要注意大小写导致的问题
2011/08/02 PHP
测试JavaScript字符串处理性能的代码
2009/12/07 Javascript
一个简单的js渐显(fadeIn)渐隐(fadeOut)类
2010/06/19 Javascript
基于jquery的一个OutlookBar类,动态创建导航条
2010/11/19 Javascript
JQuery中dataGrid设置行的高度示例代码
2014/01/03 Javascript
jQuery过滤选择器:not()方法使用介绍
2014/04/20 Javascript
jQuery oLoader实现的加载图片和页面效果
2015/03/14 Javascript
jQuery实现带滑动条的菜单效果代码
2015/08/26 Javascript
js获取及修改网页背景色和字体色的方法
2015/12/29 Javascript
Jquery实现遮罩层的简单实例(就是弹出DIV周围都灰色不能操作)
2016/07/14 Javascript
微信小程序-图片、录音、音频播放、音乐播放、视频、文件代码实例
2016/11/22 Javascript
详解nodeJS之路径PATH模块
2017/05/31 NodeJs
关于vue单文件中引用路径的处理方法
2018/01/08 Javascript
vue中实现图片和文件上传的示例代码
2018/03/16 Javascript
Vue2.0实现调用摄像头进行拍照功能 exif.js实现图片上传功能
2018/04/28 Javascript
VUE:vuex 用户登录信息的数据写入与获取方式
2019/11/11 Javascript
[06:37]2014DOTA2国际邀请赛 昔日王者渴望重回巅峰
2014/07/12 DOTA
Python抓取京东图书评论数据
2014/08/31 Python
windows下安装python的C扩展编译环境(解决Unable to find vcvarsall.bat)
2018/02/21 Python
对Pyhon实现静态变量全局变量的方法详解
2019/01/11 Python
用OpenCV进行年龄和性别检测的实现示例
2021/01/29 Python
Pycharm 设置默认解释器路径和编码格式的操作
2021/02/05 Python
索尼巴西商店:Sony巴西
2019/06/21 全球购物
美国亚马逊旗下时尚女装网店:SHOPBOP(支持中文)
2020/10/17 全球购物
医务人员竞聘职务自我评价分享
2013/11/08 职场文书
小学信息技术教学反思
2014/02/10 职场文书
企业员工培训感言
2014/02/26 职场文书
学校创先争优活动总结
2014/08/28 职场文书
市贸粮局召开党的群众路线教育实践活动总结大会新闻稿
2014/10/21 职场文书
党的群众路线教育实践活动个人对照检查材料(四风)
2014/11/05 职场文书
模范教师材料大全
2014/12/16 职场文书
2016春季校长开学典礼致辞
2015/11/26 职场文书
三年级作文之小小梦想
2019/12/06 职场文书
HTML速写之Emmet语法规则的实现
2021/04/07 HTML / CSS
python实现简易自习室座位预约系统
2021/06/30 Python