pytorch中的model=model.to(device)使用说明


Posted in Python onMay 24, 2021

这代表将模型加载到指定设备上。

其中,device=torch.device("cpu")代表的使用cpu,而device=torch.device("cuda")则代表的使用GPU。

当我们指定了设备之后,就需要将模型加载到相应设备中,此时需要使用model=model.to(device),将模型加载到相应的设备中。

将由GPU保存的模型加载到CPU上。

将torch.load()函数中的map_location参数设置为torch.device('cpu')

device = torch.device('cpu')
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH, map_location=device))

将由GPU保存的模型加载到GPU上。确保对输入的tensors调用input = input.to(device)方法。

device = torch.device("cuda")
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.to(device)

将由CPU保存的模型加载到GPU上。

确保对输入的tensors调用input = input.to(device)方法。map_location是将模型加载到GPU上,model.to(torch.device('cuda'))是将模型参数加载为CUDA的tensor。

最后保证使用.to(torch.device('cuda'))方法将需要使用的参数放入CUDA。

device = torch.device("cuda")
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH, map_location="cuda:0"))  # Choose whatever GPU device number you want
model.to(device)

补充:pytorch中model.to(device)和map_location=device的区别

一、简介

在已训练并保存在CPU上的GPU上加载模型时,加载模型时经常由于训练和保存模型时设备不同出现读取模型时出现错误,在对跨设备的模型读取时候涉及到两个参数的使用,分别是model.to(device)和map_location=devicel两个参数,简介一下两者的不同。

将map_location函数中的参数设置 torch.load()为 cuda:device_id。这会将模型加载到给定的GPU设备。

调用model.to(torch.device('cuda'))将模型的参数张量转换为CUDA张量,无论在cpu上训练还是gpu上训练,保存的模型参数都是参数张量不是cuda张量,因此,cpu设备上不需要使用torch.to(torch.device("cpu"))。

二、实例

了解了两者代表的意义,以下介绍两者的使用。

1、保存在GPU上,在CPU上加载

保存:

torch.save(model.state_dict(), PATH)

加载:

device = torch.device('cpu')
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH, map_location=device))

解释:

在使用GPU训练的CPU上加载模型时,请传递 torch.device('cpu')给map_location函数中的 torch.load()参数,使用map_location参数将张量下面的存储器动态地重新映射到CPU设备 。

2、保存在GPU上,在GPU上加载

保存:

torch.save(model.state_dict(), PATH)

加载:

device = torch.device("cuda")
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.to(device)
# Make sure to call input = input.to(device) on any input tensors that you feed to the model

解释:

在GPU上训练并保存在GPU上的模型时,只需将初始化model模型转换为CUDA优化模型即可model.to(torch.device('cuda'))。

此外,请务必.to(torch.device('cuda'))在所有模型输入上使用该 功能来准备模型的数据。

请注意,调用my_tensor.to(device) 返回my_tensorGPU上的新副本。

它不会覆盖 my_tensor。

因此,请记住手动覆盖张量: my_tensor = my_tensor.to(torch.device('cuda'))

3、保存在CPU,在GPU上加载

保存:

torch.save(model.state_dict(), PATH)

加载:

device = torch.device("cuda")
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH, map_location="cuda:0"))  # Choose whatever GPU device number you want
model.to(device)
# Make sure to call input = input.to(device) on any input tensors that you feed to the model

解释:

在已训练并保存在CPU上的GPU上加载模型时,请将map_location函数中的参数设置 torch.load()为 cuda:device_id。

这会将模型加载到给定的GPU设备。

接下来,请务必调用model.to(torch.device('cuda'))将模型的参数张量转换为CUDA张量。

最后,确保.to(torch.device('cuda'))在所有模型输入上使用该 函数来为CUDA优化模型准备数据。

请注意,调用 my_tensor.to(device)返回my_tensorGPU上的新副本。

它不会覆盖my_tensor。

因此,请记住手动覆盖张量:my_tensor = my_tensor.to(torch.device('cuda'))

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python文件写入实例分析
Apr 08 Python
Python使用pygame模块编写俄罗斯方块游戏的代码实例
Dec 08 Python
深入理解NumPy简明教程---数组3(组合)
Dec 17 Python
python 读取txt,json和hdf5文件的实例
Jun 05 Python
Python3 jupyter notebook 服务器搭建过程
Nov 30 Python
在Python中字符串、列表、元组、字典之间的相互转换
Nov 15 Python
python:动态路由的Flask程序代码
Nov 22 Python
Python读取excel文件中带公式的值的实现
Apr 17 Python
python如何随机生成高强度密码
Aug 19 Python
如何用 Python 制作一个迷宫游戏
Feb 25 Python
Python爬虫爬取全球疫情数据并存储到mysql数据库的步骤
Mar 29 Python
Python Pygame实战在打砖块游戏的实现
Mar 17 Python
解决pytorch-gpu 安装失败的记录
May 24 #Python
如何解决.cuda()加载用时很长的问题
一劳永逸彻底解决pip install慢的办法
May 24 #Python
Django实现翻页的示例代码
May 24 #Python
pytorch--之halfTensor的使用详解
pandas DataFrame.shift()函数的具体使用
May 24 #Python
教你怎么用python实现字符串转日期
May 24 #Python
You might like
php中文本数据翻页(留言本翻页)
2006/10/09 PHP
PHP初学者最感迷茫的问题小结
2010/03/27 PHP
浅析PHP微信支付通知的处理方式
2014/05/25 PHP
PHP基于phpqrcode类生成二维码的方法详解
2018/03/14 PHP
零基础php编程好学吗
2019/10/11 PHP
laravel框架select2多选插件初始化默认选中项操作示例
2020/02/18 PHP
JQuery Tips(4) 一些关于提高JQuery性能的Tips
2009/12/19 Javascript
快速排序 php与javascript的不同之处
2011/02/22 Javascript
javascript特殊用法示例介绍
2013/11/29 Javascript
浮动的div自适应居中显示的js代码
2013/12/23 Javascript
jQuery中change事件用法实例
2014/12/26 Javascript
JavaScript实现在标题栏上显示当前日期的方法
2015/03/19 Javascript
jQuery日程管理插件fullcalendar使用详解
2017/01/07 Javascript
Angular JS 生成动态二维码的方法
2017/02/23 Javascript
VUE2 前端实现 静态二级省市联动选择select的示例
2018/02/09 Javascript
vue中轮训器的使用
2019/01/27 Javascript
ES6 Symbol数据类型的应用实例分析
2019/06/26 Javascript
在Uni中使用Vue的EventBus总线机制操作
2020/07/31 Javascript
python 字符串split的用法分享
2013/03/23 Python
使用Python3 编写简单信用卡管理程序
2016/12/21 Python
WIn10+Anaconda环境下安装PyTorch(避坑指南)
2019/01/30 Python
基于django 的orm中非主键自增的实现方式
2020/05/18 Python
python 写函数在一定条件下需要调用自身时的写法说明
2020/06/01 Python
PyQt5的相对布局管理的实现
2020/08/07 Python
HTML5混合开发二维码扫描以及调用本地摄像头
2017/12/27 HTML / CSS
佳能加拿大网上商店:Canon eStore Canada
2018/04/04 全球购物
英国排名第一的宠物店:PetPlanet
2020/02/02 全球购物
市政施工员自我鉴定
2014/01/15 职场文书
2014年安全生产责任书
2014/07/22 职场文书
秋冬农业生产标语
2014/10/09 职场文书
应聘教师求职信范文
2015/03/20 职场文书
小王子读书笔记
2015/06/29 职场文书
校长新学期寄语2016
2015/12/04 职场文书
HTML+CSS实现导航条下拉菜单的示例代码
2021/08/02 HTML / CSS
如何利用React实现图片识别App
2022/02/18 Javascript
如何优化vue打包文件过大
2022/04/13 Vue.js