首页 > 编程 > Python > 正文

tensorflow2.0保存和恢复模型3种方法

2020-02-15 21:26:34
字体:
来源:转载
供稿:网友

方法1:只保存模型的权重和偏置

这种方法不会保存整个网络的结构,只是保存模型的权重和偏置,所以在后期恢复模型之前,必须手动创建和之前模型一模一样的模型,以保证权重和偏置的维度和保存之前的相同。

tf.keras.model类中的save_weights方法和load_weights方法,参数解释我就直接搬运官网的内容了。

save_weights( filepath, overwrite=True, save_format=None)

Arguments:

filepath: String, path to the file to save the weights to. When saving in TensorFlow format, this is the prefix used for checkpoint files (multiple files are generated). Note that the '.h5' suffix causes weights to be saved in HDF5 format.

overwrite: Whether to silently overwrite any existing file at the target location, or provide the user with a manual prompt.

save_format: Either 'tf' or 'h5'. A filepath ending in '.h5' or '.keras' will default to HDF5 if save_format is None. Otherwise None defaults to 'tf'.

load_weights( filepath, by_name=False)

实例1:

import tensorflow as tffrom tensorflow import kerasfrom tensorflow.keras import datasets, layers, optimizers # step1 加载训练集和测试集合mnist = tf.keras.datasets.mnist(x_train, y_train),(x_test, y_test) = mnist.load_data()x_train, x_test = x_train / 255.0, x_test / 255.0  # step2 创建模型def create_model(): return tf.keras.models.Sequential([ tf.keras.layers.Flatten(input_shape=(28, 28)), tf.keras.layers.Dense(512, activation='relu'), tf.keras.layers.Dropout(0.2), tf.keras.layers.Dense(10, activation='softmax') ])model = create_model() # step3 编译模型 主要是确定优化方法,损失函数等model.compile(optimizer='adam',  loss='sparse_categorical_crossentropy',  metrics=['accuracy']) # step4 模型训练 训练一个epochsmodel.fit(x=x_train,  y=y_train,  epochs=1,  ) # step5 模型测试loss, acc = model.evaluate(x_test, y_test)print("train model, accuracy:{:5.2f}%".format(100 * acc)) # step6 保存模型的权重和偏置model.save_weights('./save_weights/my_save_weights') # step7 删除模型del model # step8 重新创建模型model = create_model()model.compile(optimizer='adam',  loss='sparse_categorical_crossentropy',  metrics=['accuracy']) # step9 恢复权重model.load_weights('./save_weights/my_save_weights') # step10 测试模型loss, acc = model.evaluate(x_test, y_test)print("Restored model, accuracy:{:5.2f}%".format(100 * acc))

train model, accuracy:96.55%

Restored model, accuracy:96.55%

可以看到在模型的权重和偏置恢复之后,在测试集合上同样达到了训练之前相同的准确率。

方法2:直接保存整个模型

这种方法会将网络的结构,权重和优化器的状态等参数全部保存下来,后期恢复的时候就没必要创建新的网络了。

tf.keras.model类中的save方法和load_model方法

发表评论 共有条评论
用户名: 密码:
验证码: 匿名发表