TensorFlow绘制loss/accuracy曲线的实例


Posted in Python onJanuary 21, 2020

1. 多曲线

1.1 使用pyplot方式

import numpy as np
import matplotlib.pyplot as plt
 
x = np.arange(1, 11, 1)
 
plt.plot(x, x * 2, label="First")
plt.plot(x, x * 3, label="Second")
plt.plot(x, x * 4, label="Third")
 
plt.legend(loc=0, ncol=1)  # 参数:loc设置显示的位置,0是自适应;ncol设置显示的列数
 
plt.show()

1.2 使用面向对象方式

import numpy as np
import matplotlib.pyplot as plt
 
x = np.arange(1, 11, 1)
 
fig = plt.figure()
ax = fig.add_subplot(111)
 
 
ax.plot(x, x * 2, label="First")
ax.plot(x, x * 3, label="Second")
 
ax.legend(loc=0)
# ax.plot(x, x * 2)
# ax.legend([”Demo“], loc=0)
 
plt.show()

TensorFlow绘制loss/accuracy曲线的实例

2. 双y轴曲线

双y轴曲线图例合并是一个棘手的操作,现以MNIST案例中loss/accuracy绘制曲线。

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import time
import matplotlib.pyplot as plt
import numpy as np
 
x_data = tf.placeholder(tf.float32, [None, 784])
y_data = tf.placeholder(tf.float32, [None, 10])
x_image = tf.reshape(x_data, [-1, 28, 28, 1])
 
# convolve layer 1
filter1 = tf.Variable(tf.truncated_normal([5, 5, 1, 6]))
bias1 = tf.Variable(tf.truncated_normal([6]))
conv1 = tf.nn.conv2d(x_image, filter1, strides=[1, 1, 1, 1], padding='SAME')
h_conv1 = tf.nn.sigmoid(conv1 + bias1)
maxPool2 = tf.nn.max_pool(h_conv1, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
 
# convolve layer 2
filter2 = tf.Variable(tf.truncated_normal([5, 5, 6, 16]))
bias2 = tf.Variable(tf.truncated_normal([16]))
conv2 = tf.nn.conv2d(maxPool2, filter2, strides=[1, 1, 1, 1], padding='SAME')
h_conv2 = tf.nn.sigmoid(conv2 + bias2)
maxPool3 = tf.nn.max_pool(h_conv2, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
 
# convolve layer 3
filter3 = tf.Variable(tf.truncated_normal([5, 5, 16, 120]))
bias3 = tf.Variable(tf.truncated_normal([120]))
conv3 = tf.nn.conv2d(maxPool3, filter3, strides=[1, 1, 1, 1], padding='SAME')
h_conv3 = tf.nn.sigmoid(conv3 + bias3)
 
# full connection layer 1
W_fc1 = tf.Variable(tf.truncated_normal([7 * 7 * 120, 80]))
b_fc1 = tf.Variable(tf.truncated_normal([80]))
h_pool2_flat = tf.reshape(h_conv3, [-1, 7 * 7 * 120])
h_fc1 = tf.nn.sigmoid(tf.matmul(h_pool2_flat, W_fc1) + b_fc1)
 
# full connection layer 2
W_fc2 = tf.Variable(tf.truncated_normal([80, 10]))
b_fc2 = tf.Variable(tf.truncated_normal([10]))
y_model = tf.nn.softmax(tf.matmul(h_fc1, W_fc2) + b_fc2)
 
cross_entropy = - tf.reduce_sum(y_data * tf.log(y_model))
 
train_step = tf.train.GradientDescentOptimizer(1e-3).minimize(cross_entropy)
 
sess = tf.InteractiveSession()
correct_prediction = tf.equal(tf.argmax(y_data, 1), tf.argmax(y_model, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))
sess.run(tf.global_variables_initializer())
 
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
 
fig_loss = np.zeros([1000])
fig_accuracy = np.zeros([1000])
 
start_time = time.time()
for i in range(1000):
  batch_xs, batch_ys = mnist.train.next_batch(200)
  if i % 100 == 0:
    train_accuracy = sess.run(accuracy, feed_dict={x_data: batch_xs, y_data: batch_ys})
    print("step %d, train accuracy %g" % (i, train_accuracy))
    end_time = time.time()
    print("time:", (end_time - start_time))
    start_time = end_time
    print("********************************")
  train_step.run(feed_dict={x_data: batch_xs, y_data: batch_ys})
  fig_loss[i] = sess.run(cross_entropy, feed_dict={x_data: batch_xs, y_data: batch_ys})
  fig_accuracy[i] = sess.run(accuracy, feed_dict={x_data: batch_xs, y_data: batch_ys})
print("test accuracy %g" % sess.run(accuracy, feed_dict={x_data: mnist.test.images, y_data: mnist.test.labels}))
 
 
# 绘制曲线
fig, ax1 = plt.subplots()
ax2 = ax1.twinx()
lns1 = ax1.plot(np.arange(1000), fig_loss, label="Loss")
# 按一定间隔显示实现方法
# ax2.plot(200 * np.arange(len(fig_accuracy)), fig_accuracy, 'r')
lns2 = ax2.plot(np.arange(1000), fig_accuracy, 'r', label="Accuracy")
ax1.set_xlabel('iteration')
ax1.set_ylabel('training loss')
ax2.set_ylabel('training accuracy')
# 合并图例
lns = lns1 + lns2
labels = ["Loss", "Accuracy"]
# labels = [l.get_label() for l in lns]
plt.legend(lns, labels, loc=7)
plt.show()

注:数据集保存在MNIST_data文件夹下

其实就是三步:

1)分别定义loss/accuracy一维数组

fig_loss = np.zeros([1000])
fig_accuracy = np.zeros([1000])
# 按间隔定义方式:fig_accuracy = np.zeros(int(np.ceil(iteration / interval)))

2)填充真实数据

fig_loss[i] = sess.run(cross_entropy, feed_dict={x_data: batch_xs, y_data: batch_ys})
 fig_accuracy[i] = sess.run(accuracy, feed_dict={x_data: batch_xs, y_data: batch_ys})

3)绘制曲线

fig, ax1 = plt.subplots()
ax2 = ax1.twinx()
lns1 = ax1.plot(np.arange(1000), fig_loss, label="Loss")
# 按一定间隔显示实现方法
# ax2.plot(200 * np.arange(len(fig_accuracy)), fig_accuracy, 'r')
lns2 = ax2.plot(np.arange(1000), fig_accuracy, 'r', label="Accuracy")
ax1.set_xlabel('iteration')
ax1.set_ylabel('training loss')
ax2.set_ylabel('training accuracy')
# 合并图例
lns = lns1 + lns2
labels = ["Loss", "Accuracy"]
# labels = [l.get_label() for l in lns]
plt.legend(lns, labels, loc=7)

TensorFlow绘制loss/accuracy曲线的实例

以上这篇TensorFlow绘制loss/accuracy曲线的实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python实现多行注释的另类方法
Aug 22 Python
使用python实现tcp自动重连
Jul 02 Python
用Pygal绘制直方图代码示例
Dec 07 Python
python实现两个文件合并功能
Apr 01 Python
使用python接入微信聊天机器人
Mar 31 Python
PyQt5重写QComboBox的鼠标点击事件方法
Jun 25 Python
如何在Cloud Studio上执行Python代码?
Aug 09 Python
python基于event实现线程间通信控制
Jan 13 Python
Python模块_PyLibTiff读取tif文件的实例
Jan 13 Python
Pytorch maxpool的ceil_mode用法
Feb 18 Python
python 安装库几种方法之cmd,anaconda,pycharm详解
Apr 08 Python
浅谈Python3多线程之间的执行顺序问题
May 02 Python
NumPy统计函数的实现方法
Jan 21 #Python
TensorFlow实现打印每一层的输出
Jan 21 #Python
NumPy排序的实现
Jan 21 #Python
tensorflow实现在函数中用tf.Print输出中间值
Jan 21 #Python
Python实现随机生成任意数量车牌号
Jan 21 #Python
tensorflow模型继续训练 fineturn实例
Jan 21 #Python
tensorflow ckpt模型和pb模型获取节点名称,及ckpt转pb模型实例
Jan 21 #Python
You might like
PHP判断是否是微信打开,浏览器打开的方法
2018/03/14 PHP
php从数据库读取数据,并以json格式返回数据的方法
2018/08/21 PHP
PHP使用PDO操作sqlite数据库应用案例
2019/03/07 PHP
PHP 多进程与信号中断实现多任务常驻内存管理实例方法
2019/10/04 PHP
jQuery Div中加载其他页面的实现代码
2009/02/27 Javascript
基于jquery的横向滚动条(滑动条)
2011/02/24 Javascript
js获取当前月的第一天和最后一天的小例子
2013/11/18 Javascript
jquery实现更改表格行顺序示例
2014/04/30 Javascript
jQuery插件jFade实现鼠标经过的图片高亮其它变暗
2015/03/14 Javascript
Jquery1.9.1源码分析系列(十五)动画处理之外篇
2015/12/04 Javascript
jquery html5 视频播放控制代码
2016/11/06 Javascript
AngularJS解决ng界面长表达式(ui-set)的方法分析
2016/11/07 Javascript
JS中的phototype详解
2017/02/04 Javascript
react实现换肤功能的示例代码
2018/08/14 Javascript
详解vue中axios的使用与封装
2019/03/20 Javascript
微信小程序和H5页面间相互跳转代码实例
2019/09/19 Javascript
javascript实现切割轮播效果
2019/11/28 Javascript
Python3写入文件常用方法实例分析
2015/05/22 Python
使用Python神器对付12306变态验证码
2016/01/05 Python
简单讲解Python编程中namedtuple类的用法
2016/06/21 Python
详解如何用django实现redirect的几种方法总结
2018/11/22 Python
Python实现针对json中某个关键字段进行排序操作示例
2018/12/25 Python
Python参数传递机制传值和传引用原理详解
2020/05/22 Python
PythonPC客户端自动化实现原理(pywinauto)
2020/05/28 Python
python super()函数的基本使用
2020/09/10 Python
ProBikeKit澳大利亚:自行车套件,跑步和铁人三项装备
2016/11/30 全球购物
ZINVO手表官网:男士和女士手表
2019/03/10 全球购物
Pureology官网:为染色头发打造最好的产品
2019/09/13 全球购物
大学生自荐信
2013/12/11 职场文书
周年庆典邀请函范文
2014/01/24 职场文书
班主任对学生的评语
2014/04/26 职场文书
合同和协议有什么区别?
2014/10/08 职场文书
2014年挂职干部工作总结
2014/12/06 职场文书
2015年文秘个人工作总结
2015/10/14 职场文书
js中Map和Set的用法及区别实例详解
2022/02/15 Javascript
MySQL优化及索引解析
2022/03/17 MySQL