使用tensorflow根据输入更改tensor shape


Posted in Python onJune 23, 2020

涉及随机数以及类RNN的网络构建常常需要根据输入shape,决定中间变量的shape或步长。

tf.shape函数不同于tensor.shape.as_list()函数,后者返回的是常值list,而前者返回的是tensor。

使用tf.shape函数可以使得中间变量的tensor形状随输入变化,不需要在构建Graph的时候指定。但对于tf.Variable,因为需要提前分配固定空间,其shape无法通过上诉方法设定。

实例代码如下:

a = tf.placeholder(tf.float32,[None,])
b = tf.random_normal(tf.concat([tf.shape(a),[2,]],axis=0))

补充知识:pytorch中model=model.to(device)用法

这代表将模型加载到指定设备上。

其中,device=torch.device("cpu")代表的使用cpu,而device=torch.device("cuda")则代表的使用GPU。

当我们指定了设备之后,就需要将模型加载到相应设备中,此时需要使用model=model.to(device),将模型加载到相应的设备中。

将由GPU保存的模型加载到CPU上。

将torch.load()函数中的map_location参数设置为torch.device('cpu')

device = torch.device('cpu')
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH, map_location=device))

将由GPU保存的模型加载到GPU上。确保对输入的tensors调用input = input.to(device)方法。

device = torch.device("cuda")
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.to(device)

将由CPU保存的模型加载到GPU上。确保对输入的tensors调用input = input.to(device)方法。map_location是将模型加载到GPU上,model.to(torch.device('cuda'))是将模型参数加载为CUDA的tensor。最后保证使用.to(torch.device('cuda'))方法将需要使用的参数放入CUDA。

device = torch.device("cuda")
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH, map_location="cuda:0")) # Choose whatever GPU device number you want
model.to(device)

以上这篇使用tensorflow根据输入更改tensor shape就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python实现的正则表达式功能入门教程【经典】
Jun 05 Python
Python Socket使用实例
Dec 18 Python
Python中实现变量赋值传递时的引用和拷贝方法
Apr 29 Python
Python实现的爬虫刷回复功能示例
Jun 07 Python
python使用if语句实现一个猜拳游戏详解
Aug 27 Python
Pytorch技巧:DataLoader的collate_fn参数使用详解
Jan 08 Python
python pandas利用fillna方法实现部分自动填充功能
Mar 16 Python
使用Python对Dicom文件进行读取与写入的实现
Apr 20 Python
keras的三种模型实现与区别说明
Jul 03 Python
python中plt.imshow与cv2.imshow显示颜色问题
Jul 16 Python
基于django和dropzone.js实现上传文件
Nov 24 Python
Python机器学习之底层实现KNN
Jun 20 Python
pytorch 计算ConvTranspose1d输出特征大小方式
Jun 23 #Python
Android Q之气泡弹窗的实现示例
Jun 23 #Python
pytorch判断是否cuda 判断变量类型方式
Jun 23 #Python
Pytorch 解决自定义子Module .cuda() tensor失败的问题
Jun 23 #Python
python如何查看安装了的模块
Jun 23 #Python
pytorch cuda上tensor的定义 以及减少cpu的操作详解
Jun 23 #Python
Python dict的常用方法示例代码
Jun 23 #Python
You might like
php 输出双引号"与单引号'的方法
2010/05/09 PHP
PHP基于MySQL数据库实现对象持久层的方法
2015/06/17 PHP
jQuery实现随意改变div任意属性的名称和值(部分原生js实现)
2013/05/28 Javascript
jQuery Mobile的loading对话框显示/隐藏方法分享
2013/11/26 Javascript
5个数组Array方法: indexOf、filter、forEach、map、reduce使用实例
2015/01/29 Javascript
JS去除iframe滚动条的方法
2015/04/01 Javascript
AngularJS 中文API参考手册
2016/07/28 Javascript
jQuery中Datatables增加跳转到指定页功能
2017/02/08 Javascript
JQuery和html+css实现带小圆点和左右按钮的轮播图实例
2017/07/22 jQuery
微信小程序实现手势图案锁屏功能
2018/01/30 Javascript
vue router+vuex实现首页登录验证判断逻辑
2018/05/17 Javascript
你应该了解的JavaScript Array.map()五种用途小结
2018/11/14 Javascript
如何解决.vue文件url引用文件的问题
2019/01/18 Javascript
vue表单验证你真的会了吗?vue表单验证(form)validate
2019/04/07 Javascript
vue基本使用--refs获取组件或元素的实例
2019/11/07 Javascript
基于jsbarcode 生成条形码并将生成的条码保存至本地+源码
2020/04/27 Javascript
如何利用JavaScript编写更好的条件语句详解
2020/08/10 Javascript
如何阻止移动端浏览器点击图片浏览
2020/08/29 Javascript
python搜索指定目录的方法
2015/04/29 Python
Python数据类型详解(三)元祖:tuple
2016/05/08 Python
python3.6连接MySQL和表的创建与删除实例代码
2017/12/28 Python
Django1.9 加载通过ImageField上传的图片方法
2018/05/25 Python
利用python numpy+matplotlib绘制股票k线图的方法
2019/06/26 Python
python实现beta分布概率密度函数的方法
2019/07/08 Python
Numpy数组array和矩阵matrix转换方法
2019/08/05 Python
自学python用什么系统好
2020/06/23 Python
英国豪华针织品牌John Smedley的在线销售商:The Outlet by John Smedley
2018/04/08 全球购物
Fanatics法国官网:美国体育电商
2019/08/27 全球购物
Carolina Lemke Berlin澳大利亚官网:时尚太阳镜品牌
2019/09/17 全球购物
小学毕业感言150字
2014/02/05 职场文书
十八届三中全会学习方案
2014/02/16 职场文书
2014两会优秀的心得体会范文
2014/03/17 职场文书
领导班子四风表现材料
2014/08/23 职场文书
2014年政风行风自查自纠报告
2014/10/21 职场文书
酒店人事主管岗位职责
2015/04/11 职场文书
银行安全保卫工作总结
2015/08/10 职场文书