Skip to main content
Version: v0.1.9

Zero Redundancy Optimizer and Zero Offload

Author: Zhujie, Shenggui Li, Hongxin Liu

Prerequisite:

Example Code

Related Paper

Introduction​

The Zero Redundancy Optimizer (ZeRO) removes the memory redundancies across data-parallel processes by partitioning three model states (optimizer states, gradients, and parameters) instead of replicating them. By doing so, memory efficiency is boosted drastically compared to classic data parallelism, while the computational granularity and communication efficiency is retained.

  1. Shard Optimizer States: The optimizer states (e.g., for Adam optimizer, 32-bit weights, and the first and second momentum estimates) are partitioned across the processes, so that each process updates only its partition.
  1. Shard Gradient: After reduction inside data parallel process group, gradient tensors are also partitioned such that each process only stores the gradients corresponding to its partition of the optimizer states. Note, Colossal converts gradient into fp32 format to participate in parameter updating.

  2. Shard Parameter: The 16-bit model parameters are partitioned across the processes of a data parallel group.

  3. Gemini: Dynamic heterogeneous memory space manager for paramters, gradients and optimizer states.

When we shard parameter, gradient and optimizer states, and set tensor placement policy to "cpu", we can use three figures to illustrate the training process.

Forward
Backward
Optimizer step

For more details about Gemini, click here.

Usage​

We provide two levels of API to use ZeRO.

  1. Low-level API: Use ShardedModel and ShardedOptimizer directly, and write your own training loop from scratch.
  2. High-level API: Use Engine and configure ZeRO in the configuration file. You can use Trainer or write your own training loop.

We provide some shard strategies to manage the process of sharding your model:

from colossalai.zero.shard_utils import BucketTensorShardStrategy, TensorShardStrategy

TensorShardStrategy is a naive implementation that shard each tensor evenly over all ranks. BucketTensorShardStrategy fattens the tensors belonging to an operator, e.g. nn.Linear, and then shards them evenly over all ranks. It is especially useful when an operator contains bias since we cannot utilize network bandwidth well if we only gather a bias tensor (bias is usually small).

⚠ī¸ You have to initialize your model with colossalai.zero.init_ctx.ZeroInitContext.

Here is a simple example:

shard_strategy = TensorShardStrategy()
with ZeroInitContext(target_device=torch.cuda.current_device(),
shard_strategy=shard_strategy,
shard_param=True):
model = torch.nn.Linear(2, 2)

You can see the exact usage of ZeroInitContext in API Reference

If you use high-level API, you must configure shard_strategy in config file.

Next, we will firstly give you a configuration template to help you configure ZeRO when using high-level API. Then, we will give you an example of using a low-level API.

We now provide from colossalai.nn.optimizer.HybridAdam, which is faster than torch.optim.Adam. For more details, see API Reference.

Configure ZeRO with high-level API​

You can use Engine and configure ZeRO in the configuration file.

Here is a configuration template:

from colossalai.zero.shard_utils import TensorShardStrategy

zero = dict(
model_config=dict(
shard_strategy=TensorShardStrategy(),
reduce_scatter_bucket_size_mb=25,
fp32_reduce_scatter=False,
tensor_placement_policy="cuda",
gradient_predivide_factor=1.0,
reuse_fp16_shard=False
),
optimizer_config=dict(
gpu_margin_mem_ratio=0.8,
initial_scale=2**5,
min_scale=1,
growth_factor=2,
backoff_factor=0.5,
growth_interval=1000,
hysteresis=2,
max_scale=2**32
)
)

model_config and optimizer_config are keyword arguments of ShardedModelV2 and ShardedOptimizerV2 respectively. For more details of these arguments, see ShardedModelV2 API Reference and ShardedOptimizerV2 API Reference.

⚠ī¸ If you use gradient accumulation, make sure reuse_fp16_shard is False.

⚠ī¸ If you set tensor_placement_policy to "auto", make sure no other processes use CUDA during your training.

You can initialize your model in this way:

import torch
import colossalai
from colossalai.zero.init_ctx import ZeroInitContext

with ZeroInitContext(target_device=torch.cuda.current_device(),
shard_strategy=gpc.config.zero.model_config.shard_strategy,
shard_param=True):
model = torch.nn.Linear(2, 2)

Then you can use Engine as usual.

The complete example of training GPT with high-level API can be found on GPT example.

Train GPT with low-level API​

In this example, we use Hugging Face Transformers. You have to install transformers before running this example. We will take GPT2 Medium as an example here.

This example is intended for showing you how to use ZeRO. For simplicity, we just use randomly generated data here.

First, we have to import essential libs:

import colossalai
import torch
import torch.nn as nn
from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.nn.optimizer import CPUAdam
from colossalai.zero.init_ctx import ZeroInitContext
from colossalai.zero.shard_utils import TensorShardStrategy
from colossalai.zero.sharded_model import ShardedModelV2
from colossalai.zero.sharded_optim import ShardedOptimizerV2
from transformers import GPT2Config, GPT2LMHeadModel

Then we simply wrap Hugging Face Transformers:

class GPTLMModel(nn.Module):
def __init__(self, hidden_size=768, num_layers=12, num_attention_heads=12, max_seq_len=1024, vocab_size=50257, checkpoint=False):
super().__init__()
self.checkpoint = checkpoint
self.model = GPT2LMHeadModel(GPT2Config(n_embd=hidden_size, n_layer=num_layers,
n_head=num_attention_heads, n_positions=max_seq_len, n_ctx=max_seq_len, vocab_size=vocab_size))
if checkpoint:
self.model.gradient_checkpointing_enable()

def forward(self, input_ids, attention_mask):
# Only return lm_logits
return self.model(input_ids=input_ids, attention_mask=attention_mask, use_cache=not self.checkpoint)[0]

def gpt2_medium(checkpoint=False):
return GPTLMModel(hidden_size=1024, num_layers=24, num_attention_heads=16, checkpoint=checkpoint)

Define our loss function:

class GPTLMLoss(nn.Module):
def __init__(self):
super().__init__()
self.loss_fn = nn.CrossEntropyLoss()

def forward(self, logits, labels):
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
return self.loss_fn(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))

As we pre-train GPT in this example, we just use a simple language model loss.

Write a function to get random inputs:

def get_data(batch_size, seq_len, vocab_size):
input_ids = torch.randint(0, vocab_size, (batch_size, seq_len), device=torch.cuda.current_device())
attention_mask = torch.ones_like(input_ids)
return input_ids, attention_mask

Finally, we can define our training loop:

def main():
BATCH_SIZE = 8
SEQ_LEN = 1024
VOCAB_SIZE = 50257
NUM_STEPS = 10
disable_existing_loggers()
colossalai.launch_from_torch(config={})
logger = get_dist_logger()

logger.info(get_mem_info(), ranks=[0])
# build GPT model
shard_strategy = TensorShardStrategy()
with ZeroInitContext(target_device=torch.cuda.current_device(), shard_strategy=shard_strategy, shard_param=True):
model = gpt2_medium(checkpoint=True)
# Set tensor_placement_policy='cpu', which will offload params, grads and os
model = ShardedModelV2(model, shard_strategy, tensor_placement_policy='cpu', reuse_fp16_shard=True)
logger.info(get_mem_info(prefix='After init model, '), ranks=[0])

# build criterion
criterion = GPTLMLoss()

# optimizer
optimizer = HybridAdam(model.parameters(), lr=1e-3)
optimizer = ShardedOptimizerV2(model, optimizer, initial_scale=2**5)
logger.info(get_mem_info(prefix='After init optim, '), ranks=[0])

model.train()
for n in range(NUM_STEPS):
# we just use randomly generated data here
input_ids, attn_mask = get_data(BATCH_SIZE, SEQ_LEN, VOCAB_SIZE)
optimizer.zero_grad()
outputs = model(input_ids, attn_mask)
loss = criterion(outputs, input_ids)
logger.info(get_mem_info(prefix=f'Forward [{n+1}/{NUM_STEPS}] '), ranks=[0])
optimizer.backward(loss)
logger.info(get_mem_info(prefix=f'Backward [{n+1}/{NUM_STEPS}] '), ranks=[0])
optimizer.step()
logger.info(get_mem_info(prefix=f'Optimizer step [{n+1}/{NUM_STEPS}] '), ranks=[0])

The complete example can be found on ZeRO example.