查看: 1966|回复: 1

tensorflow中使用save和restore保存和恢复模型

[复制链接]

166

主题

616

帖子

1万

积分

xdtech

Rank: 5Rank: 5

积分
11792
发表于 2019-11-1 00:31:36 | 显示全部楼层 |阅读模式
我们在训练模型过程中,有时训练一段时间后,往往想要在验证集上验证一下,模型是否存在过拟合,然后视验证情况,再选择继续训练还是修改模型参数。这时tensorflow提供的Saver类,就能很好的帮助到我们。
当我们保存一个模型到指定路径后,还目录下将会出现四种类型的文件:

checkpoint: 具有最近检查点列表的协议缓存区
.data: 保存模型中的变量
.index: 标识检查点
.meta: 保存模型中计算图的结构信息

1、tf.train.Saver( )

首先需要在程序中定义一个saver操作,该定义在会话结构之外。

import tensorflow as tf
...
saver = tf.train.Saver()
...
with tf.Session() as sess:
    ...
1
2
3
4
5
6
7
这样一个saver操作就定义好了。tf.train.Saver( )有几个我们平时常用到的参数,具体如下:

max_to_keep: 设置保存最近的检查点文件的个数,例如max_to_keep=4,就是只保存最新的四个模型。
keep_checkpoint_every_n_hours: 设置每隔多长时间保存一次模型。
savable_variables: 可以设置将要保存的tensor。如tf.train.Saver([w1, w2]),就是只保存w1和w2。如果不指定任何想要保存的tensor,saver默认保存所有的tensor。

2、saver.save( )

在使用tf.train.Saver( )创建了saver操作之后,我们就可以在一个会话中保存我们的模型。

...
with tf.Session() as sess:
    ...
    for epoch in range(10):
        ...
        saver.save(sess, model_path, global_step=epoch, write_meta_graph=False)
        ...
1
2
3
4
5
6
7
使用上面代码中的saver.save( )就可以按照我们的要求保存模型。其中参数说明如下:

sess: 会话对象
model_path: 模型保存的路径
global_step=epoch: 可选,在我们保存的文件名字中,加上迭代次数,以方便我们区分保 存的文件是经过多少次的训练迭代。如global_step = 2,则我们保存的文件名字为-2.data-00000-of-00001,-2.index,-2.meta。
write_meta_graph: 可选,False: 只保存一次.meta文件;True:根据我们设置的保存次 数,保存多次.meta文件。这里对这个参数加一点说明:因为模型一旦建立好之后,计算图的结构就确定了,所以每次保存的.meta文件都是一样的,有时为了节省存储空间,我们选择只保存一次.meta文件。

3、saver.restore( )

在保存了一个模型之后,我们使用saver.restore( )来恢复模型。恢复操作也需要在session会话中。我们可以创建一个新的会话:

with tf.Session() as sess:
    saver = tf.train.import_meta_graph('model_path/-2.meta') # 以.meta文件名为-2.meta为例
    saver.restore(sess, tf.train.latest_checkpoint(model_path))
    ...
1
2
3
4
我们首先要通过saver = tf.train.import_meta_graph(‘model_path/-2.meta’)加载模型的计算图结构,然后通过saver.restore(sess, tf.train.latest_checkpoint(‘model_path’))来恢复我们保存的所有变量和操作。其中tf.train.latest_checkpoint(‘model_path’)是从最近的检查点中恢复模型。

以上就是保存和恢复全部模型的操作。在实际进行模型优化时,有时我们会对原来的模型进行修改,如增加网络的深度,重新定义一些权重变量,重新定义精确度指标等。这时,我们就可以通过变量或操作的名字来加载指定的变量或操作。

with tf.Session() as sess:
    saver = tf.train.import_meta_graph('model_path/-2.meta') # 以.meta文件名为-2.meta为例
    saver.restore(sess, tf.train.latest_checkpoint(model_path))
    graph = tf.get_default_graph()
    # 加载网络权重变量w1和w2,"w1:0"中,weight1为定义w1变量时指定的名字,当此tensor没有重复时,后面加上0
    w1 = graph.get_tensor_by_name("weight1:0")
    w2 = graph.get_tensor_by_name("weight2:0")
    # 恢复网络中的第七全连接层,fully_connected7为定义fc7时指定的名字,当此tensor没有重复时,后面加上0
    fc7 = graph.get_tensor_by_name("fully_connected7:0")
————————————————


回复

使用道具 举报

166

主题

616

帖子

1万

积分

xdtech

Rank: 5Rank: 5

积分
11792
 楼主| 发表于 2019-11-1 00:31:45 | 显示全部楼层
回复

使用道具 举报

您需要登录后才可以回帖 登录 | 立即注册

本版积分规则

快速回复 返回顶部 返回列表