TensorFlow绘制loss/accuracy曲线的实例


Posted in Python onJanuary 21, 2020

1. 多曲线

1.1 使用pyplot方式

import numpy as np
import matplotlib.pyplot as plt
 
x = np.arange(1, 11, 1)
 
plt.plot(x, x * 2, label="First")
plt.plot(x, x * 3, label="Second")
plt.plot(x, x * 4, label="Third")
 
plt.legend(loc=0, ncol=1)  # 参数:loc设置显示的位置,0是自适应;ncol设置显示的列数
 
plt.show()

1.2 使用面向对象方式

import numpy as np
import matplotlib.pyplot as plt
 
x = np.arange(1, 11, 1)
 
fig = plt.figure()
ax = fig.add_subplot(111)
 
 
ax.plot(x, x * 2, label="First")
ax.plot(x, x * 3, label="Second")
 
ax.legend(loc=0)
# ax.plot(x, x * 2)
# ax.legend([”Demo“], loc=0)
 
plt.show()

TensorFlow绘制loss/accuracy曲线的实例

2. 双y轴曲线

双y轴曲线图例合并是一个棘手的操作,现以MNIST案例中loss/accuracy绘制曲线。

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import time
import matplotlib.pyplot as plt
import numpy as np
 
x_data = tf.placeholder(tf.float32, [None, 784])
y_data = tf.placeholder(tf.float32, [None, 10])
x_image = tf.reshape(x_data, [-1, 28, 28, 1])
 
# convolve layer 1
filter1 = tf.Variable(tf.truncated_normal([5, 5, 1, 6]))
bias1 = tf.Variable(tf.truncated_normal([6]))
conv1 = tf.nn.conv2d(x_image, filter1, strides=[1, 1, 1, 1], padding='SAME')
h_conv1 = tf.nn.sigmoid(conv1 + bias1)
maxPool2 = tf.nn.max_pool(h_conv1, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
 
# convolve layer 2
filter2 = tf.Variable(tf.truncated_normal([5, 5, 6, 16]))
bias2 = tf.Variable(tf.truncated_normal([16]))
conv2 = tf.nn.conv2d(maxPool2, filter2, strides=[1, 1, 1, 1], padding='SAME')
h_conv2 = tf.nn.sigmoid(conv2 + bias2)
maxPool3 = tf.nn.max_pool(h_conv2, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
 
# convolve layer 3
filter3 = tf.Variable(tf.truncated_normal([5, 5, 16, 120]))
bias3 = tf.Variable(tf.truncated_normal([120]))
conv3 = tf.nn.conv2d(maxPool3, filter3, strides=[1, 1, 1, 1], padding='SAME')
h_conv3 = tf.nn.sigmoid(conv3 + bias3)
 
# full connection layer 1
W_fc1 = tf.Variable(tf.truncated_normal([7 * 7 * 120, 80]))
b_fc1 = tf.Variable(tf.truncated_normal([80]))
h_pool2_flat = tf.reshape(h_conv3, [-1, 7 * 7 * 120])
h_fc1 = tf.nn.sigmoid(tf.matmul(h_pool2_flat, W_fc1) + b_fc1)
 
# full connection layer 2
W_fc2 = tf.Variable(tf.truncated_normal([80, 10]))
b_fc2 = tf.Variable(tf.truncated_normal([10]))
y_model = tf.nn.softmax(tf.matmul(h_fc1, W_fc2) + b_fc2)
 
cross_entropy = - tf.reduce_sum(y_data * tf.log(y_model))
 
train_step = tf.train.GradientDescentOptimizer(1e-3).minimize(cross_entropy)
 
sess = tf.InteractiveSession()
correct_prediction = tf.equal(tf.argmax(y_data, 1), tf.argmax(y_model, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))
sess.run(tf.global_variables_initializer())
 
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
 
fig_loss = np.zeros([1000])
fig_accuracy = np.zeros([1000])
 
start_time = time.time()
for i in range(1000):
  batch_xs, batch_ys = mnist.train.next_batch(200)
  if i % 100 == 0:
    train_accuracy = sess.run(accuracy, feed_dict={x_data: batch_xs, y_data: batch_ys})
    print("step %d, train accuracy %g" % (i, train_accuracy))
    end_time = time.time()
    print("time:", (end_time - start_time))
    start_time = end_time
    print("********************************")
  train_step.run(feed_dict={x_data: batch_xs, y_data: batch_ys})
  fig_loss[i] = sess.run(cross_entropy, feed_dict={x_data: batch_xs, y_data: batch_ys})
  fig_accuracy[i] = sess.run(accuracy, feed_dict={x_data: batch_xs, y_data: batch_ys})
print("test accuracy %g" % sess.run(accuracy, feed_dict={x_data: mnist.test.images, y_data: mnist.test.labels}))
 
 
# 绘制曲线
fig, ax1 = plt.subplots()
ax2 = ax1.twinx()
lns1 = ax1.plot(np.arange(1000), fig_loss, label="Loss")
# 按一定间隔显示实现方法
# ax2.plot(200 * np.arange(len(fig_accuracy)), fig_accuracy, 'r')
lns2 = ax2.plot(np.arange(1000), fig_accuracy, 'r', label="Accuracy")
ax1.set_xlabel('iteration')
ax1.set_ylabel('training loss')
ax2.set_ylabel('training accuracy')
# 合并图例
lns = lns1 + lns2
labels = ["Loss", "Accuracy"]
# labels = [l.get_label() for l in lns]
plt.legend(lns, labels, loc=7)
plt.show()

注:数据集保存在MNIST_data文件夹下

其实就是三步:

1)分别定义loss/accuracy一维数组

fig_loss = np.zeros([1000])
fig_accuracy = np.zeros([1000])
# 按间隔定义方式:fig_accuracy = np.zeros(int(np.ceil(iteration / interval)))

2)填充真实数据

fig_loss[i] = sess.run(cross_entropy, feed_dict={x_data: batch_xs, y_data: batch_ys})
 fig_accuracy[i] = sess.run(accuracy, feed_dict={x_data: batch_xs, y_data: batch_ys})

3)绘制曲线

fig, ax1 = plt.subplots()
ax2 = ax1.twinx()
lns1 = ax1.plot(np.arange(1000), fig_loss, label="Loss")
# 按一定间隔显示实现方法
# ax2.plot(200 * np.arange(len(fig_accuracy)), fig_accuracy, 'r')
lns2 = ax2.plot(np.arange(1000), fig_accuracy, 'r', label="Accuracy")
ax1.set_xlabel('iteration')
ax1.set_ylabel('training loss')
ax2.set_ylabel('training accuracy')
# 合并图例
lns = lns1 + lns2
labels = ["Loss", "Accuracy"]
# labels = [l.get_label() for l in lns]
plt.legend(lns, labels, loc=7)

TensorFlow绘制loss/accuracy曲线的实例

以上这篇TensorFlow绘制loss/accuracy曲线的实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python实现螺旋矩阵的填充算法示例
Dec 28 Python
Python使用Selenium+BeautifulSoup爬取淘宝搜索页
Feb 24 Python
python实现快速排序的示例(二分法思想)
Mar 12 Python
python 剪切移动文件的实现代码
Aug 02 Python
Python实现操纵控制windows注册表的方法分析
May 24 Python
在Qt中正确的设置窗体的背景图片的几种方法总结
Jun 19 Python
django用户登录验证的完整示例代码
Jul 21 Python
flask框架蓝图和子域名配置详解
Jan 25 Python
在Django中自定义filter并在template中的使用详解
May 19 Python
Python如何向SQLServer存储二进制图片
Jun 08 Python
Python代码执行时间测量模块timeit用法解析
Jul 01 Python
Python中Matplotlib的点、线形状、颜色以及绘制散点图
Apr 07 Python
NumPy统计函数的实现方法
Jan 21 #Python
TensorFlow实现打印每一层的输出
Jan 21 #Python
NumPy排序的实现
Jan 21 #Python
tensorflow实现在函数中用tf.Print输出中间值
Jan 21 #Python
Python实现随机生成任意数量车牌号
Jan 21 #Python
tensorflow模型继续训练 fineturn实例
Jan 21 #Python
tensorflow ckpt模型和pb模型获取节点名称,及ckpt转pb模型实例
Jan 21 #Python
You might like
PHP中在数据库中保存Checkbox数据(2)
2006/10/09 PHP
PHP实现用户认证及管理完全源码
2007/03/11 PHP
php foreach 参数强制类型转换的问题
2010/12/10 PHP
php 数组排序 array_multisort与uasort的区别
2011/03/24 PHP
php+mysql实现简单登录注册修改密码网页
2016/11/30 PHP
PHP中include()与require()的区别说明
2017/02/14 PHP
Laravel5.5 手动分页和自定义分页样式的简单实现
2019/10/15 PHP
JavaScript 高效运行代码分析
2010/03/18 Javascript
原创javascript小游戏实现代码
2010/08/19 Javascript
js去除输入框中所有的空格和禁止输入空格的方法
2014/06/09 Javascript
Jquery遍历Json数据的方法
2015/04/20 Javascript
js删除Array数组中指定元素的两种方法
2016/08/03 Javascript
node.JS md5加密中文与php结果不一致的解决方法
2017/05/05 Javascript
Angularjs的启动过程分析
2017/07/18 Javascript
NodeJS使用七牛云存储上传文件的方法
2017/07/24 NodeJs
vue-cli 3.x 修改dist路径的方法
2018/09/19 Javascript
如何封装了一个vue移动端下拉加载下一页数据的组件
2019/01/06 Javascript
AI小程序之语音听写来了,十分钟掌握百度大脑语音听写全攻略
2020/03/13 Javascript
JSON获取属性值方法代码实例
2020/06/30 Javascript
[03:01]DOTA2英雄基础教程 露娜
2014/01/07 DOTA
python类:class创建、数据方法属性及访问控制详解
2016/07/25 Python
Python实现基于PIL和tesseract的验证码识别功能示例
2018/07/11 Python
使用Python进行体育竞技分析(预测球队成绩)
2019/05/16 Python
介绍一款python类型检查工具pyright(推荐)
2019/07/03 Python
Python考拉兹猜想输出序列代码实践
2019/07/05 Python
Python基础之列表常见操作经典实例详解
2020/02/26 Python
Python+Kepler.gl轻松制作酷炫路径动画的实现示例
2020/06/02 Python
HTML5触摸事件(touchstart、touchmove和touchend)的实现
2020/05/08 HTML / CSS
两道JAVA笔试题
2016/09/14 面试题
平安校园建设方案
2014/05/02 职场文书
小学语文教学经验交流材料
2014/06/02 职场文书
2015羊年春节慰问信
2015/02/14 职场文书
运动会闭幕式通讯稿
2015/07/18 职场文书
MySQL分库分表与分区的入门指南
2021/04/22 MySQL
MySQL数据库中varchar类型的数字比较大小的方法
2021/11/17 MySQL
Python如何利用pandas读取csv数据并绘图
2022/07/07 Python