tensorflow1.0学习之模型的保存与恢复(Saver)


Posted in Python onApril 23, 2018

将训练好的模型参数保存起来,以便以后进行验证或测试,这是我们经常要做的事情。tf里面提供模型保存的是tf.train.Saver()模块。

模型保存,先要创建一个Saver对象:如

saver=tf.train.Saver()

在创建这个Saver对象的时候,有一个参数我们经常会用到,就是 max_to_keep 参数,这个是用来设置保存模型的个数,默认为5,即 max_to_keep=5,保存最近的5个模型。如果你想每训练一代(epoch)就想保存一次模型,则可以将 max_to_keep设置为None或者0,如:

saver=tf.train.Saver(max_to_keep=0)

但是这样做除了多占用硬盘,并没有实际多大的用处,因此不推荐。

当然,如果你只想保存最后一代的模型,则只需要将max_to_keep设置为1即可,即

saver=tf.train.Saver(max_to_keep=1)

创建完saver对象后,就可以保存训练好的模型了,如:

saver.save(sess,'ckpt/mnist.ckpt',global_step=step)

第一个参数sess,这个就不用说了。第二个参数设定保存的路径和名字,第三个参数将训练的次数作为后缀加入到模型名字中。

saver.save(sess, 'my-model', global_step=0) ==>      filename: 'my-model-0'
 ...
 saver.save(sess, 'my-model', global_step=1000) ==> filename: 'my-model-1000'

看一个mnist实例:

# -*- coding: utf-8 -*-
"""
Created on Sun Jun 4 10:29:48 2017

@author: Administrator
"""
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("MNIST_data/", one_hot=False)

x = tf.placeholder(tf.float32, [None, 784])
y_=tf.placeholder(tf.int32,[None,])

dense1 = tf.layers.dense(inputs=x, 
           units=1024, 
           activation=tf.nn.relu,
           kernel_initializer=tf.truncated_normal_initializer(stddev=0.01),
           kernel_regularizer=tf.nn.l2_loss)
dense2= tf.layers.dense(inputs=dense1, 
           units=512, 
           activation=tf.nn.relu,
           kernel_initializer=tf.truncated_normal_initializer(stddev=0.01),
           kernel_regularizer=tf.nn.l2_loss)
logits= tf.layers.dense(inputs=dense2, 
            units=10, 
            activation=None,
            kernel_initializer=tf.truncated_normal_initializer(stddev=0.01),
            kernel_regularizer=tf.nn.l2_loss)

loss=tf.losses.sparse_softmax_cross_entropy(labels=y_,logits=logits)
train_op=tf.train.AdamOptimizer(learning_rate=0.001).minimize(loss)
correct_prediction = tf.equal(tf.cast(tf.argmax(logits,1),tf.int32), y_)  
acc= tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

sess=tf.InteractiveSession() 
sess.run(tf.global_variables_initializer())

saver=tf.train.Saver(max_to_keep=1)
for i in range(100):
 batch_xs, batch_ys = mnist.train.next_batch(100)
 sess.run(train_op, feed_dict={x: batch_xs, y_: batch_ys})
 val_loss,val_acc=sess.run([loss,acc], feed_dict={x: mnist.test.images, y_: mnist.test.labels})
 print('epoch:%d, val_loss:%f, val_acc:%f'%(i,val_loss,val_acc))
 saver.save(sess,'ckpt/mnist.ckpt',global_step=i+1)
sess.close()

代码中红色部分就是保存模型的代码,虽然我在每训练完一代的时候,都进行了保存,但后一次保存的模型会覆盖前一次的,最终只会保存最后一次。因此我们可以节省时间,将保存代码放到循环之外(仅适用max_to_keep=1,否则还是需要放在循环内).

在实验中,最后一代可能并不是验证精度最高的一代,因此我们并不想默认保存最后一代,而是想保存验证精度最高的一代,则加个中间变量和判断语句就可以了。

saver=tf.train.Saver(max_to_keep=1)
max_acc=0
for i in range(100):
 batch_xs, batch_ys = mnist.train.next_batch(100)
 sess.run(train_op, feed_dict={x: batch_xs, y_: batch_ys})
 val_loss,val_acc=sess.run([loss,acc], feed_dict={x: mnist.test.images, y_: mnist.test.labels})
 print('epoch:%d, val_loss:%f, val_acc:%f'%(i,val_loss,val_acc))
 if val_acc>max_acc:
   max_acc=val_acc
   saver.save(sess,'ckpt/mnist.ckpt',global_step=i+1)
sess.close()

如果我们想保存验证精度最高的三代,且把每次的验证精度也随之保存下来,则我们可以生成一个txt文件用于保存。

saver=tf.train.Saver(max_to_keep=3)
max_acc=0
f=open('ckpt/acc.txt','w')
for i in range(100):
 batch_xs, batch_ys = mnist.train.next_batch(100)
 sess.run(train_op, feed_dict={x: batch_xs, y_: batch_ys})
 val_loss,val_acc=sess.run([loss,acc], feed_dict={x: mnist.test.images, y_: mnist.test.labels})
 print('epoch:%d, val_loss:%f, val_acc:%f'%(i,val_loss,val_acc))
 f.write(str(i+1)+', val_acc: '+str(val_acc)+'\n')
 if val_acc>max_acc:
   max_acc=val_acc
   saver.save(sess,'ckpt/mnist.ckpt',global_step=i+1)
f.close()
sess.close()

模型的恢复用的是restore()函数,它需要两个参数restore(sess, save_path),save_path指的是保存的模型路径。我们可以使用tf.train.latest_checkpoint()来自动获取最后一次保存的模型。如:

model_file=tf.train.latest_checkpoint('ckpt/')
saver.restore(sess,model_file)

则程序后半段代码我们可以改为:

sess=tf.InteractiveSession() 
sess.run(tf.global_variables_initializer())
is_train=False
saver=tf.train.Saver(max_to_keep=3)

#训练阶段
if is_train:
  max_acc=0
  f=open('ckpt/acc.txt','w')
  for i in range(100):
   batch_xs, batch_ys = mnist.train.next_batch(100)
   sess.run(train_op, feed_dict={x: batch_xs, y_: batch_ys})
   val_loss,val_acc=sess.run([loss,acc], feed_dict={x: mnist.test.images, y_: mnist.test.labels})
   print('epoch:%d, val_loss:%f, val_acc:%f'%(i,val_loss,val_acc))
   f.write(str(i+1)+', val_acc: '+str(val_acc)+'\n')
   if val_acc>max_acc:
     max_acc=val_acc
     saver.save(sess,'ckpt/mnist.ckpt',global_step=i+1)
  f.close()

#验证阶段
else:
  model_file=tf.train.latest_checkpoint('ckpt/')
  saver.restore(sess,model_file)
  val_loss,val_acc=sess.run([loss,acc], feed_dict={x: mnist.test.images, y_: mnist.test.labels})
  print('val_loss:%f, val_acc:%f'%(val_loss,val_acc))
sess.close()

标红的地方,就是与保存、恢复模型相关的代码。用一个bool型变量is_train来控制训练和验证两个阶段。

整个源程序:

# -*- coding: utf-8 -*-
"""
Created on Sun Jun 4 10:29:48 2017

@author: Administrator
"""
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("MNIST_data/", one_hot=False)

x = tf.placeholder(tf.float32, [None, 784])
y_=tf.placeholder(tf.int32,[None,])

dense1 = tf.layers.dense(inputs=x, 
           units=1024, 
           activation=tf.nn.relu,
           kernel_initializer=tf.truncated_normal_initializer(stddev=0.01),
           kernel_regularizer=tf.nn.l2_loss)
dense2= tf.layers.dense(inputs=dense1, 
           units=512, 
           activation=tf.nn.relu,
           kernel_initializer=tf.truncated_normal_initializer(stddev=0.01),
           kernel_regularizer=tf.nn.l2_loss)
logits= tf.layers.dense(inputs=dense2, 
            units=10, 
            activation=None,
            kernel_initializer=tf.truncated_normal_initializer(stddev=0.01),
            kernel_regularizer=tf.nn.l2_loss)

loss=tf.losses.sparse_softmax_cross_entropy(labels=y_,logits=logits)
train_op=tf.train.AdamOptimizer(learning_rate=0.001).minimize(loss)
correct_prediction = tf.equal(tf.cast(tf.argmax(logits,1),tf.int32), y_)  
acc= tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

sess=tf.InteractiveSession() 
sess.run(tf.global_variables_initializer())

is_train=True
saver=tf.train.Saver(max_to_keep=3)

#训练阶段
if is_train:
  max_acc=0
  f=open('ckpt/acc.txt','w')
  for i in range(100):
   batch_xs, batch_ys = mnist.train.next_batch(100)
   sess.run(train_op, feed_dict={x: batch_xs, y_: batch_ys})
   val_loss,val_acc=sess.run([loss,acc], feed_dict={x: mnist.test.images, y_: mnist.test.labels})
   print('epoch:%d, val_loss:%f, val_acc:%f'%(i,val_loss,val_acc))
   f.write(str(i+1)+', val_acc: '+str(val_acc)+'\n')
   if val_acc>max_acc:
     max_acc=val_acc
     saver.save(sess,'ckpt/mnist.ckpt',global_step=i+1)
  f.close()

#验证阶段
else:
  model_file=tf.train.latest_checkpoint('ckpt/')
  saver.restore(sess,model_file)
  val_loss,val_acc=sess.run([loss,acc], feed_dict={x: mnist.test.images, y_: mnist.test.labels})
  print('val_loss:%f, val_acc:%f'%(val_loss,val_acc))
sess.close()

参考文章:https://3water.com/article/138779.htm

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
使用Python脚本操作MongoDB的教程
Apr 16 Python
python 链接和操作 memcache方法
Mar 04 Python
windows环境下tensorflow安装过程详解
Mar 30 Python
Python可变参数*args和**kwargs用法实例小结
Apr 27 Python
python实现生成字符串大小写字母和数字的各种组合
Jan 01 Python
python矩阵/字典实现最短路径算法
Jan 17 Python
Python图像处理PIL各模块详细介绍(推荐)
Jul 17 Python
利用python实现短信和电话提醒功能的例子
Aug 08 Python
Python unittest框架操作实例解析
Apr 13 Python
Python接口测试文件上传实例解析
May 22 Python
解决Python中导入自己写的类,被划红线,但不影响执行的问题
Jul 13 Python
Python3.8官网文档之类的基础语法阅读
Sep 04 Python
tensorflow 使用flags定义命令行参数的方法
Apr 23 #Python
Tensorflow之Saver的用法详解
Apr 23 #Python
python获取文件路径、文件名、后缀名的实例
Apr 23 #Python
Python基于FTP模块实现ftp文件上传操作示例
Apr 23 #Python
Python基于whois模块简单识别网站域名及所有者的方法
Apr 23 #Python
Python实现自定义顺序、排列写入数据到Excel的方法
Apr 23 #Python
python 获取指定文件夹下所有文件名称并写入列表的实例
Apr 23 #Python
You might like
php xfocus防注入资料
2008/04/27 PHP
php 验证码(倾斜,正弦干扰线,黏贴,旋转)
2013/06/29 PHP
Zend Framework生成验证码并实现验证码验证功能(附demo源码下载)
2016/03/22 PHP
PHP session垃圾回收机制实例分析
2019/06/28 PHP
Laravel 5.4前后台分离,通过不同的二级域名访问方法
2019/10/13 PHP
PHP PDO和消息队列的个人理解与应用实例分析
2019/11/25 PHP
jQuery :nth-child前有无空格的区别分析
2011/07/11 Javascript
分享一个自己写的简单的javascript分页组件
2015/02/15 Javascript
jquery实现简单手风琴菜单效果实例
2015/06/13 Javascript
图片旋转、鼠标滚轮缩放、镜像、切换图片js代码
2020/12/13 Javascript
javascript实现右侧弹出“分享到”窗口效果
2016/02/01 Javascript
jQuery中$.each()函数的用法引申实例
2016/05/12 Javascript
JS hashMap实例详解
2016/05/26 Javascript
jQuery调用Webservice传递json数组的方法
2016/08/06 Javascript
使用Bootstrap4 + Vue2实现分页查询的示例代码
2017/12/21 Javascript
基于vue实现可搜索下拉框定制组件
2020/03/26 Javascript
JavaScript实现串行请求的示例代码
2020/09/14 Javascript
[07:49]2014DOTA2国际邀请赛 Newbee夺冠后采访xiao8坦言奖金会上交
2014/07/23 DOTA
Python with的用法
2014/08/22 Python
整理Python 常用string函数(收藏)
2016/05/30 Python
python的列表List求均值和中位数实例
2020/03/03 Python
python爬虫开发之使用Python爬虫库requests多线程抓取猫眼电影TOP100实例
2020/03/10 Python
Python应用实现处理excel数据过程解析
2020/06/19 Python
详解Scrapy Redis入门实战
2020/11/18 Python
PHP面试题及答案一
2012/06/18 面试题
现在输入n个数字,以逗号,分开;然后可选择升或者降序排序;按提交键就在另一页面显示按什么排序,结果为,提供reset
2012/11/09 面试题
专题组织生活会方案
2014/06/15 职场文书
仓库管理员岗位职责
2015/02/03 职场文书
2015年度物业公司工作总结
2015/04/27 职场文书
辩护词格式
2015/05/22 职场文书
2015年计算机教师工作总结
2015/07/22 职场文书
中国梦宣传标语口号
2015/12/26 职场文书
mysql数据库入门第一步之创建表
2021/05/14 MySQL
JavaScript实现登录窗体
2021/06/22 Javascript
斗罗大陆八大特殊魂兽,龙族始祖排榜首,第五最残忍(翠魔鸟)
2022/03/18 国漫
日本十大血腥动漫,那些被禁播的动漫盘点
2022/03/21 日漫