tensorflow1.0学习之模型的保存与恢复(Saver)


Posted in Python onApril 23, 2018

将训练好的模型参数保存起来,以便以后进行验证或测试,这是我们经常要做的事情。tf里面提供模型保存的是tf.train.Saver()模块。

模型保存,先要创建一个Saver对象:如

saver=tf.train.Saver()

在创建这个Saver对象的时候,有一个参数我们经常会用到,就是 max_to_keep 参数,这个是用来设置保存模型的个数,默认为5,即 max_to_keep=5,保存最近的5个模型。如果你想每训练一代(epoch)就想保存一次模型,则可以将 max_to_keep设置为None或者0,如:

saver=tf.train.Saver(max_to_keep=0)

但是这样做除了多占用硬盘,并没有实际多大的用处,因此不推荐。

当然,如果你只想保存最后一代的模型,则只需要将max_to_keep设置为1即可,即

saver=tf.train.Saver(max_to_keep=1)

创建完saver对象后,就可以保存训练好的模型了,如:

saver.save(sess,'ckpt/mnist.ckpt',global_step=step)

第一个参数sess,这个就不用说了。第二个参数设定保存的路径和名字,第三个参数将训练的次数作为后缀加入到模型名字中。

saver.save(sess, 'my-model', global_step=0) ==>      filename: 'my-model-0'
 ...
 saver.save(sess, 'my-model', global_step=1000) ==> filename: 'my-model-1000'

看一个mnist实例:

# -*- coding: utf-8 -*-
"""
Created on Sun Jun 4 10:29:48 2017

@author: Administrator
"""
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("MNIST_data/", one_hot=False)

x = tf.placeholder(tf.float32, [None, 784])
y_=tf.placeholder(tf.int32,[None,])

dense1 = tf.layers.dense(inputs=x, 
           units=1024, 
           activation=tf.nn.relu,
           kernel_initializer=tf.truncated_normal_initializer(stddev=0.01),
           kernel_regularizer=tf.nn.l2_loss)
dense2= tf.layers.dense(inputs=dense1, 
           units=512, 
           activation=tf.nn.relu,
           kernel_initializer=tf.truncated_normal_initializer(stddev=0.01),
           kernel_regularizer=tf.nn.l2_loss)
logits= tf.layers.dense(inputs=dense2, 
            units=10, 
            activation=None,
            kernel_initializer=tf.truncated_normal_initializer(stddev=0.01),
            kernel_regularizer=tf.nn.l2_loss)

loss=tf.losses.sparse_softmax_cross_entropy(labels=y_,logits=logits)
train_op=tf.train.AdamOptimizer(learning_rate=0.001).minimize(loss)
correct_prediction = tf.equal(tf.cast(tf.argmax(logits,1),tf.int32), y_)  
acc= tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

sess=tf.InteractiveSession() 
sess.run(tf.global_variables_initializer())

saver=tf.train.Saver(max_to_keep=1)
for i in range(100):
 batch_xs, batch_ys = mnist.train.next_batch(100)
 sess.run(train_op, feed_dict={x: batch_xs, y_: batch_ys})
 val_loss,val_acc=sess.run([loss,acc], feed_dict={x: mnist.test.images, y_: mnist.test.labels})
 print('epoch:%d, val_loss:%f, val_acc:%f'%(i,val_loss,val_acc))
 saver.save(sess,'ckpt/mnist.ckpt',global_step=i+1)
sess.close()

代码中红色部分就是保存模型的代码,虽然我在每训练完一代的时候,都进行了保存,但后一次保存的模型会覆盖前一次的,最终只会保存最后一次。因此我们可以节省时间,将保存代码放到循环之外(仅适用max_to_keep=1,否则还是需要放在循环内).

在实验中,最后一代可能并不是验证精度最高的一代,因此我们并不想默认保存最后一代,而是想保存验证精度最高的一代,则加个中间变量和判断语句就可以了。

saver=tf.train.Saver(max_to_keep=1)
max_acc=0
for i in range(100):
 batch_xs, batch_ys = mnist.train.next_batch(100)
 sess.run(train_op, feed_dict={x: batch_xs, y_: batch_ys})
 val_loss,val_acc=sess.run([loss,acc], feed_dict={x: mnist.test.images, y_: mnist.test.labels})
 print('epoch:%d, val_loss:%f, val_acc:%f'%(i,val_loss,val_acc))
 if val_acc>max_acc:
   max_acc=val_acc
   saver.save(sess,'ckpt/mnist.ckpt',global_step=i+1)
sess.close()

如果我们想保存验证精度最高的三代,且把每次的验证精度也随之保存下来,则我们可以生成一个txt文件用于保存。

saver=tf.train.Saver(max_to_keep=3)
max_acc=0
f=open('ckpt/acc.txt','w')
for i in range(100):
 batch_xs, batch_ys = mnist.train.next_batch(100)
 sess.run(train_op, feed_dict={x: batch_xs, y_: batch_ys})
 val_loss,val_acc=sess.run([loss,acc], feed_dict={x: mnist.test.images, y_: mnist.test.labels})
 print('epoch:%d, val_loss:%f, val_acc:%f'%(i,val_loss,val_acc))
 f.write(str(i+1)+', val_acc: '+str(val_acc)+'\n')
 if val_acc>max_acc:
   max_acc=val_acc
   saver.save(sess,'ckpt/mnist.ckpt',global_step=i+1)
f.close()
sess.close()

模型的恢复用的是restore()函数,它需要两个参数restore(sess, save_path),save_path指的是保存的模型路径。我们可以使用tf.train.latest_checkpoint()来自动获取最后一次保存的模型。如:

model_file=tf.train.latest_checkpoint('ckpt/')
saver.restore(sess,model_file)

则程序后半段代码我们可以改为:

sess=tf.InteractiveSession() 
sess.run(tf.global_variables_initializer())
is_train=False
saver=tf.train.Saver(max_to_keep=3)

#训练阶段
if is_train:
  max_acc=0
  f=open('ckpt/acc.txt','w')
  for i in range(100):
   batch_xs, batch_ys = mnist.train.next_batch(100)
   sess.run(train_op, feed_dict={x: batch_xs, y_: batch_ys})
   val_loss,val_acc=sess.run([loss,acc], feed_dict={x: mnist.test.images, y_: mnist.test.labels})
   print('epoch:%d, val_loss:%f, val_acc:%f'%(i,val_loss,val_acc))
   f.write(str(i+1)+', val_acc: '+str(val_acc)+'\n')
   if val_acc>max_acc:
     max_acc=val_acc
     saver.save(sess,'ckpt/mnist.ckpt',global_step=i+1)
  f.close()

#验证阶段
else:
  model_file=tf.train.latest_checkpoint('ckpt/')
  saver.restore(sess,model_file)
  val_loss,val_acc=sess.run([loss,acc], feed_dict={x: mnist.test.images, y_: mnist.test.labels})
  print('val_loss:%f, val_acc:%f'%(val_loss,val_acc))
sess.close()

标红的地方,就是与保存、恢复模型相关的代码。用一个bool型变量is_train来控制训练和验证两个阶段。

整个源程序:

# -*- coding: utf-8 -*-
"""
Created on Sun Jun 4 10:29:48 2017

@author: Administrator
"""
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("MNIST_data/", one_hot=False)

x = tf.placeholder(tf.float32, [None, 784])
y_=tf.placeholder(tf.int32,[None,])

dense1 = tf.layers.dense(inputs=x, 
           units=1024, 
           activation=tf.nn.relu,
           kernel_initializer=tf.truncated_normal_initializer(stddev=0.01),
           kernel_regularizer=tf.nn.l2_loss)
dense2= tf.layers.dense(inputs=dense1, 
           units=512, 
           activation=tf.nn.relu,
           kernel_initializer=tf.truncated_normal_initializer(stddev=0.01),
           kernel_regularizer=tf.nn.l2_loss)
logits= tf.layers.dense(inputs=dense2, 
            units=10, 
            activation=None,
            kernel_initializer=tf.truncated_normal_initializer(stddev=0.01),
            kernel_regularizer=tf.nn.l2_loss)

loss=tf.losses.sparse_softmax_cross_entropy(labels=y_,logits=logits)
train_op=tf.train.AdamOptimizer(learning_rate=0.001).minimize(loss)
correct_prediction = tf.equal(tf.cast(tf.argmax(logits,1),tf.int32), y_)  
acc= tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

sess=tf.InteractiveSession() 
sess.run(tf.global_variables_initializer())

is_train=True
saver=tf.train.Saver(max_to_keep=3)

#训练阶段
if is_train:
  max_acc=0
  f=open('ckpt/acc.txt','w')
  for i in range(100):
   batch_xs, batch_ys = mnist.train.next_batch(100)
   sess.run(train_op, feed_dict={x: batch_xs, y_: batch_ys})
   val_loss,val_acc=sess.run([loss,acc], feed_dict={x: mnist.test.images, y_: mnist.test.labels})
   print('epoch:%d, val_loss:%f, val_acc:%f'%(i,val_loss,val_acc))
   f.write(str(i+1)+', val_acc: '+str(val_acc)+'\n')
   if val_acc>max_acc:
     max_acc=val_acc
     saver.save(sess,'ckpt/mnist.ckpt',global_step=i+1)
  f.close()

#验证阶段
else:
  model_file=tf.train.latest_checkpoint('ckpt/')
  saver.restore(sess,model_file)
  val_loss,val_acc=sess.run([loss,acc], feed_dict={x: mnist.test.images, y_: mnist.test.labels})
  print('val_loss:%f, val_acc:%f'%(val_loss,val_acc))
sess.close()

参考文章:https://3water.com/article/138779.htm

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python list 合并连接字符串的方法
Mar 09 Python
Python中一些自然语言工具的使用的入门教程
Apr 13 Python
Python中List.count()方法的使用教程
May 20 Python
Python单链表简单实现代码
Apr 27 Python
Python实现的三层BP神经网络算法示例
Feb 07 Python
解决Python2.7读写文件中的中文乱码问题
Apr 12 Python
python实现公司年会抽奖程序
Jan 22 Python
pip指定python位置安装软件包的方法
Jul 12 Python
python wxpython 实现界面跳转功能
Dec 17 Python
Pandas读取csv时如何设置列名
Jun 02 Python
python 使用elasticsearch 实现翻页的三种方式
Jul 31 Python
python subprocess pipe 实时输出日志的操作
Dec 05 Python
tensorflow 使用flags定义命令行参数的方法
Apr 23 #Python
Tensorflow之Saver的用法详解
Apr 23 #Python
python获取文件路径、文件名、后缀名的实例
Apr 23 #Python
Python基于FTP模块实现ftp文件上传操作示例
Apr 23 #Python
Python基于whois模块简单识别网站域名及所有者的方法
Apr 23 #Python
Python实现自定义顺序、排列写入数据到Excel的方法
Apr 23 #Python
python 获取指定文件夹下所有文件名称并写入列表的实例
Apr 23 #Python
You might like
隐藏你的.php文件的实现方法
2007/03/19 PHP
解析PHP自带的进位制之间的转换函数
2013/06/08 PHP
php的laravel框架快速集成微信登录的方法
2016/12/12 PHP
PHP swoole和redis异步任务实现方法分析
2019/08/12 PHP
基础的prototype.js常用函数及其用法
2007/03/10 Javascript
jQuery 开天辟地入门篇一
2009/12/09 Javascript
基于jQuery的固定表格头部的代码(IE6,7,8测试通过)
2010/05/18 Javascript
jQuery Select(单选) 模拟插件 V1.3.62 改进版
2010/07/17 Javascript
基于mootools 1.3框架下的图片滑动效果代码
2011/04/22 Javascript
jquery实现商品拖动选择效果代码(自写)
2013/05/28 Javascript
使用AngularJS创建自定义的过滤器的方法
2015/06/18 Javascript
JavaScript中setTimeout的那些事儿
2016/11/14 Javascript
彻底解决 webpack 打包文件体积过大问题
2017/07/07 Javascript
vue动态路由配置及路由传参的方式
2018/05/23 Javascript
jQuery+CSS实现的标签页效果示例【测试可用】
2018/08/14 jQuery
小程序兼容安卓和IOS数据处理问题及坑
2018/09/18 Javascript
如何在基于vue-cli的项目自定义打包环境
2018/11/10 Javascript
vue动态绑定class选中当前列表变色的方法示例
2018/12/19 Javascript
理解JavaScript中的Proxy 与 Reflection API
2020/09/21 Javascript
Python中使用glob和rmtree删除目录子目录及所有文件的例子
2014/11/21 Python
对python指数、幂数拟合curve_fit详解
2018/12/29 Python
python多线程同步实例教程
2019/08/11 Python
分享PyCharm的几个使用技巧
2019/11/10 Python
python实现最短路径的实例方法
2020/07/19 Python
如何使用python自带IDLE的几种方法
2020/10/10 Python
英国最全面的橄榄球联盟门票网站:Live Rugby Tickets
2018/10/06 全球购物
Discard Protocol抛弃协议的作用是什么
2015/10/10 面试题
优秀幼教自荐信
2014/02/03 职场文书
市场营销管理毕业生自荐信
2014/03/03 职场文书
购房协议书范本
2014/04/11 职场文书
2015年七一建党节演讲稿
2015/03/19 职场文书
新郎父母婚礼答谢词
2015/09/29 职场文书
七年级之家长会发言稿范文
2019/09/04 职场文书
NGINX 权限控制文件预览和下载的实现原理
2022/01/18 Servers
代码复现python目标检测yolo3详解预测
2022/05/06 Python
了解MySQL查询语句执行过程(5大组件)
2022/08/14 MySQL