一、TensorFlow常规模型加载方法
保存模型
tf.train.Saver()类,.save(sess, ckpt文件目录)方法
参数名称 | 功能说明 | 默认值 |
var_list | Saver中存储变量集合 | 全局变量集合 |
reshape | 加载时是否恢复变量形状 | True |
sharded | 是否将变量轮循放在所有设备上 | True |
max_to_keep | 保留最近检查点个数 | 5 |
restore_sequentially | 是否按顺序恢复变量,模型较大时顺序恢复内存消耗小 | True |
var_list是字典形式{变量名字符串: 变量符号},相对应的restore也根据同样形式的字典将ckpt中的字符串对应的变量加载给程序中的符号。
如果Saver给定了字典作为加载方式,则按照字典来,如:saver = tf.train.Saver({"v/ExponentialMovingAverage":v}),否则每个变量寻找自己的name属性在ckpt中的对应值进行加载。
加载模型
当我们基于checkpoint文件(ckpt)加载参数时,实际上我们使用Saver.restore取代了initializer的初始化
checkpoint文件会记录保存信息,通过它可以定位最新保存的模型:
ckpt = tf.train.get_checkpoint_state('./model/')print(ckpt.model_checkpoint_path)
.meta文件保存了当前图结构
.index文件保存了当前参数名
.data文件保存了当前参数值
tf.train.import_meta_graph函数给出model.ckpt-n.meta的路径后会加载图结构,并返回saver对象
ckpt = tf.train.get_checkpoint_state('./model/')
tf.train.Saver函数会返回加载默认图的saver对象,saver对象初始化时可以指定变量映射方式,根据名字映射变量(『TensorFlow』滑动平均)
saver = tf.train.Saver({"v/ExponentialMovingAverage":v})
saver.restore函数给出model.ckpt-n的路径后会自动寻找参数名-值文件进行加载
saver.restore(sess,'./model/model.ckpt-0')saver.restore(sess,ckpt.model_checkpoint_path)
1.不加载图结构,只加载参数
由于实际上我们参数保存的都是Variable变量的值,所以其他的参数值(例如batch_size)等,我们在restore时可能希望修改,但是图结构在train时一般就已经确定了,所以我们可以使用tf.Graph().as_default()新建一个默认图(建议使用上下文环境),利用这个新图修改和变量无关的参值大小,从而达到目的。
新闻热点
疑难解答