使用tensorflow根据输入更改tensor shape


Posted in Python onJune 23, 2020

涉及随机数以及类RNN的网络构建常常需要根据输入shape,决定中间变量的shape或步长。

tf.shape函数不同于tensor.shape.as_list()函数,后者返回的是常值list,而前者返回的是tensor。

使用tf.shape函数可以使得中间变量的tensor形状随输入变化,不需要在构建Graph的时候指定。但对于tf.Variable,因为需要提前分配固定空间,其shape无法通过上诉方法设定。

实例代码如下:

a = tf.placeholder(tf.float32,[None,])
b = tf.random_normal(tf.concat([tf.shape(a),[2,]],axis=0))

补充知识:pytorch中model=model.to(device)用法

这代表将模型加载到指定设备上。

其中,device=torch.device("cpu")代表的使用cpu,而device=torch.device("cuda")则代表的使用GPU。

当我们指定了设备之后,就需要将模型加载到相应设备中,此时需要使用model=model.to(device),将模型加载到相应的设备中。

将由GPU保存的模型加载到CPU上。

将torch.load()函数中的map_location参数设置为torch.device('cpu')

device = torch.device('cpu')
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH, map_location=device))

将由GPU保存的模型加载到GPU上。确保对输入的tensors调用input = input.to(device)方法。

device = torch.device("cuda")
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.to(device)

将由CPU保存的模型加载到GPU上。确保对输入的tensors调用input = input.to(device)方法。map_location是将模型加载到GPU上,model.to(torch.device('cuda'))是将模型参数加载为CUDA的tensor。最后保证使用.to(torch.device('cuda'))方法将需要使用的参数放入CUDA。

device = torch.device("cuda")
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH, map_location="cuda:0")) # Choose whatever GPU device number you want
model.to(device)

以上这篇使用tensorflow根据输入更改tensor shape就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python实现快速排序和插入排序算法及自定义排序的示例
Feb 16 Python
python定时按日期备份MySQL数据并压缩
Apr 19 Python
Python实现 PS 图像调整中的亮度调整
Jun 28 Python
Python及Pycharm安装方法图文教程
Aug 05 Python
详解centos7+django+python3+mysql+阿里云部署项目全流程
Nov 15 Python
Python实现bilibili时间长度查询的示例代码
Jan 14 Python
tensorflow生成多个tfrecord文件实例
Feb 17 Python
Python3实现打印任意宽度的菱形代码
Apr 12 Python
Python pip 常用命令汇总
Oct 19 Python
详解用selenium来下载小姐姐图片并保存
Jan 26 Python
用python批量解压带密码的压缩包
May 31 Python
Python利用capstone实现反汇编
Apr 06 Python
pytorch 计算ConvTranspose1d输出特征大小方式
Jun 23 #Python
Android Q之气泡弹窗的实现示例
Jun 23 #Python
pytorch判断是否cuda 判断变量类型方式
Jun 23 #Python
Pytorch 解决自定义子Module .cuda() tensor失败的问题
Jun 23 #Python
python如何查看安装了的模块
Jun 23 #Python
pytorch cuda上tensor的定义 以及减少cpu的操作详解
Jun 23 #Python
Python dict的常用方法示例代码
Jun 23 #Python
You might like
写一个用户在线显示的程序
2006/10/09 PHP
PHP4和PHP5共存于一系统
2006/11/17 PHP
ThinkPHP3.2框架使用addAll()批量插入数据的方法
2017/03/16 PHP
在 PHP 和 Laravel 中使用 Traits的方法
2019/11/13 PHP
JavaScript 字符编码规则
2009/05/04 Javascript
Prototype Date对象 学习
2009/07/12 Javascript
DWZ table的原生分页浅谈
2013/03/01 Javascript
JS中获取数据库中的值的方法
2013/07/14 Javascript
jQuery取id有.的值的方法
2014/05/21 Javascript
使用jquery组件qrcode生成二维码及应用指南
2015/02/22 Javascript
javascript拖拽应用实例(二)
2016/03/25 Javascript
AngularJs表单验证实例代码解析
2016/11/29 Javascript
JavaScript实现的鼠标响应颜色渐变效果完整实例
2017/02/18 Javascript
Angular4开发解决跨域问题详解
2017/08/28 Javascript
解决vue-cli单页面手机应用input点击手机端虚拟键盘弹出盖住input问题
2018/08/25 Javascript
详解promise.then,process.nextTick, setTimeout 以及 setImmediate的执行顺序
2018/11/21 Javascript
详解React项目如何修改打包地址(编译输出文件地址)
2019/03/21 Javascript
Vue实现回到顶部和底部动画效果
2019/07/31 Javascript
详解Django之auth模块(用户认证)
2018/04/17 Python
Python GUI Tkinter简单实现个性签名设计
2018/06/19 Python
Python实现对字典分别按键(key)和值(value)进行排序的方法分析
2018/12/19 Python
django model的update时auto_now不被更新的原因及解决方式
2020/04/01 Python
python 检测图片是否有马赛克
2020/12/01 Python
python opencv实现图像配准与比较
2021/02/09 Python
SpringBoot首页设置解析(推荐)
2021/02/11 Python
美国男装连锁零售商:Men’s Wearhouse
2016/10/14 全球购物
VisionPros美国站:加拿大在线隐形眼镜和眼镜零售商
2020/02/11 全球购物
薇姿法国官网:Vichy法国
2021/01/28 全球购物
假日旅行社实习自我鉴定
2013/09/24 职场文书
普通大学毕业生自荐信
2013/11/04 职场文书
三年大学生活自我鉴定
2014/01/21 职场文书
《夕阳真美》教学反思
2014/04/27 职场文书
松材线虫病防治方案
2014/06/15 职场文书
2014村书记党建工作汇报材料
2014/11/02 职场文书
质量保证书格式模板
2015/02/27 职场文书
MySQL令人大跌眼镜的隐式转换
2021/08/23 MySQL