TensorFlow绘制loss/accuracy曲线的实例


Posted in Python onJanuary 21, 2020

1. 多曲线

1.1 使用pyplot方式

import numpy as np
import matplotlib.pyplot as plt
 
x = np.arange(1, 11, 1)
 
plt.plot(x, x * 2, label="First")
plt.plot(x, x * 3, label="Second")
plt.plot(x, x * 4, label="Third")
 
plt.legend(loc=0, ncol=1)  # 参数:loc设置显示的位置,0是自适应;ncol设置显示的列数
 
plt.show()

1.2 使用面向对象方式

import numpy as np
import matplotlib.pyplot as plt
 
x = np.arange(1, 11, 1)
 
fig = plt.figure()
ax = fig.add_subplot(111)
 
 
ax.plot(x, x * 2, label="First")
ax.plot(x, x * 3, label="Second")
 
ax.legend(loc=0)
# ax.plot(x, x * 2)
# ax.legend([”Demo“], loc=0)
 
plt.show()

TensorFlow绘制loss/accuracy曲线的实例

2. 双y轴曲线

双y轴曲线图例合并是一个棘手的操作,现以MNIST案例中loss/accuracy绘制曲线。

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import time
import matplotlib.pyplot as plt
import numpy as np
 
x_data = tf.placeholder(tf.float32, [None, 784])
y_data = tf.placeholder(tf.float32, [None, 10])
x_image = tf.reshape(x_data, [-1, 28, 28, 1])
 
# convolve layer 1
filter1 = tf.Variable(tf.truncated_normal([5, 5, 1, 6]))
bias1 = tf.Variable(tf.truncated_normal([6]))
conv1 = tf.nn.conv2d(x_image, filter1, strides=[1, 1, 1, 1], padding='SAME')
h_conv1 = tf.nn.sigmoid(conv1 + bias1)
maxPool2 = tf.nn.max_pool(h_conv1, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
 
# convolve layer 2
filter2 = tf.Variable(tf.truncated_normal([5, 5, 6, 16]))
bias2 = tf.Variable(tf.truncated_normal([16]))
conv2 = tf.nn.conv2d(maxPool2, filter2, strides=[1, 1, 1, 1], padding='SAME')
h_conv2 = tf.nn.sigmoid(conv2 + bias2)
maxPool3 = tf.nn.max_pool(h_conv2, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
 
# convolve layer 3
filter3 = tf.Variable(tf.truncated_normal([5, 5, 16, 120]))
bias3 = tf.Variable(tf.truncated_normal([120]))
conv3 = tf.nn.conv2d(maxPool3, filter3, strides=[1, 1, 1, 1], padding='SAME')
h_conv3 = tf.nn.sigmoid(conv3 + bias3)
 
# full connection layer 1
W_fc1 = tf.Variable(tf.truncated_normal([7 * 7 * 120, 80]))
b_fc1 = tf.Variable(tf.truncated_normal([80]))
h_pool2_flat = tf.reshape(h_conv3, [-1, 7 * 7 * 120])
h_fc1 = tf.nn.sigmoid(tf.matmul(h_pool2_flat, W_fc1) + b_fc1)
 
# full connection layer 2
W_fc2 = tf.Variable(tf.truncated_normal([80, 10]))
b_fc2 = tf.Variable(tf.truncated_normal([10]))
y_model = tf.nn.softmax(tf.matmul(h_fc1, W_fc2) + b_fc2)
 
cross_entropy = - tf.reduce_sum(y_data * tf.log(y_model))
 
train_step = tf.train.GradientDescentOptimizer(1e-3).minimize(cross_entropy)
 
sess = tf.InteractiveSession()
correct_prediction = tf.equal(tf.argmax(y_data, 1), tf.argmax(y_model, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))
sess.run(tf.global_variables_initializer())
 
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
 
fig_loss = np.zeros([1000])
fig_accuracy = np.zeros([1000])
 
start_time = time.time()
for i in range(1000):
  batch_xs, batch_ys = mnist.train.next_batch(200)
  if i % 100 == 0:
    train_accuracy = sess.run(accuracy, feed_dict={x_data: batch_xs, y_data: batch_ys})
    print("step %d, train accuracy %g" % (i, train_accuracy))
    end_time = time.time()
    print("time:", (end_time - start_time))
    start_time = end_time
    print("********************************")
  train_step.run(feed_dict={x_data: batch_xs, y_data: batch_ys})
  fig_loss[i] = sess.run(cross_entropy, feed_dict={x_data: batch_xs, y_data: batch_ys})
  fig_accuracy[i] = sess.run(accuracy, feed_dict={x_data: batch_xs, y_data: batch_ys})
print("test accuracy %g" % sess.run(accuracy, feed_dict={x_data: mnist.test.images, y_data: mnist.test.labels}))
 
 
# 绘制曲线
fig, ax1 = plt.subplots()
ax2 = ax1.twinx()
lns1 = ax1.plot(np.arange(1000), fig_loss, label="Loss")
# 按一定间隔显示实现方法
# ax2.plot(200 * np.arange(len(fig_accuracy)), fig_accuracy, 'r')
lns2 = ax2.plot(np.arange(1000), fig_accuracy, 'r', label="Accuracy")
ax1.set_xlabel('iteration')
ax1.set_ylabel('training loss')
ax2.set_ylabel('training accuracy')
# 合并图例
lns = lns1 + lns2
labels = ["Loss", "Accuracy"]
# labels = [l.get_label() for l in lns]
plt.legend(lns, labels, loc=7)
plt.show()

注:数据集保存在MNIST_data文件夹下

其实就是三步:

1)分别定义loss/accuracy一维数组

fig_loss = np.zeros([1000])
fig_accuracy = np.zeros([1000])
# 按间隔定义方式:fig_accuracy = np.zeros(int(np.ceil(iteration / interval)))

2)填充真实数据

fig_loss[i] = sess.run(cross_entropy, feed_dict={x_data: batch_xs, y_data: batch_ys})
 fig_accuracy[i] = sess.run(accuracy, feed_dict={x_data: batch_xs, y_data: batch_ys})

3)绘制曲线

fig, ax1 = plt.subplots()
ax2 = ax1.twinx()
lns1 = ax1.plot(np.arange(1000), fig_loss, label="Loss")
# 按一定间隔显示实现方法
# ax2.plot(200 * np.arange(len(fig_accuracy)), fig_accuracy, 'r')
lns2 = ax2.plot(np.arange(1000), fig_accuracy, 'r', label="Accuracy")
ax1.set_xlabel('iteration')
ax1.set_ylabel('training loss')
ax2.set_ylabel('training accuracy')
# 合并图例
lns = lns1 + lns2
labels = ["Loss", "Accuracy"]
# labels = [l.get_label() for l in lns]
plt.legend(lns, labels, loc=7)

TensorFlow绘制loss/accuracy曲线的实例

以上这篇TensorFlow绘制loss/accuracy曲线的实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python中实现字符串类型与字典类型相互转换的方法
Aug 18 Python
分析Python编程时利用wxPython来支持多线程的方法
Apr 07 Python
Python使用numpy模块创建数组操作示例
Jun 20 Python
使用GitHub和Python实现持续部署的方法
May 09 Python
Python安装与基本数据类型教程详解
May 29 Python
django 通过url实现简单的权限控制的例子
Aug 16 Python
python网络编程 使用UDP、TCP协议收发信息详解
Aug 29 Python
详解字符串在Python内部是如何省内存的
Feb 03 Python
pytorch之Resize()函数具体使用详解
Feb 27 Python
python自定义函数def的应用详解
Jun 03 Python
Python面向对象实现方法总结
Aug 12 Python
python数据处理之Pandas类型转换
Apr 28 Python
NumPy统计函数的实现方法
Jan 21 #Python
TensorFlow实现打印每一层的输出
Jan 21 #Python
NumPy排序的实现
Jan 21 #Python
tensorflow实现在函数中用tf.Print输出中间值
Jan 21 #Python
Python实现随机生成任意数量车牌号
Jan 21 #Python
tensorflow模型继续训练 fineturn实例
Jan 21 #Python
tensorflow ckpt模型和pb模型获取节点名称,及ckpt转pb模型实例
Jan 21 #Python
You might like
PHP 程序员也要学会使用“异常”
2009/06/16 PHP
MySql数据库查询结果用表格输出PHP代码示例
2015/03/20 PHP
eclipse php wamp配置教程
2016/06/30 PHP
javascript闭包的理解和实例
2010/08/12 Javascript
解决火狐浏览器下JS setTimeout函数不兼容失效不执行的方法
2012/11/14 Javascript
js左侧三级菜单导航实例代码
2013/09/13 Javascript
Jquery如何实现点击时高亮显示代码
2014/01/22 Javascript
比例尺、缩略图、平移缩放之百度地图添加控件方法
2015/08/03 Javascript
jQuery easyUI datagrid 增加求和统计行的实现代码
2016/06/01 Javascript
Angular2.js实现表单验证详解
2017/06/23 Javascript
手机注册发送验证码倒计时的简单实例
2017/11/15 Javascript
如何理解Vue的v-model指令的使用方法
2018/07/19 Javascript
微信小程序如何刷新当前界面的实现方法
2019/06/07 Javascript
node.js使用fs读取文件出错的解决方案
2019/10/23 Javascript
vue点击自增和求和的实例代码
2019/11/06 Javascript
在vue和element-ui的table中实现分页复选功能
2019/12/04 Javascript
Vue多选列表组件深入详解
2021/03/02 Vue.js
11个并不被常用但对开发非常有帮助的Python库
2015/03/31 Python
在RedHat系Linux上部署Python的Celery框架的教程
2015/04/07 Python
解决phantomjs截图失败,phantom.exit位置的问题
2018/05/17 Python
Pytorch中index_select() 函数的实现理解
2019/11/19 Python
Python HTMLTestRunner库安装过程解析
2020/05/25 Python
tensorflow下的图片标准化函数per_image_standardization用法
2020/06/30 Python
用HTML5制作一个简单的桌球游戏的教程
2015/05/12 HTML / CSS
欧迪办公美国官网:Office Depot
2016/08/22 全球购物
PHP如何防止SQL注入
2014/05/03 面试题
将"引用"作为函数参数有哪些特点
2013/04/05 面试题
汉语言文学专业求职信
2014/06/19 职场文书
保险专业求职信
2014/07/07 职场文书
演讲比赛的活动方案
2014/08/28 职场文书
印刷技术专业自荐信
2014/09/18 职场文书
党员对照检查材料
2014/09/22 职场文书
爱岗敬业先进典型事迹材料(2016推荐版)
2016/02/26 职场文书
委托开发合同书(标准版)
2019/08/07 职场文书
JAVA API 实用类 String详解
2021/10/05 Java/Android
彩虹社八名人气艺人全新周边限时推出,性转女装男装一次拥有!
2022/04/01 日漫