使用tensorflow根据输入更改tensor shape


Posted in Python onJune 23, 2020

涉及随机数以及类RNN的网络构建常常需要根据输入shape,决定中间变量的shape或步长。

tf.shape函数不同于tensor.shape.as_list()函数,后者返回的是常值list,而前者返回的是tensor。

使用tf.shape函数可以使得中间变量的tensor形状随输入变化,不需要在构建Graph的时候指定。但对于tf.Variable,因为需要提前分配固定空间,其shape无法通过上诉方法设定。

实例代码如下:

a = tf.placeholder(tf.float32,[None,])
b = tf.random_normal(tf.concat([tf.shape(a),[2,]],axis=0))

补充知识:pytorch中model=model.to(device)用法

这代表将模型加载到指定设备上。

其中,device=torch.device("cpu")代表的使用cpu,而device=torch.device("cuda")则代表的使用GPU。

当我们指定了设备之后,就需要将模型加载到相应设备中,此时需要使用model=model.to(device),将模型加载到相应的设备中。

将由GPU保存的模型加载到CPU上。

将torch.load()函数中的map_location参数设置为torch.device('cpu')

device = torch.device('cpu')
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH, map_location=device))

将由GPU保存的模型加载到GPU上。确保对输入的tensors调用input = input.to(device)方法。

device = torch.device("cuda")
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.to(device)

将由CPU保存的模型加载到GPU上。确保对输入的tensors调用input = input.to(device)方法。map_location是将模型加载到GPU上,model.to(torch.device('cuda'))是将模型参数加载为CUDA的tensor。最后保证使用.to(torch.device('cuda'))方法将需要使用的参数放入CUDA。

device = torch.device("cuda")
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH, map_location="cuda:0")) # Choose whatever GPU device number you want
model.to(device)

以上这篇使用tensorflow根据输入更改tensor shape就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python使用urllib模块开发的多线程豆瓣小站mp3下载器
Jan 16 Python
Python抓取电影天堂电影信息的代码
Apr 07 Python
python实现给微信公众号发送消息的方法
Jun 30 Python
Python学习pygal绘制线图代码分享
Dec 09 Python
Python中将变量按行写入txt文本中的方法
Apr 03 Python
使用python实现快速搭建简易的FTP服务器
Sep 12 Python
numpy向空的二维数组中添加元素的方法
Nov 01 Python
Python判断一个三位数是否为水仙花数的示例
Nov 13 Python
flask的orm框架SQLAlchemy查询实现解析
Dec 12 Python
django之从html页面表单获取输入的数据实例
Mar 16 Python
Python如何获取文件指定行的内容
May 27 Python
python利用platform模块获取系统信息
Oct 09 Python
pytorch 计算ConvTranspose1d输出特征大小方式
Jun 23 #Python
Android Q之气泡弹窗的实现示例
Jun 23 #Python
pytorch判断是否cuda 判断变量类型方式
Jun 23 #Python
Pytorch 解决自定义子Module .cuda() tensor失败的问题
Jun 23 #Python
python如何查看安装了的模块
Jun 23 #Python
pytorch cuda上tensor的定义 以及减少cpu的操作详解
Jun 23 #Python
Python dict的常用方法示例代码
Jun 23 #Python
You might like
PHP开发不能违背的安全规则 过滤用户输入
2011/05/01 PHP
Laravel 5框架学习之表单验证
2015/04/08 PHP
PHP+swoole实现简单多人在线聊天群发
2016/01/19 PHP
golang与php实现计算两个经纬度之间距离的方法
2016/07/22 PHP
php pdo连接数据库操作示例
2019/11/18 PHP
关于Blog顶部的滚动导航条代码
2006/09/25 Javascript
静态的动态续篇之来点XML
2006/12/23 Javascript
Jquery中$.get(),$.post(),$.ajax(),$.getJSON()的用法总结
2013/11/14 Javascript
简单的js图片轮换代码(js图片轮播)
2014/05/06 Javascript
js数组操作常用方法
2014/05/08 Javascript
javascript表格的渲染组件
2015/07/03 Javascript
jQuery过滤HTML标签并高亮显示关键字的方法
2015/08/07 Javascript
Jquery zTree 树控件异步加载操作
2016/02/25 Javascript
Javascript数组Array基础介绍
2016/03/13 Javascript
Node.js配合node-http-proxy解决本地开发ajax跨域问题
2016/08/31 Javascript
chrome浏览器如何断点调试异步加载的JS
2016/09/05 Javascript
require.js 加载 vue组件 r.js 合并压缩的实例
2016/10/14 Javascript
Agularjs妙用双向数据绑定实现手风琴效果
2017/05/26 Javascript
vue-自定义组件传值的实例讲解
2018/09/18 Javascript
Node.js API详解之 tty功能与用法实例分析
2020/04/27 Javascript
[01:00:17]DOTA2-DPC中国联赛 正赛 SAG vs Dynasty BO3 第二场 1月25日
2021/03/11 DOTA
Flask框架各种常见装饰器示例
2018/07/17 Python
Python字典创建 遍历 添加等实用基础操作技巧
2018/09/13 Python
Python自动发送邮件的方法实例总结
2018/12/08 Python
Python 中的参数传递、返回值、浅拷贝、深拷贝
2019/06/25 Python
Anaconda配置pytorch-gpu虚拟环境的图文教程
2020/04/16 Python
CSS Grid布局教程之什么是网格布局
2014/12/30 HTML / CSS
html5的canvas方法使用指南
2014/12/15 HTML / CSS
goodhealth官方海外旗舰店:新西兰国民营养师
2017/12/15 全球购物
《小小雨点》教学反思
2014/02/18 职场文书
聘用意向书范本
2014/04/01 职场文书
2015届大学生就业推荐表自我评价
2014/09/27 职场文书
交通事故协议书范文
2014/10/23 职场文书
网聊搭讪开场白
2015/05/28 职场文书
24年收藏2000多部退役军用电台
2022/02/18 无线电
恶魔之树最顶端的三颗果实 震震果实上榜,第一可以制造岩浆
2022/03/18 日漫