如何定义TensorFlow输入节点


Posted in Python onJanuary 23, 2020

TensorFlow中有如下几种定义输入节点的方法。

通过占位符定义:一般使用这种方式。

通过字典类型定义:一般用于输入比较多的情况。

直接定义:一般很少使用。

一 占位符定义

示例:

具体使用tf.placeholder函数,代码如下:

X = tf.placeholder("float")
Y = tf.placeholder("float")

二 字典类型定义

1 实例

通过字典类型定义输入节点

2 关键代码

# 创建模型
# 占位符
inputdict = {
  'x': tf.placeholder("float"),
  'y': tf.placeholder("float")
}

3 解释

通过字典定义的方式和第一种比较像,只不过是堆叠到一起。

4 全部代码

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
plotdata = { "batchsize":[], "loss":[] }
def moving_average(a, w=10):
  if len(a) < w:
    return a[:]  
  return [val if idx < w else sum(a[(idx-w):idx])/w for idx, val in enumerate(a)]
#生成模拟数据
train_X = np.linspace(-1, 1, 100)
train_Y = 2 * train_X + np.random.randn(*train_X.shape) * 0.3 # y=2x,但是加入了噪声
#图形显示
plt.plot(train_X, train_Y, 'ro', label='Original data')
plt.legend()
plt.show()
# 创建模型
# 占位符
inputdict = {
  'x': tf.placeholder("float"),
  'y': tf.placeholder("float")
}
# 模型参数
W = tf.Variable(tf.random_normal([1]), name="weight")
b = tf.Variable(tf.zeros([1]), name="bias")
# 前向结构
z = tf.multiply(inputdict['x'], W)+ b
#反向优化
cost =tf.reduce_mean( tf.square(inputdict['y'] - z))
learning_rate = 0.01
optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost) #Gradient descent
# 初始化变量
init = tf.global_variables_initializer()
#参数设置
training_epochs = 20
display_step = 2
# 启动session
with tf.Session() as sess:
  sess.run(init)
  # Fit all training data
  for epoch in range(training_epochs):
    for (x, y) in zip(train_X, train_Y):
      sess.run(optimizer, feed_dict={inputdict['x']: x, inputdict['y']: y})
    #显示训练中的详细信息
    if epoch % display_step == 0:
      loss = sess.run(cost, feed_dict={inputdict['x']: train_X, inputdict['y']:train_Y})
      print ("Epoch:", epoch+1, "cost=", loss,"W=", sess.run(W), "b=", sess.run(b))
      if not (loss == "NA" ):
        plotdata["batchsize"].append(epoch)
        plotdata["loss"].append(loss)
  print (" Finished!")
  print ("cost=", sess.run(cost, feed_dict={inputdict['x']: train_X, inputdict['y']: train_Y}), "W=", sess.run(W), "b=", sess.run(b))
  #图形显示
  plt.plot(train_X, train_Y, 'ro', label='Original data')
  plt.plot(train_X, sess.run(W) * train_X + sess.run(b), label='Fitted line')
  plt.legend()
  plt.show()
  
  plotdata["avgloss"] = moving_average(plotdata["loss"])
  plt.figure(1)
  plt.subplot(211)
  plt.plot(plotdata["batchsize"], plotdata["avgloss"], 'b--')
  plt.xlabel('Minibatch number')
  plt.ylabel('Loss')
  plt.title('Minibatch run vs. Training loss')
   
  plt.show()
  print ("x=0.2,z=", sess.run(z, feed_dict={inputdict['x']: 0.2}))

5 运行结果

如何定义TensorFlow输入节点

三 直接定义

1 实例

直接定义输入结果

2 解释

直接定义:将定义好的Python变量直接放到OP节点中参与输入的运算,将模拟数据的变量直接放到模型中训练。

3 代码

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
#生成模拟数据
train_X =np.float32( np.linspace(-1, 1, 100))
train_Y = 2 * train_X + np.random.randn(*train_X.shape) * 0.3 # y=2x,但是加入了噪声
#图形显示
plt.plot(train_X, train_Y, 'ro', label='Original data')
plt.legend()
plt.show()
# 创建模型
# 模型参数
W = tf.Variable(tf.random_normal([1]), name="weight")
b = tf.Variable(tf.zeros([1]), name="bias")
# 前向结构
z = tf.multiply(W, train_X)+ b
#反向优化
cost =tf.reduce_mean( tf.square(train_Y - z))
learning_rate = 0.01
optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost) #Gradient descent
# 初始化变量
init = tf.global_variables_initializer()
#参数设置
training_epochs = 20
display_step = 2
# 启动session
with tf.Session() as sess:
  sess.run(init)
  # Fit all training data
  for epoch in range(training_epochs):
    for (x, y) in zip(train_X, train_Y):
      sess.run(optimizer)
    #显示训练中的详细信息
    if epoch % display_step == 0:
      loss = sess.run(cost)
      print ("Epoch:", epoch+1, "cost=", loss,"W=", sess.run(W), "b=", sess.run(b))
  print (" Finished!")
  print ("cost=", sess.run(cost), "W=", sess.run(W), "b=", sess.run(b))

4 运行结果

如何定义TensorFlow输入节点

以上这篇如何定义TensorFlow输入节点就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python os模块介绍
Nov 30 Python
Python单体模式的几种常见实现方法详解
Jul 28 Python
python3 http提交json参数并获取返回值的方法
Dec 19 Python
在pycharm 中添加运行参数的操作方法
Jan 19 Python
python整小时 整天时间戳获取算法示例
Feb 20 Python
python如何爬取网站数据并进行数据可视化
Jul 08 Python
python中dict使用方法详解
Jul 17 Python
VSCode中自动为Python文件添加头部注释
Nov 14 Python
python实现Oracle查询分组的方法示例
Apr 30 Python
Pytorch之扩充tensor的操作
Mar 04 Python
还在手动盖楼抽奖?教你用Python实现自动评论盖楼抽奖(一)
Jun 07 Python
Python几种酷炫的进度条的方式
Apr 11 Python
django 文件上传功能的相关实例代码(简单易懂)
Jan 22 #Python
python动态文本进度条的实例代码
Jan 22 #Python
python wav模块获取采样率 采样点声道量化位数(实例代码)
Jan 22 #Python
使用Python实现Wake On Lan远程开机功能
Jan 22 #Python
python定义类self用法实例解析
Jan 22 #Python
通过实例解析python描述符原理作用
Jan 22 #Python
python基于property()函数定义属性
Jan 22 #Python
You might like
使用无限生命期Session的方法
2006/10/09 PHP
FirePHP 推荐一款PHP调试工具
2011/04/23 PHP
Windows下部署Apache+PHP+MySQL运行环境实战
2012/08/31 PHP
PHP实现的mysql操作类【MySQL与MySQLi方式】
2017/10/07 PHP
Mac系统下搭建Nginx+php-fpm实例讲解
2020/12/15 PHP
js中匿名函数的N种写法
2010/09/08 Javascript
jQuery页面图片伴随滚动条逐渐显示的小例子
2013/03/21 Javascript
js 窗口抖动示例
2013/09/04 Javascript
自动设置iframe大小的jQuery代码
2013/09/11 Javascript
juery框架写的弹窗效果适合新手
2013/11/27 Javascript
JS获取客户端IP地址、MAC和主机名的7个方法汇总
2014/07/21 Javascript
javascript实现简单的贪吃蛇游戏
2015/03/31 Javascript
JavaScript html5 canvas绘制时钟效果
2016/03/01 Javascript
EasyUI中在表单提交之前进行验证
2016/07/19 Javascript
JavaScript DOM节点操作方法总结
2016/08/23 Javascript
探究JavaScript中的五种事件处理程序方式
2016/12/07 Javascript
基于vue.js的分页插件详解
2017/11/27 Javascript
Js视频播放器插件Video.js使用方法详解
2020/02/04 Javascript
JavaScript实现省份城市的三级联动
2020/02/11 Javascript
Python基于递归算法实现的汉诺塔与Fibonacci数列示例
2018/04/18 Python
BP神经网络原理及Python实现代码
2018/12/18 Python
解决numpy矩阵相减出现的负值自动转正值的问题
2020/06/03 Python
基于python实现图片转字符画代码实例
2020/09/04 Python
jupyter notebook 写代码自动补全的实现
2020/11/02 Python
python3中celery异步框架简单使用+守护进程方式启动
2021/01/20 Python
Linux管理员面试经常问道的相关命令
2014/12/12 面试题
社会学专业学生职业规划书
2014/02/07 职场文书
华清池导游词
2015/02/02 职场文书
市场督导岗位职责
2015/04/10 职场文书
2015年银行大堂经理工作总结
2015/04/24 职场文书
学生病假条范文
2015/08/17 职场文书
关于五一放假的通知
2015/08/18 职场文书
解决Mysql的left join无效及使用的注意事项说明
2021/07/01 MySQL
MySQL 1130异常,无法远程登录解决方案详解
2021/08/23 MySQL
vue递归实现树形组件
2022/07/15 Vue.js
Win11 Beta 22621.601 和 22622.601今日发布 KB5017384修复内容汇总
2022/09/23 数码科技