Modern LLMs don’t fit in the memory of a single GPU. For example, assuming that each parameter is stored as FP16, a 100B-parameter LLM needs 200 GB of memory just to store its model parameters! At the time of writing, this exceeds the memory of the most advanced GPUs including the NVIDIA B200, which has 192 GB of HBM3e memory. Our calculation also didn’t include gradients and optimizer states; here is a more comprehensive overview:
Quantity stored in memory | Description | Training | Inference |
---|---|---|---|
Model parameters | Needed for the forward pass. | ✅ | ✅ |
Gradients | One per model parameter. Gradients are computed during the backward pass and are used to update the model parameters. | ✅ | ❌ |
Optimizer states | For each parameter, the Adam optimizer stores (a) the exponentially weighted moving average of the gradient, and (b) the exponentially weighted moving average of the squared gradient. Thus, Adam uses 3x the memory of just storing the parameters alone. | ✅ | ❌ |
Activations | Stored during the forward pass and used during the backward pass. These need to be stored for all layers until the backward pass is complete. The memory requirement scales linearly with the batch size. | ✅ | ❌ |
Thus, large models need to be “split” across multiple GPUs so that the necessary quantities fit in GPU memory. This is achieved with sharding. In this post, I’ll discuss the sharding strategy known as fully sharded data parallel (FSDP), specifically the PyTorch implementation.
FSDP shards model parameters, gradients, and optimizer states across GPUs. This makes it extremely memory efficient. FSDP breaks a model into so-called “FSDP units”. A unit can consist of multiple layers or even a single layer. It is up to the user to define.
The model is decomposed into FSDP units. Source: Figure 1 in the PyTorch FSDP paper.
Within each unit, the associated model parameters are flattened, concatenated, and sharded across GPUs. During the forward (or backward) pass, the sharded parameters are communicated and recovered on-demand before computations and discarded right after. This ensures that FSDP only has to materialize parameters from one unit at a time, which drastically reduces peak memory consumption for every GPU.
Consider the FSDP unit in the diagram below:
Full sharding, across 16 GPUs, of an FSDP unit comprising a single nn.Linear
layer whose parameters are a weight matrix and a bias. Every GPU holds one model parameter from the unit’s FlatParameter
, with the last rank holding the padded value. Source: Figure 3 in the PyTorch FSDP paper.
Let’s walk through the flatten-concat-chunk algorithm: