使用tensorflow根据输入更改tensor shape


Posted in Python onJune 23, 2020

涉及随机数以及类RNN的网络构建常常需要根据输入shape,决定中间变量的shape或步长。

tf.shape函数不同于tensor.shape.as_list()函数,后者返回的是常值list,而前者返回的是tensor。

使用tf.shape函数可以使得中间变量的tensor形状随输入变化,不需要在构建Graph的时候指定。但对于tf.Variable,因为需要提前分配固定空间,其shape无法通过上诉方法设定。

实例代码如下:

a = tf.placeholder(tf.float32,[None,])
b = tf.random_normal(tf.concat([tf.shape(a),[2,]],axis=0))

补充知识:pytorch中model=model.to(device)用法

这代表将模型加载到指定设备上。

其中,device=torch.device("cpu")代表的使用cpu,而device=torch.device("cuda")则代表的使用GPU。

当我们指定了设备之后,就需要将模型加载到相应设备中,此时需要使用model=model.to(device),将模型加载到相应的设备中。

将由GPU保存的模型加载到CPU上。

将torch.load()函数中的map_location参数设置为torch.device('cpu')

device = torch.device('cpu')
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH, map_location=device))

将由GPU保存的模型加载到GPU上。确保对输入的tensors调用input = input.to(device)方法。

device = torch.device("cuda")
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.to(device)

将由CPU保存的模型加载到GPU上。确保对输入的tensors调用input = input.to(device)方法。map_location是将模型加载到GPU上,model.to(torch.device('cuda'))是将模型参数加载为CUDA的tensor。最后保证使用.to(torch.device('cuda'))方法将需要使用的参数放入CUDA。

device = torch.device("cuda")
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH, map_location="cuda:0")) # Choose whatever GPU device number you want
model.to(device)

以上这篇使用tensorflow根据输入更改tensor shape就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python赋值操作方法分享
Mar 23 Python
Python计算三维矢量幅度的方法
Jun 15 Python
深入理解python中的闭包和装饰器
Jun 12 Python
python安装mysql-python简明笔记(ubuntu环境)
Jun 25 Python
Python插件virtualenv搭建虚拟环境
Nov 20 Python
python 2.7.14安装图文教程
Apr 08 Python
pyqt5中QThread在使用时出现重复emit的实例
Jun 21 Python
django多文件上传,form提交,多对多外键保存的实例
Aug 06 Python
关于PyTorch源码解读之torchvision.models
Aug 17 Python
python 解决mysql where in 对列表(list,,array)问题
Jun 06 Python
使用keras时input_shape的维度表示问题说明
Jun 29 Python
浅析Python中的套接字编程
Jun 22 Python
pytorch 计算ConvTranspose1d输出特征大小方式
Jun 23 #Python
Android Q之气泡弹窗的实现示例
Jun 23 #Python
pytorch判断是否cuda 判断变量类型方式
Jun 23 #Python
Pytorch 解决自定义子Module .cuda() tensor失败的问题
Jun 23 #Python
python如何查看安装了的模块
Jun 23 #Python
pytorch cuda上tensor的定义 以及减少cpu的操作详解
Jun 23 #Python
Python dict的常用方法示例代码
Jun 23 #Python
You might like
用PHP实现多级树型菜单
2006/10/09 PHP
php str_pad 函数用法简介
2009/07/11 PHP
php获取textarea的值并处理回车换行的方法
2014/10/20 PHP
php中smarty模板条件判断用法实例
2015/06/11 PHP
Laravel 创建可以传递参数 Console服务的例子
2019/10/14 PHP
一个符号插入器 中用到的js代码
2007/09/04 Javascript
childNodes.length与children.length的区别
2009/05/14 Javascript
IE6-IE9中tbody的innerHTML不能赋值的解决方法
2014/09/26 Javascript
javascript查询字符串参数的方法
2015/01/28 Javascript
ExtJs动态生成treepanel的Json格式
2015/07/19 Javascript
详解AngularJS中的filter过滤器用法
2016/01/04 Javascript
angularJs关于指令的一些冷门属性详解
2016/10/24 Javascript
vue.js学习笔记之绑定style样式和class列表
2016/10/31 Javascript
微信小程序 闭包写法详细介绍
2016/12/14 Javascript
深究AngularJS如何获取input的焦点(自定义指令)
2017/06/12 Javascript
python daemon守护进程实现
2016/08/27 Python
对python修改xml文件的节点值方法详解
2018/12/24 Python
用Python逐行分析文件方法
2019/01/28 Python
anaconda中更改python版本的方法步骤
2019/07/14 Python
python实现的按要求生成手机号功能示例
2019/10/08 Python
用python打开摄像头并把图像传回qq邮箱(Pyinstaller打包)
2020/05/17 Python
用Python实现童年贪吃蛇小游戏功能的实例代码
2020/12/07 Python
详解CSS3原生支持div铺满浏览器的方法
2018/08/30 HTML / CSS
德国古洛迷亚百货官网:GALERIA Kaufhof
2017/06/20 全球购物
Vans奥地利官方网站:美国原创极限运动潮牌
2018/09/30 全球购物
mysql_pconnect()和mysql_connect()有什么区别
2012/05/25 面试题
Linux如何修改文件和文件夹的权限
2012/06/27 面试题
历史系自荐信范文
2013/12/24 职场文书
初中三好学生自我鉴定
2014/04/07 职场文书
《陈涉世家》教学反思
2014/04/12 职场文书
美丽家庭事迹材料
2014/05/03 职场文书
巴西世界杯32强口号
2014/06/05 职场文书
2014年加油站工作总结
2014/12/04 职场文书
储备店长岗位职责
2015/04/14 职场文书
吧主申请感言怎么写
2015/08/03 职场文书
Spring Bean的实例化之属性注入源码剖析过程
2021/06/13 Java/Android