TensorFlow绘制loss/accuracy曲线的实例


Posted in Python onJanuary 21, 2020

1. 多曲线

1.1 使用pyplot方式

import numpy as np
import matplotlib.pyplot as plt
 
x = np.arange(1, 11, 1)
 
plt.plot(x, x * 2, label="First")
plt.plot(x, x * 3, label="Second")
plt.plot(x, x * 4, label="Third")
 
plt.legend(loc=0, ncol=1)  # 参数:loc设置显示的位置,0是自适应;ncol设置显示的列数
 
plt.show()

1.2 使用面向对象方式

import numpy as np
import matplotlib.pyplot as plt
 
x = np.arange(1, 11, 1)
 
fig = plt.figure()
ax = fig.add_subplot(111)
 
 
ax.plot(x, x * 2, label="First")
ax.plot(x, x * 3, label="Second")
 
ax.legend(loc=0)
# ax.plot(x, x * 2)
# ax.legend([”Demo“], loc=0)
 
plt.show()

TensorFlow绘制loss/accuracy曲线的实例

2. 双y轴曲线

双y轴曲线图例合并是一个棘手的操作,现以MNIST案例中loss/accuracy绘制曲线。

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import time
import matplotlib.pyplot as plt
import numpy as np
 
x_data = tf.placeholder(tf.float32, [None, 784])
y_data = tf.placeholder(tf.float32, [None, 10])
x_image = tf.reshape(x_data, [-1, 28, 28, 1])
 
# convolve layer 1
filter1 = tf.Variable(tf.truncated_normal([5, 5, 1, 6]))
bias1 = tf.Variable(tf.truncated_normal([6]))
conv1 = tf.nn.conv2d(x_image, filter1, strides=[1, 1, 1, 1], padding='SAME')
h_conv1 = tf.nn.sigmoid(conv1 + bias1)
maxPool2 = tf.nn.max_pool(h_conv1, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
 
# convolve layer 2
filter2 = tf.Variable(tf.truncated_normal([5, 5, 6, 16]))
bias2 = tf.Variable(tf.truncated_normal([16]))
conv2 = tf.nn.conv2d(maxPool2, filter2, strides=[1, 1, 1, 1], padding='SAME')
h_conv2 = tf.nn.sigmoid(conv2 + bias2)
maxPool3 = tf.nn.max_pool(h_conv2, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
 
# convolve layer 3
filter3 = tf.Variable(tf.truncated_normal([5, 5, 16, 120]))
bias3 = tf.Variable(tf.truncated_normal([120]))
conv3 = tf.nn.conv2d(maxPool3, filter3, strides=[1, 1, 1, 1], padding='SAME')
h_conv3 = tf.nn.sigmoid(conv3 + bias3)
 
# full connection layer 1
W_fc1 = tf.Variable(tf.truncated_normal([7 * 7 * 120, 80]))
b_fc1 = tf.Variable(tf.truncated_normal([80]))
h_pool2_flat = tf.reshape(h_conv3, [-1, 7 * 7 * 120])
h_fc1 = tf.nn.sigmoid(tf.matmul(h_pool2_flat, W_fc1) + b_fc1)
 
# full connection layer 2
W_fc2 = tf.Variable(tf.truncated_normal([80, 10]))
b_fc2 = tf.Variable(tf.truncated_normal([10]))
y_model = tf.nn.softmax(tf.matmul(h_fc1, W_fc2) + b_fc2)
 
cross_entropy = - tf.reduce_sum(y_data * tf.log(y_model))
 
train_step = tf.train.GradientDescentOptimizer(1e-3).minimize(cross_entropy)
 
sess = tf.InteractiveSession()
correct_prediction = tf.equal(tf.argmax(y_data, 1), tf.argmax(y_model, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))
sess.run(tf.global_variables_initializer())
 
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
 
fig_loss = np.zeros([1000])
fig_accuracy = np.zeros([1000])
 
start_time = time.time()
for i in range(1000):
  batch_xs, batch_ys = mnist.train.next_batch(200)
  if i % 100 == 0:
    train_accuracy = sess.run(accuracy, feed_dict={x_data: batch_xs, y_data: batch_ys})
    print("step %d, train accuracy %g" % (i, train_accuracy))
    end_time = time.time()
    print("time:", (end_time - start_time))
    start_time = end_time
    print("********************************")
  train_step.run(feed_dict={x_data: batch_xs, y_data: batch_ys})
  fig_loss[i] = sess.run(cross_entropy, feed_dict={x_data: batch_xs, y_data: batch_ys})
  fig_accuracy[i] = sess.run(accuracy, feed_dict={x_data: batch_xs, y_data: batch_ys})
print("test accuracy %g" % sess.run(accuracy, feed_dict={x_data: mnist.test.images, y_data: mnist.test.labels}))
 
 
# 绘制曲线
fig, ax1 = plt.subplots()
ax2 = ax1.twinx()
lns1 = ax1.plot(np.arange(1000), fig_loss, label="Loss")
# 按一定间隔显示实现方法
# ax2.plot(200 * np.arange(len(fig_accuracy)), fig_accuracy, 'r')
lns2 = ax2.plot(np.arange(1000), fig_accuracy, 'r', label="Accuracy")
ax1.set_xlabel('iteration')
ax1.set_ylabel('training loss')
ax2.set_ylabel('training accuracy')
# 合并图例
lns = lns1 + lns2
labels = ["Loss", "Accuracy"]
# labels = [l.get_label() for l in lns]
plt.legend(lns, labels, loc=7)
plt.show()

注:数据集保存在MNIST_data文件夹下

其实就是三步:

1)分别定义loss/accuracy一维数组

fig_loss = np.zeros([1000])
fig_accuracy = np.zeros([1000])
# 按间隔定义方式:fig_accuracy = np.zeros(int(np.ceil(iteration / interval)))

2)填充真实数据

fig_loss[i] = sess.run(cross_entropy, feed_dict={x_data: batch_xs, y_data: batch_ys})
 fig_accuracy[i] = sess.run(accuracy, feed_dict={x_data: batch_xs, y_data: batch_ys})

3)绘制曲线

fig, ax1 = plt.subplots()
ax2 = ax1.twinx()
lns1 = ax1.plot(np.arange(1000), fig_loss, label="Loss")
# 按一定间隔显示实现方法
# ax2.plot(200 * np.arange(len(fig_accuracy)), fig_accuracy, 'r')
lns2 = ax2.plot(np.arange(1000), fig_accuracy, 'r', label="Accuracy")
ax1.set_xlabel('iteration')
ax1.set_ylabel('training loss')
ax2.set_ylabel('training accuracy')
# 合并图例
lns = lns1 + lns2
labels = ["Loss", "Accuracy"]
# labels = [l.get_label() for l in lns]
plt.legend(lns, labels, loc=7)

TensorFlow绘制loss/accuracy曲线的实例

以上这篇TensorFlow绘制loss/accuracy曲线的实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
浅谈Python程序与C++程序的联合使用
Apr 07 Python
Python编程之列表操作实例详解【创建、使用、更新、删除】
Jul 22 Python
浅谈Python中的bs4基础
Oct 21 Python
python导入坐标点的具体操作
May 10 Python
python制作英语翻译小工具代码实例
Sep 09 Python
用Python写一个自动木马程序
Sep 17 Python
python pycharm的安装及其使用
Oct 11 Python
python调用Matplotlib绘制分布点图
Oct 18 Python
基于Python中random.sample()的替代方案
May 23 Python
Django实现内容缓存实例方法
Jun 30 Python
python中K-means算法基础知识点
Jan 25 Python
基于tensorflow __init__、build 和call的使用小结
Feb 26 Python
NumPy统计函数的实现方法
Jan 21 #Python
TensorFlow实现打印每一层的输出
Jan 21 #Python
NumPy排序的实现
Jan 21 #Python
tensorflow实现在函数中用tf.Print输出中间值
Jan 21 #Python
Python实现随机生成任意数量车牌号
Jan 21 #Python
tensorflow模型继续训练 fineturn实例
Jan 21 #Python
tensorflow ckpt模型和pb模型获取节点名称,及ckpt转pb模型实例
Jan 21 #Python
You might like
PHP微信支付结果通知与回调策略分析
2019/01/10 PHP
PHP工厂模式的日常使用
2019/03/20 PHP
PHP进阶学习之垃圾回收机制详解
2019/06/18 PHP
PHP实现随机发放扑克牌
2020/04/21 PHP
JS 文件本身编码转换 图文教程
2009/10/12 Javascript
JS简单实现String转Date的方法
2016/03/02 Javascript
JavaScript面向对象程序设计教程
2016/03/29 Javascript
如何使用Bootstrap创建表单
2017/03/29 Javascript
解决bootstrap下拉菜单点击立即隐藏bug的方法
2017/06/13 Javascript
js模块加载方式浅析
2017/08/12 Javascript
javascript填充默认头像方法
2018/02/22 Javascript
layui实现数据表格table分页功能(ajax异步)
2019/07/27 Javascript
小程序中this.setData的使用和注意事项
2019/08/28 Javascript
微信小程序点击顶部导航栏切换样式代码实例
2019/11/12 Javascript
基于redis的小程序登录实现方法流程分析
2020/05/25 Javascript
关于vue 项目中浏览器跨域的配置问题
2020/11/10 Javascript
[08:29]DOTA2每周TOP10 精彩击杀集锦vol.7
2014/06/25 DOTA
Python备份Mysql脚本
2008/08/11 Python
matplotlib作图添加表格实例代码
2018/01/23 Python
Flask解决跨域的问题示例代码
2018/02/12 Python
win8下python3.4安装和环境配置图文教程
2018/07/31 Python
python pandas库的安装和创建
2019/01/10 Python
python开发游戏的前期准备
2019/05/05 Python
python制作英语翻译小工具代码实例
2019/09/09 Python
Python高级特性——详解多维数组切片(Slice)
2019/11/26 Python
python读取yaml文件后修改写入本地实例
2020/04/27 Python
俄语地区最大的中国商品在线购物网站之一:Umka Mall
2019/11/03 全球购物
商得四方公司面试题(gid+)
2014/04/30 面试题
如何清空Session
2015/02/23 面试题
个人能力自我鉴赏
2014/01/25 职场文书
年度考核自我评价
2014/01/25 职场文书
初中三好学生自我鉴定
2014/04/07 职场文书
实习单位评语
2014/04/26 职场文书
安全教育的主题班会
2015/08/13 职场文书
使用HttpSessionListener监听器实战
2022/03/17 Java/Android
Vue 打包后相对路径的引用问题
2022/06/05 Vue.js