Skip to main content
Version: v0.1.9

Model Checkpoint

Author : Guangyang Lu

Prerequisite:

Example Code:

This function is experiential.

Introduction​

In this tutorial, you will learn how to save and load model checkpoints.

To leverage the power of parallel strategies in Colossal-AI, modifications to models and tensors are needed, for which you cannot directly use torch.save or torch.load to save or load model checkpoints. Therefore, we have provided you with the API to achieve the same thing.

Moreover, when loading, you are not demanded to use the same parallel strategy as saving.

How to use​

Save​

There are two ways to train a model in Colossal-AI, by engine or by trainer. Be aware that we only save the state_dict. Therefore, when loading the checkpoints, you need to define the model first.

Save when using engine​

from colossalai.utils import save_checkpoint
model = ...
engine, _, _, _ = colossalai.initialize(model=model, ...)
for epoch in range(num_epochs):
... # do some training
save_checkpoint('xxx.pt', epoch, model)

Save when using trainer​

from colossalai.trainer import Trainer, hooks
model = ...
engine, _, _, _ = colossalai.initialize(model=model, ...)
trainer = Trainer(engine, ...)
hook_list = [
hooks.SaveCheckpointHook(1, 'xxx.pt', model)
...]

trainer.fit(...
hook=hook_list)

Load​

from colossalai.utils import load_checkpoint
model = ...
load_checkpoint('xxx.pt', model)
... # train or test