TensorFlow绘制loss/accuracy曲线的实例


Posted in Python onJanuary 21, 2020

1. 多曲线

1.1 使用pyplot方式

import numpy as np
import matplotlib.pyplot as plt
 
x = np.arange(1, 11, 1)
 
plt.plot(x, x * 2, label="First")
plt.plot(x, x * 3, label="Second")
plt.plot(x, x * 4, label="Third")
 
plt.legend(loc=0, ncol=1)  # 参数:loc设置显示的位置,0是自适应;ncol设置显示的列数
 
plt.show()

1.2 使用面向对象方式

import numpy as np
import matplotlib.pyplot as plt
 
x = np.arange(1, 11, 1)
 
fig = plt.figure()
ax = fig.add_subplot(111)
 
 
ax.plot(x, x * 2, label="First")
ax.plot(x, x * 3, label="Second")
 
ax.legend(loc=0)
# ax.plot(x, x * 2)
# ax.legend([”Demo“], loc=0)
 
plt.show()

TensorFlow绘制loss/accuracy曲线的实例

2. 双y轴曲线

双y轴曲线图例合并是一个棘手的操作,现以MNIST案例中loss/accuracy绘制曲线。

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import time
import matplotlib.pyplot as plt
import numpy as np
 
x_data = tf.placeholder(tf.float32, [None, 784])
y_data = tf.placeholder(tf.float32, [None, 10])
x_image = tf.reshape(x_data, [-1, 28, 28, 1])
 
# convolve layer 1
filter1 = tf.Variable(tf.truncated_normal([5, 5, 1, 6]))
bias1 = tf.Variable(tf.truncated_normal([6]))
conv1 = tf.nn.conv2d(x_image, filter1, strides=[1, 1, 1, 1], padding='SAME')
h_conv1 = tf.nn.sigmoid(conv1 + bias1)
maxPool2 = tf.nn.max_pool(h_conv1, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
 
# convolve layer 2
filter2 = tf.Variable(tf.truncated_normal([5, 5, 6, 16]))
bias2 = tf.Variable(tf.truncated_normal([16]))
conv2 = tf.nn.conv2d(maxPool2, filter2, strides=[1, 1, 1, 1], padding='SAME')
h_conv2 = tf.nn.sigmoid(conv2 + bias2)
maxPool3 = tf.nn.max_pool(h_conv2, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
 
# convolve layer 3
filter3 = tf.Variable(tf.truncated_normal([5, 5, 16, 120]))
bias3 = tf.Variable(tf.truncated_normal([120]))
conv3 = tf.nn.conv2d(maxPool3, filter3, strides=[1, 1, 1, 1], padding='SAME')
h_conv3 = tf.nn.sigmoid(conv3 + bias3)
 
# full connection layer 1
W_fc1 = tf.Variable(tf.truncated_normal([7 * 7 * 120, 80]))
b_fc1 = tf.Variable(tf.truncated_normal([80]))
h_pool2_flat = tf.reshape(h_conv3, [-1, 7 * 7 * 120])
h_fc1 = tf.nn.sigmoid(tf.matmul(h_pool2_flat, W_fc1) + b_fc1)
 
# full connection layer 2
W_fc2 = tf.Variable(tf.truncated_normal([80, 10]))
b_fc2 = tf.Variable(tf.truncated_normal([10]))
y_model = tf.nn.softmax(tf.matmul(h_fc1, W_fc2) + b_fc2)
 
cross_entropy = - tf.reduce_sum(y_data * tf.log(y_model))
 
train_step = tf.train.GradientDescentOptimizer(1e-3).minimize(cross_entropy)
 
sess = tf.InteractiveSession()
correct_prediction = tf.equal(tf.argmax(y_data, 1), tf.argmax(y_model, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))
sess.run(tf.global_variables_initializer())
 
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
 
fig_loss = np.zeros([1000])
fig_accuracy = np.zeros([1000])
 
start_time = time.time()
for i in range(1000):
  batch_xs, batch_ys = mnist.train.next_batch(200)
  if i % 100 == 0:
    train_accuracy = sess.run(accuracy, feed_dict={x_data: batch_xs, y_data: batch_ys})
    print("step %d, train accuracy %g" % (i, train_accuracy))
    end_time = time.time()
    print("time:", (end_time - start_time))
    start_time = end_time
    print("********************************")
  train_step.run(feed_dict={x_data: batch_xs, y_data: batch_ys})
  fig_loss[i] = sess.run(cross_entropy, feed_dict={x_data: batch_xs, y_data: batch_ys})
  fig_accuracy[i] = sess.run(accuracy, feed_dict={x_data: batch_xs, y_data: batch_ys})
print("test accuracy %g" % sess.run(accuracy, feed_dict={x_data: mnist.test.images, y_data: mnist.test.labels}))
 
 
# 绘制曲线
fig, ax1 = plt.subplots()
ax2 = ax1.twinx()
lns1 = ax1.plot(np.arange(1000), fig_loss, label="Loss")
# 按一定间隔显示实现方法
# ax2.plot(200 * np.arange(len(fig_accuracy)), fig_accuracy, 'r')
lns2 = ax2.plot(np.arange(1000), fig_accuracy, 'r', label="Accuracy")
ax1.set_xlabel('iteration')
ax1.set_ylabel('training loss')
ax2.set_ylabel('training accuracy')
# 合并图例
lns = lns1 + lns2
labels = ["Loss", "Accuracy"]
# labels = [l.get_label() for l in lns]
plt.legend(lns, labels, loc=7)
plt.show()

注:数据集保存在MNIST_data文件夹下

其实就是三步:

1)分别定义loss/accuracy一维数组

fig_loss = np.zeros([1000])
fig_accuracy = np.zeros([1000])
# 按间隔定义方式:fig_accuracy = np.zeros(int(np.ceil(iteration / interval)))

2)填充真实数据

fig_loss[i] = sess.run(cross_entropy, feed_dict={x_data: batch_xs, y_data: batch_ys})
 fig_accuracy[i] = sess.run(accuracy, feed_dict={x_data: batch_xs, y_data: batch_ys})

3)绘制曲线

fig, ax1 = plt.subplots()
ax2 = ax1.twinx()
lns1 = ax1.plot(np.arange(1000), fig_loss, label="Loss")
# 按一定间隔显示实现方法
# ax2.plot(200 * np.arange(len(fig_accuracy)), fig_accuracy, 'r')
lns2 = ax2.plot(np.arange(1000), fig_accuracy, 'r', label="Accuracy")
ax1.set_xlabel('iteration')
ax1.set_ylabel('training loss')
ax2.set_ylabel('training accuracy')
# 合并图例
lns = lns1 + lns2
labels = ["Loss", "Accuracy"]
# labels = [l.get_label() for l in lns]
plt.legend(lns, labels, loc=7)

TensorFlow绘制loss/accuracy曲线的实例

以上这篇TensorFlow绘制loss/accuracy曲线的实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
使用PDB模式调试Python程序介绍
Apr 05 Python
python简单获取本机计算机名和IP地址的方法
Jun 03 Python
Python 爬虫模拟登陆知乎
Sep 23 Python
详解Python 数据库 (sqlite3)应用
Dec 07 Python
Python升级导致yum、pip报错的解决方法
Sep 06 Python
Python使用itchat模块实现群聊转发,自动回复功能示例
Aug 26 Python
基于python实现文件加密功能
Jan 06 Python
python新手学习可变和不可变对象
Jun 11 Python
Tensorflow之MNIST CNN实现并保存、加载模型
Jun 17 Python
使用Tensorflow-GPU禁用GPU设置(CPU与GPU速度对比)
Jun 30 Python
python 实现朴素贝叶斯算法的示例
Sep 30 Python
python 利用Pyinstaller打包Web项目
Oct 23 Python
NumPy统计函数的实现方法
Jan 21 #Python
TensorFlow实现打印每一层的输出
Jan 21 #Python
NumPy排序的实现
Jan 21 #Python
tensorflow实现在函数中用tf.Print输出中间值
Jan 21 #Python
Python实现随机生成任意数量车牌号
Jan 21 #Python
tensorflow模型继续训练 fineturn实例
Jan 21 #Python
tensorflow ckpt模型和pb模型获取节点名称,及ckpt转pb模型实例
Jan 21 #Python
You might like
透析PHP的配置文件php.ini
2006/10/09 PHP
关于url地址传参数时字符串有回车造成页面脚本赋值失败的解决方法
2013/06/28 PHP
php的ZipArchive类用法实例
2014/10/20 PHP
php 静态属性和静态方法区别详解
2017/04/09 PHP
使用一个for循环将N*N的二维数组的所有值置1实现方法
2017/05/29 PHP
IE8 原生JSON支持
2009/04/13 Javascript
JavaScript类属性的访问方式详解
2014/02/11 Javascript
js简单实现表单中点击按钮动态增加输入框数量的方法
2015/08/18 Javascript
谈谈我对JavaScript原型和闭包系列理解(随手笔记8)
2015/12/24 Javascript
js判断某个字符出现的次数的简单实例
2016/06/03 Javascript
jquery实用技巧之输入框提示语句
2016/07/28 Javascript
利用Javascript实现BMI计算器
2016/08/16 Javascript
AngularJS equal比较对象实例详解
2016/09/14 Javascript
livereload工具实现前端可视化开发【推荐】
2016/12/23 Javascript
微信小程序实现获取自己所处位置的经纬度坐标功能示例
2017/11/30 Javascript
nodeJs实现基于连接池连接mysql的方法示例
2018/02/10 NodeJs
express框架中使用jwt实现验证的方法
2019/08/25 Javascript
layui use 定义js外部引用函数的方法
2019/09/26 Javascript
js实现随机点名程序
2020/09/17 Javascript
element-ui如何防止重复提交的方法步骤
2019/12/09 Javascript
6种JavaScript继承方式及优缺点(小结)
2020/02/06 Javascript
vue2和vue3的v-if与v-for优先级对比学习
2020/10/10 Javascript
python魔法方法-自定义序列详解
2016/07/21 Python
CentOS 6.X系统下升级Python2.6到Python2.7 的方法
2016/10/12 Python
使用Python的package机制如何简化utils包设计详解
2017/12/11 Python
python3.5+tesseract+adb实现西瓜视频或头脑王者辅助答题
2018/01/17 Python
python实现kNN算法识别手写体数字的示例代码
2019/08/16 Python
玩具反斗城西班牙网上商城:ToysRUs西班牙
2017/01/19 全球购物
英国高级健康和美容产品零售商:Life and Looks
2019/08/01 全球购物
印度排名第一的蛋糕、鲜花和礼品送货:Winni
2019/08/02 全球购物
工程师求职简历的自我评价分享
2013/10/10 职场文书
师说教学反思
2014/02/07 职场文书
阳光体育活动总结
2014/04/30 职场文书
入党积极分子半年考察意见
2015/06/02 职场文书
关于战胜挫折的名言警句大全!
2019/07/05 职场文书
React + Threejs + Swiper 实现全景图效果的完整代码
2021/06/28 Javascript