跳到主要内容

Booster API

作者: Mingyan Jiang, Jianghai Chen, Baizhou Zhang

预备知识:

示例代码

简介

在我们的新设计中, colossalai.booster 代替 colossalai.initialize 将特征(例如,模型、优化器、数据加载器)无缝注入到您的训练组件中。 使用 booster API, 您可以更友好地将我们的并行策略整合到待训练模型中. 调用 colossalai.booster 是您进入训练流程前的正常操作。 在下面的章节中,我们将介绍 colossalai.booster 是如何工作的以及使用时我们要注意的细节。

Booster 插件

Booster 插件是管理并行配置的重要组件(eg:gemini 插件封装了 gemini 加速方案)。目前支持的插件如下:

HybridParallelPlugin: HybridParallelPlugin 插件封装了混合并行的加速解决方案。它提供的接口可以在张量并行,流水线并行以及两种数据并行方法(DDP, Zero)间进行任意的组合。

GeminiPlugin: GeminiPlugin 插件封装了 gemini 加速解决方案,即基于块内存管理的 ZeRO 优化方案。

TorchDDPPlugin: TorchDDPPlugin 插件封装了Pytorch的DDP加速方案,实现了模型级别的数据并行,可以跨多机运行。

LowLevelZeroPlugin: LowLevelZeroPlugin 插件封装了零冗余优化器的 1/2 阶段。阶段 1:切分优化器参数,分发到各并发进程或并发 GPU 上。阶段 2:切分优化器参数及梯度,分发到各并发进程或并发 GPU 上。

TorchFSDPPlugin: TorchFSDPPlugin封装了 Pytorch的FSDP加速方案,可以用于零冗余优化器数据并行(ZeroDP)的训练。

若想了解更多关于插件的用法细节,请参考Booster 插件章节。

有一些插件支持懒惰初始化,它能节省初始化大模型时的内存占用。详情请参考懒惰初始化

Booster 接口

class
 

colossalai.booster.Booster

(device: typing.Optional[str] = None, mixed_precision: typing.Union[colossalai.booster.mixed_precision.mixed_precision_base.MixedPrecision, str, NoneType] = None, plugin: typing.Optional[colossalai.booster.plugin.plugin_base.Plugin] = None)
Parameters
  • device (str or torch.device) -- The device to run the training. Default: None. If plugin is not used or plugin doesn't control the device, this argument will be set as training device ('cuda' will be used if argument is None).
  • mixed_precision (str or MixedPrecision) -- The mixed precision to run the training. Default: None. If the argument is a string, it can be 'fp16', 'fp16_apex', 'bf16', or 'fp8'. 'fp16' would use PyTorch AMP while fp16_apex would use Nvidia Apex.
  • plugin (Plugin) -- The plugin to run the training. Default: None.
Description

Booster is a high-level API for training neural networks. It provides a unified interface for training with different precision, accelerator, and plugin.

Example
# Following is pseudocode

colossalai.launch(...)
plugin = GeminiPlugin(...)
booster = Booster(precision='fp16', plugin=plugin)

model = GPT2()
optimizer = HybridAdam(model.parameters())
dataloader = plugin.prepare_dataloader(train_dataset, batch_size=8)
lr_scheduler = LinearWarmupScheduler()
criterion = GPTLMLoss()

model, optimizer, criterion, dataloader, lr_scheduler = booster.boost(model, optimizer, criterion, dataloader, lr_scheduler)

for epoch in range(max_epochs):
    for input_ids, attention_mask in dataloader:
        outputs = model(input_ids.cuda(), attention_mask.cuda())
        loss = criterion(outputs.logits, input_ids)
        booster.backward(loss, optimizer)
        optimizer.step()
        lr_scheduler.step()
        optimizer.zero_grad()
function
 

backward

(loss: Tensor, optimizer: Optimizer)
Parameters
  • loss (torch.Tensor) -- The loss for backpropagation.
  • optimizer (Optimizer) -- The optimizer to be updated.
Description
Execution of backward during training step.
function
 

boost

(model: Module, optimizer: typing.Optional[torch.optim.optimizer.Optimizer] = None, criterion: typing.Optional[typing.Callable] = None, dataloader: typing.Optional[torch.utils.data.dataloader.DataLoader] = None, lr_scheduler: typing.Optional[torch.optim.lr_scheduler._LRScheduler] = None)
Parameters
  • model (nn.Module) -- Convert model into a wrapped model for distributive training. The model might be decorated or partitioned by plugin's strategy after execution of this method.
  • optimizer (Optimizer, optional) -- Convert optimizer into a wrapped optimizer for distributive training. The optimizer's param groups or states might be decorated or partitioned by plugin's strategy after execution of this method. Defaults to None.
  • criterion (Callable, optional) -- The function that calculates loss. Defaults to None.
  • dataloader (DataLoader, optional) -- The prepared dataloader for training. Defaults to None.
  • lr_scheduler (LRScheduler, optional) -- The learning scheduler for training. Defaults to None.
Returns

List[Union[nn.Module, Optimizer, LRScheduler, DataLoader]]: The list of boosted input arguments.

Description

Wrap and inject features to the passed in model, optimizer, criterion, lr_scheduler, and dataloader.

function
 

execute_pipeline

(data_iter: typing.Iterator, model: Module, criterion: typing.Callable[[typing.Any, typing.Any], torch.Tensor], optimizer: typing.Optional[torch.optim.optimizer.Optimizer] = None, return_loss: bool = True, return_outputs: bool = False)
Parameters

data_iter(Iterator) -- The iterator for getting the next batch of data. Usually there are two ways to obtain this argument:

  1. wrap the dataloader to iterator through: iter(dataloader)
  2. get the next batch from dataloader, and wrap this batch to iterator: iter([batch])
  • model (nn.Module) -- The model to execute forward/backward, it should be a model wrapped by a plugin that supports pipeline. criterion -- (Callable[[Any, Any], torch.Tensor]): Criterion to be used. It should take two arguments: model outputs and inputs, and returns loss tensor. 'lambda y, x: loss_fn(y)' can turn a normal loss function into a valid two-argument criterion here.
    • optimizer (Optimizer, optional) -- The optimizer for execution of backward. Can be None when only doing forward (i.e. evaluation). Defaults to None.
  • return_loss (bool, optional) -- Whether to return loss in the dict returned by this method. Defaults to True.
  • return_output (bool, optional) -- Whether to return Huggingface style model outputs in the dict returned by this method. Defaults to False.
Returns

Dict[str, Any]: Output dict in the form of {'loss': ..., 'outputs': ...}. ret_dict['loss'] is the loss of forward if return_loss is set to True, else None. ret_dict['outputs'] is the Huggingface style model outputs during forward if return_output is set to True, else None.

Description

Execute forward & backward when utilizing pipeline parallel. Return loss or Huggingface style model outputs if needed.

Warning: This function is tailored for the scenario of pipeline parallel. As a result, please don't do the forward/backward pass in the conventional way (model(input)/loss.backward()) when doing pipeline parallel training with booster, which will cause unexpected errors.

function
 

load_lr_scheduler

(lr_scheduler: _LRScheduler, checkpoint: str)
Parameters
  • lr_scheduler (LRScheduler) -- A lr scheduler boosted by Booster.
  • checkpoint (str) -- Path to the checkpoint. It must be a local file path.
Description
Load lr scheduler from checkpoint.
function
 

load_model

(model: typing.Union[torch.nn.modules.module.Module, colossalai.interface.model.ModelWrapper], checkpoint: str, strict: bool = True)
Parameters
  • model (nn.Module or ModelWrapper) -- A model boosted by Booster.
  • checkpoint (str) -- Path to the checkpoint. It must be a local path. It should be a directory path if the checkpoint is sharded. Otherwise, it should be a file path.
  • strict (bool, optional) -- whether to strictly enforce that the keys in :attr:state_dict match the keys returned by this module's [~torch.nn.Module.state_dict] function. Defaults to True.
Description
Load model from checkpoint.
function
 

load_optimizer

(optimizer: Optimizer, checkpoint: str)
Parameters
  • optimizer (Optimizer) -- An optimizer boosted by Booster.
  • checkpoint (str) -- Path to the checkpoint. It must be a local path. It should be a directory path if the checkpoint is sharded. Otherwise, it should be a file path.
  • prefix (str, optional) -- A prefix added to parameter and buffer names to compose the keys in state_dict. Defaults to None.
  • size_per_shard (int, optional) -- Maximum size of checkpoint shard file in MB. This is useful only when shard=True. Defaults to 1024.
Description
Load optimizer from checkpoint.
function
 

no_sync

(model: Module = None, optimizer: OptimizerWrapper = None)
Parameters
  • model (nn.Module) -- The model to be disabled gradient synchronization, for DDP
  • optimizer (OptimizerWrapper) -- The optimizer to be disabled gradient synchronization, for ZeRO1-1
Returns

contextmanager: Context to disable gradient synchronization.

Description
Context manager to disable gradient synchronization across DP process groups. Support torch DDP and Low Level ZeRO-1 for now.
function
 

save_lr_scheduler

(lr_scheduler: _LRScheduler, checkpoint: str)
Parameters
  • lr_scheduler (LRScheduler) -- A lr scheduler boosted by Booster.
  • checkpoint (str) -- Path to the checkpoint. It must be a local file path.
Description
Save lr scheduler to checkpoint.
function
 

save_model

(model: typing.Union[torch.nn.modules.module.Module, colossalai.interface.model.ModelWrapper], checkpoint: str, shard: bool = False, gather_dtensor: bool = True, prefix: typing.Optional[str] = None, size_per_shard: int = 1024, use_safetensors: bool = False)
Parameters
  • model (nn.Module or ModelWrapper) -- A model boosted by Booster.
  • checkpoint (str) -- Path to the checkpoint. It must be a local path. It is a file path if shard=False. Otherwise, it is a directory path.
  • shard (bool, optional) -- Whether to save checkpoint a sharded way. If true, the checkpoint will be a folder with the same format as Huggingface transformers checkpoint. Otherwise, it will be a single file. Defaults to False.
  • gather_dtensor (bool, optional) -- whether to gather the distributed tensor to the first device. Default: True.
  • prefix (str, optional) -- A prefix added to parameter and buffer names to compose the keys in state_dict. Defaults to None.
  • size_per_shard (int, optional) -- Maximum size of checkpoint shard file in MB. This is useful only when shard=True. Defaults to 1024.
  • use_safetensors (bool, optional) -- whether to use safe tensors. Default: False. If set to True, the checkpoint will be saved.
Description
Save model to checkpoint.
function
 

save_optimizer

(optimizer: Optimizer, checkpoint: str, shard: bool = False, gather_dtensor: bool = True, prefix: typing.Optional[str] = None, size_per_shard: int = 1024)
Parameters
  • optimizer (Optimizer) -- An optimizer boosted by Booster.
  • checkpoint (str) -- Path to the checkpoint. It must be a local path. It is a file path if shard=False. Otherwise, it is a directory path.
  • shard (bool, optional) -- Whether to save checkpoint a sharded way. If true, the checkpoint will be a folder. Otherwise, it will be a single file. Defaults to False.
  • gather_dtensor (bool) -- whether to gather the distributed tensor to the first device. Default: True.
  • prefix (str, optional) -- A prefix added to parameter and buffer names to compose the keys in state_dict. Defaults to None.
  • size_per_shard (int, optional) -- Maximum size of checkpoint shard file in MB. This is useful only when shard=True. Defaults to 1024.
Description

Save optimizer to checkpoint.

使用方法及示例

在使用 colossalai 训练时,首先需要在训练脚本的开头启动分布式环境,并创建需要使用的模型、优化器、损失函数、数据加载器等对象。之后,调用booster.boost 将特征注入到这些对象中,您就可以使用我们的 booster API 去进行您接下来的训练流程。

以下是一个伪代码示例,将展示如何使用我们的 booster API 进行模型训练:

import torch
from torch.optim import SGD
from torchvision.models import resnet18

import colossalai
from colossalai.booster import Booster
from colossalai.booster.plugin import TorchDDPPlugin

def train():
# launch colossalai
colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost')

# create plugin and objects for training
plugin = TorchDDPPlugin()
booster = Booster(plugin=plugin)
model = resnet18()
criterion = lambda x: x.mean()
optimizer = SGD((model.parameters()), lr=0.001)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.1)

# use booster.boost to wrap the training objects
model, optimizer, criterion, _, scheduler = booster.boost(model, optimizer, criterion, lr_scheduler=scheduler)

# do training as normal, except that the backward should be called by booster
x = torch.randn(4, 3, 224, 224)
x = x.to('cuda')
output = model(x)
loss = criterion(output)
booster.backward(loss, optimizer)
optimizer.clip_grad_by_norm(1.0)
optimizer.step()
scheduler.step()
optimizer.zero_grad()

# checkpointing using booster api
save_path = "./model"
booster.save_model(model, save_path, shard=True, size_per_shard=10, use_safetensors=True)

new_model = resnet18()
booster.load_model(new_model, save_path)

更多的Booster设计细节请参考这一页面