TensorFlow绘制loss/accuracy曲线的实例


Posted in Python onJanuary 21, 2020

1. 多曲线

1.1 使用pyplot方式

import numpy as np
import matplotlib.pyplot as plt
 
x = np.arange(1, 11, 1)
 
plt.plot(x, x * 2, label="First")
plt.plot(x, x * 3, label="Second")
plt.plot(x, x * 4, label="Third")
 
plt.legend(loc=0, ncol=1)  # 参数:loc设置显示的位置,0是自适应;ncol设置显示的列数
 
plt.show()

1.2 使用面向对象方式

import numpy as np
import matplotlib.pyplot as plt
 
x = np.arange(1, 11, 1)
 
fig = plt.figure()
ax = fig.add_subplot(111)
 
 
ax.plot(x, x * 2, label="First")
ax.plot(x, x * 3, label="Second")
 
ax.legend(loc=0)
# ax.plot(x, x * 2)
# ax.legend([”Demo“], loc=0)
 
plt.show()

TensorFlow绘制loss/accuracy曲线的实例

2. 双y轴曲线

双y轴曲线图例合并是一个棘手的操作,现以MNIST案例中loss/accuracy绘制曲线。

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import time
import matplotlib.pyplot as plt
import numpy as np
 
x_data = tf.placeholder(tf.float32, [None, 784])
y_data = tf.placeholder(tf.float32, [None, 10])
x_image = tf.reshape(x_data, [-1, 28, 28, 1])
 
# convolve layer 1
filter1 = tf.Variable(tf.truncated_normal([5, 5, 1, 6]))
bias1 = tf.Variable(tf.truncated_normal([6]))
conv1 = tf.nn.conv2d(x_image, filter1, strides=[1, 1, 1, 1], padding='SAME')
h_conv1 = tf.nn.sigmoid(conv1 + bias1)
maxPool2 = tf.nn.max_pool(h_conv1, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
 
# convolve layer 2
filter2 = tf.Variable(tf.truncated_normal([5, 5, 6, 16]))
bias2 = tf.Variable(tf.truncated_normal([16]))
conv2 = tf.nn.conv2d(maxPool2, filter2, strides=[1, 1, 1, 1], padding='SAME')
h_conv2 = tf.nn.sigmoid(conv2 + bias2)
maxPool3 = tf.nn.max_pool(h_conv2, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
 
# convolve layer 3
filter3 = tf.Variable(tf.truncated_normal([5, 5, 16, 120]))
bias3 = tf.Variable(tf.truncated_normal([120]))
conv3 = tf.nn.conv2d(maxPool3, filter3, strides=[1, 1, 1, 1], padding='SAME')
h_conv3 = tf.nn.sigmoid(conv3 + bias3)
 
# full connection layer 1
W_fc1 = tf.Variable(tf.truncated_normal([7 * 7 * 120, 80]))
b_fc1 = tf.Variable(tf.truncated_normal([80]))
h_pool2_flat = tf.reshape(h_conv3, [-1, 7 * 7 * 120])
h_fc1 = tf.nn.sigmoid(tf.matmul(h_pool2_flat, W_fc1) + b_fc1)
 
# full connection layer 2
W_fc2 = tf.Variable(tf.truncated_normal([80, 10]))
b_fc2 = tf.Variable(tf.truncated_normal([10]))
y_model = tf.nn.softmax(tf.matmul(h_fc1, W_fc2) + b_fc2)
 
cross_entropy = - tf.reduce_sum(y_data * tf.log(y_model))
 
train_step = tf.train.GradientDescentOptimizer(1e-3).minimize(cross_entropy)
 
sess = tf.InteractiveSession()
correct_prediction = tf.equal(tf.argmax(y_data, 1), tf.argmax(y_model, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))
sess.run(tf.global_variables_initializer())
 
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
 
fig_loss = np.zeros([1000])
fig_accuracy = np.zeros([1000])
 
start_time = time.time()
for i in range(1000):
  batch_xs, batch_ys = mnist.train.next_batch(200)
  if i % 100 == 0:
    train_accuracy = sess.run(accuracy, feed_dict={x_data: batch_xs, y_data: batch_ys})
    print("step %d, train accuracy %g" % (i, train_accuracy))
    end_time = time.time()
    print("time:", (end_time - start_time))
    start_time = end_time
    print("********************************")
  train_step.run(feed_dict={x_data: batch_xs, y_data: batch_ys})
  fig_loss[i] = sess.run(cross_entropy, feed_dict={x_data: batch_xs, y_data: batch_ys})
  fig_accuracy[i] = sess.run(accuracy, feed_dict={x_data: batch_xs, y_data: batch_ys})
print("test accuracy %g" % sess.run(accuracy, feed_dict={x_data: mnist.test.images, y_data: mnist.test.labels}))
 
 
# 绘制曲线
fig, ax1 = plt.subplots()
ax2 = ax1.twinx()
lns1 = ax1.plot(np.arange(1000), fig_loss, label="Loss")
# 按一定间隔显示实现方法
# ax2.plot(200 * np.arange(len(fig_accuracy)), fig_accuracy, 'r')
lns2 = ax2.plot(np.arange(1000), fig_accuracy, 'r', label="Accuracy")
ax1.set_xlabel('iteration')
ax1.set_ylabel('training loss')
ax2.set_ylabel('training accuracy')
# 合并图例
lns = lns1 + lns2
labels = ["Loss", "Accuracy"]
# labels = [l.get_label() for l in lns]
plt.legend(lns, labels, loc=7)
plt.show()

注:数据集保存在MNIST_data文件夹下

其实就是三步:

1)分别定义loss/accuracy一维数组

fig_loss = np.zeros([1000])
fig_accuracy = np.zeros([1000])
# 按间隔定义方式:fig_accuracy = np.zeros(int(np.ceil(iteration / interval)))

2)填充真实数据

fig_loss[i] = sess.run(cross_entropy, feed_dict={x_data: batch_xs, y_data: batch_ys})
 fig_accuracy[i] = sess.run(accuracy, feed_dict={x_data: batch_xs, y_data: batch_ys})

3)绘制曲线

fig, ax1 = plt.subplots()
ax2 = ax1.twinx()
lns1 = ax1.plot(np.arange(1000), fig_loss, label="Loss")
# 按一定间隔显示实现方法
# ax2.plot(200 * np.arange(len(fig_accuracy)), fig_accuracy, 'r')
lns2 = ax2.plot(np.arange(1000), fig_accuracy, 'r', label="Accuracy")
ax1.set_xlabel('iteration')
ax1.set_ylabel('training loss')
ax2.set_ylabel('training accuracy')
# 合并图例
lns = lns1 + lns2
labels = ["Loss", "Accuracy"]
# labels = [l.get_label() for l in lns]
plt.legend(lns, labels, loc=7)

TensorFlow绘制loss/accuracy曲线的实例

以上这篇TensorFlow绘制loss/accuracy曲线的实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python使用百度翻译进行中翻英示例
Apr 14 Python
Python列出一个文件夹及其子目录的所有文件
Jun 30 Python
将tensorflow的ckpt模型存储为npy的实例
Jul 09 Python
Python 实现Windows开机运行某软件的方法
Oct 14 Python
Django框架模板注入操作示例【变量传递到模板】
Dec 19 Python
python绘制评估优化算法性能的测试函数
Jun 25 Python
Python求离散序列导数的示例
Jul 10 Python
浅谈Django中view对数据库的调用方法
Jul 18 Python
Django中ajax发送post请求 报403错误CSRF验证失败解决方案
Aug 13 Python
在Windows上安装和配置 Jupyter Lab 作为桌面级应用程序教程
Apr 22 Python
在主流系统之上安装Pygame的方法
May 20 Python
单身狗福利?Python爬取某婚恋网征婚数据
Jun 03 Python
NumPy统计函数的实现方法
Jan 21 #Python
TensorFlow实现打印每一层的输出
Jan 21 #Python
NumPy排序的实现
Jan 21 #Python
tensorflow实现在函数中用tf.Print输出中间值
Jan 21 #Python
Python实现随机生成任意数量车牌号
Jan 21 #Python
tensorflow模型继续训练 fineturn实例
Jan 21 #Python
tensorflow ckpt模型和pb模型获取节点名称,及ckpt转pb模型实例
Jan 21 #Python
You might like
谈谈PHP的输入输出流
2007/02/14 PHP
PHP的范围解析操作符(::)的含义分析说明
2011/07/03 PHP
PHP禁止个别IP访问网站
2013/10/30 PHP
简单的php新闻发布系统教程
2014/05/09 PHP
php版微信返回用户text输入的方法
2016/11/14 PHP
TP5框架使用QueryList采集框架爬小说操作示例
2020/03/26 PHP
Gambit vs ForZe BO3 第二场 2.13
2021/03/10 DOTA
JQuery 无废话系列教程(二) jquery实战篇上
2009/06/23 Javascript
Javascript的构造函数和constructor属性
2010/01/09 Javascript
Node.js开源应用框架HapiJS介绍
2015/01/14 Javascript
sso跨域写cookie的一段js脚本(推荐)
2016/05/25 Javascript
全面解析Javascript无限添加QQ好友原理
2016/06/15 Javascript
仿Angular Bootstrap TimePicker创建分钟数-秒数的输入控件
2016/07/01 Javascript
JS实现鼠标移上去显示图片或微信二维码
2016/12/14 Javascript
js实现横向拖拽导航条功能
2017/02/17 Javascript
JS实现快递单打印功能【推荐】
2018/06/21 Javascript
微信小程序webview实现长按点击识别二维码功能示例
2019/01/24 Javascript
利用python实现数据分析
2017/01/11 Python
python读取图片并修改格式与大小的方法
2018/07/24 Python
PyQT实现菜单中的复制,全选和清空的功能的方法
2019/06/17 Python
python+selenium 鼠标事件操作方法
2019/08/24 Python
python+opencv实现移动侦测(帧差法)
2020/03/20 Python
使用Python将Exception异常错误堆栈信息写入日志文件
2020/04/08 Python
python使用hdfs3模块对hdfs进行操作详解
2020/06/06 Python
挪威户外活动服装和装备购物网站:Bergfreunde挪威
2016/10/20 全球购物
公司新员工的演讲稿注意事项
2014/01/01 职场文书
仓库管理制度
2014/01/21 职场文书
2014年学雷锋活动总结
2014/06/26 职场文书
英语自我介绍演讲稿
2014/09/01 职场文书
酒店保洁员岗位职责
2015/02/26 职场文书
工程合作意向书范本
2015/05/09 职场文书
运动会新闻稿
2015/07/17 职场文书
公司员工奖惩制度
2015/08/04 职场文书
(开源)微信小程序+mqtt,esp8266温湿度读取
2021/04/02 Javascript
最新最全的手机号验证正则表达式
2022/02/24 Javascript
分享CSS盒子模型隐藏的几种方式
2022/02/28 HTML / CSS