如何定义TensorFlow输入节点


Posted in Python onJanuary 23, 2020

TensorFlow中有如下几种定义输入节点的方法。

通过占位符定义:一般使用这种方式。

通过字典类型定义:一般用于输入比较多的情况。

直接定义:一般很少使用。

一 占位符定义

示例:

具体使用tf.placeholder函数,代码如下:

X = tf.placeholder("float")
Y = tf.placeholder("float")

二 字典类型定义

1 实例

通过字典类型定义输入节点

2 关键代码

# 创建模型
# 占位符
inputdict = {
  'x': tf.placeholder("float"),
  'y': tf.placeholder("float")
}

3 解释

通过字典定义的方式和第一种比较像,只不过是堆叠到一起。

4 全部代码

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
plotdata = { "batchsize":[], "loss":[] }
def moving_average(a, w=10):
  if len(a) < w:
    return a[:]  
  return [val if idx < w else sum(a[(idx-w):idx])/w for idx, val in enumerate(a)]
#生成模拟数据
train_X = np.linspace(-1, 1, 100)
train_Y = 2 * train_X + np.random.randn(*train_X.shape) * 0.3 # y=2x,但是加入了噪声
#图形显示
plt.plot(train_X, train_Y, 'ro', label='Original data')
plt.legend()
plt.show()
# 创建模型
# 占位符
inputdict = {
  'x': tf.placeholder("float"),
  'y': tf.placeholder("float")
}
# 模型参数
W = tf.Variable(tf.random_normal([1]), name="weight")
b = tf.Variable(tf.zeros([1]), name="bias")
# 前向结构
z = tf.multiply(inputdict['x'], W)+ b
#反向优化
cost =tf.reduce_mean( tf.square(inputdict['y'] - z))
learning_rate = 0.01
optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost) #Gradient descent
# 初始化变量
init = tf.global_variables_initializer()
#参数设置
training_epochs = 20
display_step = 2
# 启动session
with tf.Session() as sess:
  sess.run(init)
  # Fit all training data
  for epoch in range(training_epochs):
    for (x, y) in zip(train_X, train_Y):
      sess.run(optimizer, feed_dict={inputdict['x']: x, inputdict['y']: y})
    #显示训练中的详细信息
    if epoch % display_step == 0:
      loss = sess.run(cost, feed_dict={inputdict['x']: train_X, inputdict['y']:train_Y})
      print ("Epoch:", epoch+1, "cost=", loss,"W=", sess.run(W), "b=", sess.run(b))
      if not (loss == "NA" ):
        plotdata["batchsize"].append(epoch)
        plotdata["loss"].append(loss)
  print (" Finished!")
  print ("cost=", sess.run(cost, feed_dict={inputdict['x']: train_X, inputdict['y']: train_Y}), "W=", sess.run(W), "b=", sess.run(b))
  #图形显示
  plt.plot(train_X, train_Y, 'ro', label='Original data')
  plt.plot(train_X, sess.run(W) * train_X + sess.run(b), label='Fitted line')
  plt.legend()
  plt.show()
  
  plotdata["avgloss"] = moving_average(plotdata["loss"])
  plt.figure(1)
  plt.subplot(211)
  plt.plot(plotdata["batchsize"], plotdata["avgloss"], 'b--')
  plt.xlabel('Minibatch number')
  plt.ylabel('Loss')
  plt.title('Minibatch run vs. Training loss')
   
  plt.show()
  print ("x=0.2,z=", sess.run(z, feed_dict={inputdict['x']: 0.2}))

5 运行结果

如何定义TensorFlow输入节点

三 直接定义

1 实例

直接定义输入结果

2 解释

直接定义:将定义好的Python变量直接放到OP节点中参与输入的运算,将模拟数据的变量直接放到模型中训练。

3 代码

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
#生成模拟数据
train_X =np.float32( np.linspace(-1, 1, 100))
train_Y = 2 * train_X + np.random.randn(*train_X.shape) * 0.3 # y=2x,但是加入了噪声
#图形显示
plt.plot(train_X, train_Y, 'ro', label='Original data')
plt.legend()
plt.show()
# 创建模型
# 模型参数
W = tf.Variable(tf.random_normal([1]), name="weight")
b = tf.Variable(tf.zeros([1]), name="bias")
# 前向结构
z = tf.multiply(W, train_X)+ b
#反向优化
cost =tf.reduce_mean( tf.square(train_Y - z))
learning_rate = 0.01
optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost) #Gradient descent
# 初始化变量
init = tf.global_variables_initializer()
#参数设置
training_epochs = 20
display_step = 2
# 启动session
with tf.Session() as sess:
  sess.run(init)
  # Fit all training data
  for epoch in range(training_epochs):
    for (x, y) in zip(train_X, train_Y):
      sess.run(optimizer)
    #显示训练中的详细信息
    if epoch % display_step == 0:
      loss = sess.run(cost)
      print ("Epoch:", epoch+1, "cost=", loss,"W=", sess.run(W), "b=", sess.run(b))
  print (" Finished!")
  print ("cost=", sess.run(cost), "W=", sess.run(W), "b=", sess.run(b))

4 运行结果

如何定义TensorFlow输入节点

以上这篇如何定义TensorFlow输入节点就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python实现获取客户机上指定文件并传输到服务器的方法
Mar 16 Python
使用Python搭建虚拟环境的配置方法
Feb 28 Python
Python Tkinter模块实现时钟功能应用示例
Jul 23 Python
Python 实现文件读写、坐标寻址、查找替换功能
Sep 11 Python
python获取array中指定元素的示例
Nov 26 Python
python读取csv文件指定行的2种方法详解
Feb 13 Python
django xadmin action兼容自定义model权限教程
Mar 30 Python
xadmin使用formfield_for_dbfield函数过滤下拉表单实例
Apr 07 Python
Python 使用 PyQt5 开发的关机小工具分享
Jul 16 Python
Python判断变量是否是None写法代码实例
Oct 09 Python
Python用Jira库来操作Jira
Dec 28 Python
python中mongodb包操作数据库
Apr 19 Python
django 文件上传功能的相关实例代码(简单易懂)
Jan 22 #Python
python动态文本进度条的实例代码
Jan 22 #Python
python wav模块获取采样率 采样点声道量化位数(实例代码)
Jan 22 #Python
使用Python实现Wake On Lan远程开机功能
Jan 22 #Python
python定义类self用法实例解析
Jan 22 #Python
通过实例解析python描述符原理作用
Jan 22 #Python
python基于property()函数定义属性
Jan 22 #Python
You might like
PHP开发中常用的8个小技巧
2008/08/27 PHP
php定义参数数量可变的函数用法实例
2015/03/16 PHP
必须收藏的php实用代码片段
2016/02/02 PHP
javascript网页关闭时提醒效果脚本
2008/10/22 Javascript
过虑特殊字符输入的js代码
2010/08/05 Javascript
浅谈Javascript嵌套函数及闭包
2010/11/09 Javascript
jquery 卷帘效果实现代码(不同方向)
2013/02/05 Javascript
设置jsf的选择框h:selectOneMenu为不可编辑状态的方法
2014/01/07 Javascript
js正则表达式中test,exec,match方法的区别说明
2014/01/29 Javascript
Dojo Javascript 编程规范 规范自己的JavaScript书写
2014/10/26 Javascript
60个很实用的jQuery代码开发技巧收集
2014/12/15 Javascript
TypeScript具有的几个不同特质
2015/04/07 Javascript
jquery对所有input type=text的控件赋值实现方法
2016/12/02 Javascript
jquery仿京东侧边栏导航效果
2017/03/02 Javascript
20个最常见的jQuery面试问题及答案
2018/05/23 jQuery
Electron-vue开发的客户端支付收款工具的实现
2019/05/24 Javascript
Javascript摸拟自由落体与上抛运动原理与实现方法详解
2020/04/08 Javascript
Node.js设置定时任务之node-schedule模块的使用详解
2020/04/28 Javascript
有关vue 开发钉钉 H5 微应用 dd.ready() 不执行问题及快速解决方案
2020/05/09 Javascript
vue页面跳转实现页面缓存操作
2020/07/22 Javascript
[10:42]Team Liquid Vs Newbee
2018/06/07 DOTA
python使用新浪微博api上传图片到微博示例
2014/01/10 Python
有关wxpython pyqt内存占用问题分析
2014/06/09 Python
python网络编程学习笔记(五):socket的一些补充
2014/06/09 Python
K-means聚类算法介绍与利用python实现的代码示例
2017/11/13 Python
使用 Python 实现微信群友统计器的思路详解
2018/09/26 Python
Python基于mysql实现学生管理系统
2019/02/21 Python
如何用Python制作微信好友个性签名词云图
2019/06/28 Python
python 自定义装饰器实例详解
2019/07/20 Python
Django实现web端tailf日志文件功能及实例详解
2019/07/28 Python
python中利用matplotlib读取灰度图的例子
2019/12/07 Python
django template实现定义临时变量,自定义赋值、自增实例
2020/07/12 Python
高一地理教学反思
2014/01/18 职场文书
2015年度班主任自我评价
2015/03/11 职场文书
世界名著读书笔记
2015/06/25 职场文书
创业计划书之都市休闲农庄
2019/12/28 职场文书