跳到主要内容

懒惰初始化

作者: Hongxin Liu

前置教程:

简介

懒惰初始化延迟了模型的初始化。它能够节省在大模型初始化时的内存占用。

如果你的模型有 N 十亿个参数并且你的内存(或显存)为 M GB, 我们推荐您在 4N >= M 时使用懒惰初始化。否则,懒惰初始化不是必须的。

使用

懒惰初始化必须与 booster 一起使用。

API 参考

class
 

colossalai.lazy.LazyInitContext

(tensor_cls: typing.Union[colossalai.lazy.lazy_init._MyTensor, colossalai.lazy.lazy_init.LazyTensor] = <class 'colossalai.lazy.lazy_init.LazyTensor'>, default_device: typing.Union[str, torch.device, int, NoneType] = None)
Parameters
  • tensor_cls (Union[_MyTensor, LazyTensor], optional) -- This is only for test. Defaults to LazyTensor.
  • default_device (Optional[Union[torch.device, str, int]], optional) -- Defalt device for initialization. If it's cuda, initilization will be accelerated, but cuda memory will be allocated. By default, it's cpu. Defaults to None.
Description
Context manager for lazy initialization. Enables initializing the model without allocating real memory.
function
 

materialize

(module: Module, verbose: bool = False)
Parameters
  • module (nn.Module) -- Target nn.Module
  • verbose (bool) -- Whether to print lazy initialization rate. Defaults to False.
Description
Initialize all `Parameter` from `LazyTensor`. This function will modify the module in-place.

例子

import colossalai
from colossalai.lazy import LazyInitContext
from colossalai.booster import Booster
from colossalai.booster.plugin import GeminiPlugin

from transformers import LlamaForCausalLM, LlamaConfig, BertForPreTraining

colossalai.launch({})
plugin = GeminiPlugin()
booster = Booster(plugin)

# 1. Initialize model from scratch
# Initialization on cuda will accelerate the initialization process but take more GPU memory.
with LazyInitContext(default_device="cuda"):
model = LlamaForCausalLM(LlamaConfig(hidden_size=64, intermediate_size=172, num_hidden_layers=4, num_attention_heads=4))
model, *_ = booster.boost(model)

# 2. Initialize model from pretrained
with LazyInitContext():
model = BertForPreTraining.from_pretrained("prajjwal1/bert-tiny")
model, *_ = booster.boost(model)

⚠️ 使用懒惰初始化加载预训练模型在 colossalai>0.3.3 或主分支上支持。

限制

我们提到,懒惰初始化必须与 booster 一起使用。只有几个插件支持它。

插件支持情况备注
Gemini
Hybrid Parallel
Low Level Zero不需要
Torch DDP不兼容
Torch FSDP不兼容

不是所有的模型都可以懒惰初始化。在某些情况下,一部分参数/缓冲区可能会被提前初始化。但是不用担心,这部分通常只占整个模型的一小部分。

并且一些模型完全不支持,会引发错误。我们测试了 torchvision, diffusers, timm, transformers, torchaudio 和 torchrec 中的模型。以下模型不受支持:

模型分类
wav2vec2_basetorchaudio
hubert_basetorchaudio
ViTModeltransformers
ViTForMaskedImageModelingtransformers
ViTForImageClassificationtransformers
Blip2Modeltransformers
Blip2ForConditionalGenerationtransformers