tensorflow1.0学习之模型的保存与恢复(Saver)


Posted in Python onApril 23, 2018

将训练好的模型参数保存起来,以便以后进行验证或测试,这是我们经常要做的事情。tf里面提供模型保存的是tf.train.Saver()模块。

模型保存,先要创建一个Saver对象:如

saver=tf.train.Saver()

在创建这个Saver对象的时候,有一个参数我们经常会用到,就是 max_to_keep 参数,这个是用来设置保存模型的个数,默认为5,即 max_to_keep=5,保存最近的5个模型。如果你想每训练一代(epoch)就想保存一次模型,则可以将 max_to_keep设置为None或者0,如:

saver=tf.train.Saver(max_to_keep=0)

但是这样做除了多占用硬盘,并没有实际多大的用处,因此不推荐。

当然,如果你只想保存最后一代的模型,则只需要将max_to_keep设置为1即可,即

saver=tf.train.Saver(max_to_keep=1)

创建完saver对象后,就可以保存训练好的模型了,如:

saver.save(sess,'ckpt/mnist.ckpt',global_step=step)

第一个参数sess,这个就不用说了。第二个参数设定保存的路径和名字,第三个参数将训练的次数作为后缀加入到模型名字中。

saver.save(sess, 'my-model', global_step=0) ==>      filename: 'my-model-0'
 ...
 saver.save(sess, 'my-model', global_step=1000) ==> filename: 'my-model-1000'

看一个mnist实例:

# -*- coding: utf-8 -*-
"""
Created on Sun Jun 4 10:29:48 2017

@author: Administrator
"""
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("MNIST_data/", one_hot=False)

x = tf.placeholder(tf.float32, [None, 784])
y_=tf.placeholder(tf.int32,[None,])

dense1 = tf.layers.dense(inputs=x, 
           units=1024, 
           activation=tf.nn.relu,
           kernel_initializer=tf.truncated_normal_initializer(stddev=0.01),
           kernel_regularizer=tf.nn.l2_loss)
dense2= tf.layers.dense(inputs=dense1, 
           units=512, 
           activation=tf.nn.relu,
           kernel_initializer=tf.truncated_normal_initializer(stddev=0.01),
           kernel_regularizer=tf.nn.l2_loss)
logits= tf.layers.dense(inputs=dense2, 
            units=10, 
            activation=None,
            kernel_initializer=tf.truncated_normal_initializer(stddev=0.01),
            kernel_regularizer=tf.nn.l2_loss)

loss=tf.losses.sparse_softmax_cross_entropy(labels=y_,logits=logits)
train_op=tf.train.AdamOptimizer(learning_rate=0.001).minimize(loss)
correct_prediction = tf.equal(tf.cast(tf.argmax(logits,1),tf.int32), y_)  
acc= tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

sess=tf.InteractiveSession() 
sess.run(tf.global_variables_initializer())

saver=tf.train.Saver(max_to_keep=1)
for i in range(100):
 batch_xs, batch_ys = mnist.train.next_batch(100)
 sess.run(train_op, feed_dict={x: batch_xs, y_: batch_ys})
 val_loss,val_acc=sess.run([loss,acc], feed_dict={x: mnist.test.images, y_: mnist.test.labels})
 print('epoch:%d, val_loss:%f, val_acc:%f'%(i,val_loss,val_acc))
 saver.save(sess,'ckpt/mnist.ckpt',global_step=i+1)
sess.close()

代码中红色部分就是保存模型的代码,虽然我在每训练完一代的时候,都进行了保存,但后一次保存的模型会覆盖前一次的,最终只会保存最后一次。因此我们可以节省时间,将保存代码放到循环之外(仅适用max_to_keep=1,否则还是需要放在循环内).

在实验中,最后一代可能并不是验证精度最高的一代,因此我们并不想默认保存最后一代,而是想保存验证精度最高的一代,则加个中间变量和判断语句就可以了。

saver=tf.train.Saver(max_to_keep=1)
max_acc=0
for i in range(100):
 batch_xs, batch_ys = mnist.train.next_batch(100)
 sess.run(train_op, feed_dict={x: batch_xs, y_: batch_ys})
 val_loss,val_acc=sess.run([loss,acc], feed_dict={x: mnist.test.images, y_: mnist.test.labels})
 print('epoch:%d, val_loss:%f, val_acc:%f'%(i,val_loss,val_acc))
 if val_acc>max_acc:
   max_acc=val_acc
   saver.save(sess,'ckpt/mnist.ckpt',global_step=i+1)
sess.close()

如果我们想保存验证精度最高的三代,且把每次的验证精度也随之保存下来,则我们可以生成一个txt文件用于保存。

saver=tf.train.Saver(max_to_keep=3)
max_acc=0
f=open('ckpt/acc.txt','w')
for i in range(100):
 batch_xs, batch_ys = mnist.train.next_batch(100)
 sess.run(train_op, feed_dict={x: batch_xs, y_: batch_ys})
 val_loss,val_acc=sess.run([loss,acc], feed_dict={x: mnist.test.images, y_: mnist.test.labels})
 print('epoch:%d, val_loss:%f, val_acc:%f'%(i,val_loss,val_acc))
 f.write(str(i+1)+', val_acc: '+str(val_acc)+'\n')
 if val_acc>max_acc:
   max_acc=val_acc
   saver.save(sess,'ckpt/mnist.ckpt',global_step=i+1)
f.close()
sess.close()

模型的恢复用的是restore()函数,它需要两个参数restore(sess, save_path),save_path指的是保存的模型路径。我们可以使用tf.train.latest_checkpoint()来自动获取最后一次保存的模型。如:

model_file=tf.train.latest_checkpoint('ckpt/')
saver.restore(sess,model_file)

则程序后半段代码我们可以改为:

sess=tf.InteractiveSession() 
sess.run(tf.global_variables_initializer())
is_train=False
saver=tf.train.Saver(max_to_keep=3)

#训练阶段
if is_train:
  max_acc=0
  f=open('ckpt/acc.txt','w')
  for i in range(100):
   batch_xs, batch_ys = mnist.train.next_batch(100)
   sess.run(train_op, feed_dict={x: batch_xs, y_: batch_ys})
   val_loss,val_acc=sess.run([loss,acc], feed_dict={x: mnist.test.images, y_: mnist.test.labels})
   print('epoch:%d, val_loss:%f, val_acc:%f'%(i,val_loss,val_acc))
   f.write(str(i+1)+', val_acc: '+str(val_acc)+'\n')
   if val_acc>max_acc:
     max_acc=val_acc
     saver.save(sess,'ckpt/mnist.ckpt',global_step=i+1)
  f.close()

#验证阶段
else:
  model_file=tf.train.latest_checkpoint('ckpt/')
  saver.restore(sess,model_file)
  val_loss,val_acc=sess.run([loss,acc], feed_dict={x: mnist.test.images, y_: mnist.test.labels})
  print('val_loss:%f, val_acc:%f'%(val_loss,val_acc))
sess.close()

标红的地方,就是与保存、恢复模型相关的代码。用一个bool型变量is_train来控制训练和验证两个阶段。

整个源程序:

# -*- coding: utf-8 -*-
"""
Created on Sun Jun 4 10:29:48 2017

@author: Administrator
"""
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("MNIST_data/", one_hot=False)

x = tf.placeholder(tf.float32, [None, 784])
y_=tf.placeholder(tf.int32,[None,])

dense1 = tf.layers.dense(inputs=x, 
           units=1024, 
           activation=tf.nn.relu,
           kernel_initializer=tf.truncated_normal_initializer(stddev=0.01),
           kernel_regularizer=tf.nn.l2_loss)
dense2= tf.layers.dense(inputs=dense1, 
           units=512, 
           activation=tf.nn.relu,
           kernel_initializer=tf.truncated_normal_initializer(stddev=0.01),
           kernel_regularizer=tf.nn.l2_loss)
logits= tf.layers.dense(inputs=dense2, 
            units=10, 
            activation=None,
            kernel_initializer=tf.truncated_normal_initializer(stddev=0.01),
            kernel_regularizer=tf.nn.l2_loss)

loss=tf.losses.sparse_softmax_cross_entropy(labels=y_,logits=logits)
train_op=tf.train.AdamOptimizer(learning_rate=0.001).minimize(loss)
correct_prediction = tf.equal(tf.cast(tf.argmax(logits,1),tf.int32), y_)  
acc= tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

sess=tf.InteractiveSession() 
sess.run(tf.global_variables_initializer())

is_train=True
saver=tf.train.Saver(max_to_keep=3)

#训练阶段
if is_train:
  max_acc=0
  f=open('ckpt/acc.txt','w')
  for i in range(100):
   batch_xs, batch_ys = mnist.train.next_batch(100)
   sess.run(train_op, feed_dict={x: batch_xs, y_: batch_ys})
   val_loss,val_acc=sess.run([loss,acc], feed_dict={x: mnist.test.images, y_: mnist.test.labels})
   print('epoch:%d, val_loss:%f, val_acc:%f'%(i,val_loss,val_acc))
   f.write(str(i+1)+', val_acc: '+str(val_acc)+'\n')
   if val_acc>max_acc:
     max_acc=val_acc
     saver.save(sess,'ckpt/mnist.ckpt',global_step=i+1)
  f.close()

#验证阶段
else:
  model_file=tf.train.latest_checkpoint('ckpt/')
  saver.restore(sess,model_file)
  val_loss,val_acc=sess.run([loss,acc], feed_dict={x: mnist.test.images, y_: mnist.test.labels})
  print('val_loss:%f, val_acc:%f'%(val_loss,val_acc))
sess.close()

参考文章:https://3water.com/article/138779.htm

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
从零学Python之hello world
May 21 Python
centos系统升级python 2.7.3
Jul 03 Python
简单介绍Python中用于求最小值的min()方法
May 15 Python
Python文本统计功能之西游记用字统计操作示例
May 07 Python
python 递归深度优先搜索与广度优先搜索算法模拟实现
Oct 22 Python
python针对不定分隔符切割提取字符串的方法
Oct 26 Python
python RabbitMQ 使用详细介绍(小结)
Nov 08 Python
Python 共享变量加锁、释放详解
Aug 28 Python
TFRecord文件查看包含的所有Features代码
Feb 17 Python
keras 自定义loss model.add_loss的使用详解
Jun 22 Python
Python实现文件压缩和解压的示例代码
Aug 12 Python
Python常用模块函数代码汇总解析
Aug 31 Python
tensorflow 使用flags定义命令行参数的方法
Apr 23 #Python
Tensorflow之Saver的用法详解
Apr 23 #Python
python获取文件路径、文件名、后缀名的实例
Apr 23 #Python
Python基于FTP模块实现ftp文件上传操作示例
Apr 23 #Python
Python基于whois模块简单识别网站域名及所有者的方法
Apr 23 #Python
Python实现自定义顺序、排列写入数据到Excel的方法
Apr 23 #Python
python 获取指定文件夹下所有文件名称并写入列表的实例
Apr 23 #Python
You might like
Smarty foreach控制循环次数的实现详解
2013/07/03 PHP
php使用Jpgraph创建柱状图展示年度收支表效果示例
2017/02/15 PHP
PHP 并发场景的几种解决方案
2019/06/14 PHP
PHP filter_var() 函数, 验证判断EMAIL,URL等
2021/03/09 PHP
基于jquery的给文章加入关键字链接
2010/10/26 Javascript
jquery判断浏览器类型的代码
2012/11/05 Javascript
Javascript Web Slider 焦点图示例源码
2013/10/10 Javascript
使用Js让Html中特殊字符不被转义
2013/11/05 Javascript
详解JavaScript中setSeconds()方法的使用
2015/06/11 Javascript
JavaScript中的Promise使用详解
2015/06/24 Javascript
JavaScript实现给按钮加上双重动作的方法
2015/08/14 Javascript
jQuery Easyui实现左右布局
2016/01/26 Javascript
JS中sort函数排序用法实例分析
2016/06/16 Javascript
用JS动态设置CSS样式常见方法小结(推荐)
2016/11/10 Javascript
VueJs路由跳转——vue-router的使用详解
2017/01/10 Javascript
jquery与js实现全选功能的区别
2017/06/11 jQuery
angularjs 缓存的使用详解
2018/03/19 Javascript
vue-cli3.0 环境变量与模式配置方法
2018/11/08 Javascript
使用 vue 实现灭霸打响指英雄消失的效果附demo
2019/05/06 Javascript
微信小程序全局变量的设置、使用、修改过程解析
2019/09/24 Javascript
Vue实现仿iPhone悬浮球的示例代码
2020/03/13 Javascript
基于VUE实现简单的学生信息管理系统
2021/01/13 Vue.js
python应用程序在windows下不出现cmd窗口的办法
2014/05/29 Python
python利用urllib和urllib2访问http的GET/POST详解
2017/09/27 Python
Python爬虫框架Scrapy基本用法入门教程
2018/07/26 Python
python实现本地图片转存并重命名的示例代码
2018/10/27 Python
python url 参数修改方法
2018/12/26 Python
python用opencv 图像傅里叶变换
2021/01/04 Python
eBay法国购物网站:eBay.fr
2017/10/21 全球购物
电子信息专业学生自荐信
2013/11/09 职场文书
毕业生个人的求职信范文
2013/12/03 职场文书
失业者真诚求职信范文
2013/12/25 职场文书
银行介绍信范文
2014/01/10 职场文书
个人工作作风整改措施思想汇报
2014/10/13 职场文书
初中差生评语
2014/12/29 职场文书
python实现批量移动文件
2021/04/05 Python