使用tensorflow根据输入更改tensor shape


Posted in Python onJune 23, 2020

涉及随机数以及类RNN的网络构建常常需要根据输入shape,决定中间变量的shape或步长。

tf.shape函数不同于tensor.shape.as_list()函数,后者返回的是常值list,而前者返回的是tensor。

使用tf.shape函数可以使得中间变量的tensor形状随输入变化,不需要在构建Graph的时候指定。但对于tf.Variable,因为需要提前分配固定空间,其shape无法通过上诉方法设定。

实例代码如下:

a = tf.placeholder(tf.float32,[None,])
b = tf.random_normal(tf.concat([tf.shape(a),[2,]],axis=0))

补充知识:pytorch中model=model.to(device)用法

这代表将模型加载到指定设备上。

其中,device=torch.device("cpu")代表的使用cpu,而device=torch.device("cuda")则代表的使用GPU。

当我们指定了设备之后,就需要将模型加载到相应设备中,此时需要使用model=model.to(device),将模型加载到相应的设备中。

将由GPU保存的模型加载到CPU上。

将torch.load()函数中的map_location参数设置为torch.device('cpu')

device = torch.device('cpu')
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH, map_location=device))

将由GPU保存的模型加载到GPU上。确保对输入的tensors调用input = input.to(device)方法。

device = torch.device("cuda")
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.to(device)

将由CPU保存的模型加载到GPU上。确保对输入的tensors调用input = input.to(device)方法。map_location是将模型加载到GPU上,model.to(torch.device('cuda'))是将模型参数加载为CUDA的tensor。最后保证使用.to(torch.device('cuda'))方法将需要使用的参数放入CUDA。

device = torch.device("cuda")
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH, map_location="cuda:0")) # Choose whatever GPU device number you want
model.to(device)

以上这篇使用tensorflow根据输入更改tensor shape就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python中itertools模块用法详解
Sep 25 Python
Python中isnumeric()方法的使用简介
May 19 Python
python基于右递归解决八皇后问题的方法
May 25 Python
numpy中实现ndarray数组返回符合特定条件的索引方法
Apr 17 Python
python中的turtle库函数简单使用教程
Jul 23 Python
用Python和WordCloud绘制词云的实现方法(内附让字体清晰的秘笈)
Jan 08 Python
用python打印1~20的整数实例讲解
Jul 01 Python
Python画图实现同一结点多个柱状图的示例
Jul 07 Python
Python迭代器iterator生成器generator使用解析
Oct 24 Python
Cython编译python为so 代码加密示例
Dec 23 Python
django正续或者倒序查库实例
May 19 Python
Python Scrapy多页数据爬取实现过程解析
Jun 12 Python
pytorch 计算ConvTranspose1d输出特征大小方式
Jun 23 #Python
Android Q之气泡弹窗的实现示例
Jun 23 #Python
pytorch判断是否cuda 判断变量类型方式
Jun 23 #Python
Pytorch 解决自定义子Module .cuda() tensor失败的问题
Jun 23 #Python
python如何查看安装了的模块
Jun 23 #Python
pytorch cuda上tensor的定义 以及减少cpu的操作详解
Jun 23 #Python
Python dict的常用方法示例代码
Jun 23 #Python
You might like
可快速识别放射性物质-国外大神教你diy一个开放式辐射探测器
2020/03/12 无线电
Yii中表单用法实例详解
2016/01/05 PHP
thinkPHP5.0框架安装教程
2017/03/25 PHP
Laravel 创建可以传递参数 Console服务的例子
2019/10/14 PHP
Thinkphp 3.2框架使用Redis的方法详解
2019/10/24 PHP
关于URL中的特殊符号使用介绍
2011/11/03 Javascript
js/jquery去掉空格,回车,换行示例代码
2013/11/05 Javascript
jQuery实现冻结表头的方法
2015/03/09 Javascript
Java遍历集合方法分析(实现原理、算法性能、适用场合)
2016/04/25 Javascript
JS中使用new Date(str)创建时间对象不兼容firefox和ie的解决方法(两种)
2016/12/14 Javascript
jQuery.Validate表单验证插件的使用示例详解
2017/01/04 Javascript
基于ExtJs在页面上window再调用Window的事件处理方法
2017/07/26 Javascript
javascript将url解析为json格式的两种方法
2017/08/18 Javascript
JavaScript设计模式之命令模式实例分析
2019/01/16 Javascript
vue实现后台管理权限系统及顶栏三级菜单显示功能
2019/06/19 Javascript
Angular8 简单表单验证的实现示例
2020/06/03 Javascript
vue双击事件2.0事件监听(点击-双击-鼠标事件)和事件修饰符操作
2020/07/27 Javascript
[02:59]2014DOTA2西雅图国际邀请赛 圆满落幕中国夺冠
2014/07/23 DOTA
python executemany的使用及注意事项
2017/03/13 Python
python dataframe常见操作方法:实现取行、列、切片、统计特征值
2018/06/09 Python
OpenCV 模板匹配
2019/07/10 Python
使用Python的turtle模块画国旗
2019/09/24 Python
Python3.8安装Pygame教程步骤详解
2020/08/14 Python
CSS3的新特性介绍
2008/10/31 HTML / CSS
html5+css3气泡组件的实现
2014/11/21 HTML / CSS
好莱坞百老汇御用王牌美妆:Koh Gen Do 江原道
2018/04/03 全球购物
德国旅游网站:weg.de
2018/06/03 全球购物
Jacques Lemans德国:奥地利钟表品牌
2019/12/26 全球购物
简述你对Statement,PreparedStatement,CallableStatement的理解
2013/03/25 面试题
社区工作者思想汇报
2014/01/13 职场文书
历史专业学生的自我评价
2014/02/28 职场文书
团购业务员岗位职责
2014/03/15 职场文书
妇联领导班子剖析材料
2014/08/21 职场文书
家庭贫困证明范本(经典版)
2014/09/22 职场文书
教你如何用Python实现人脸识别(含源代码)
2021/06/23 Python
win10电脑老是死机怎么办?win10系统老是死机的解决方法
2022/08/05 数码科技