Scaling LLM Training: Deep Dive into ZeRO and FSDP for Multi-GPU Systems

Authors
  • avatar
    Name
    Nino
    Occupation
    Senior Tech Editor

As Large Language Models (LLMs) like DeepSeek-V3 and Claude 3.5 Sonnet continue to scale into hundreds of billions of parameters, the primary bottleneck for developers isn't just raw compute power—it is memory. Training or even fine-tuning these models on a single GPU is often impossible because the model states exceed the VRAM capacity of even the most high-end hardware like the H100. To solve this, technical teams rely on advanced distributed training techniques. While platforms like n1n.ai simplify access to these models via high-speed APIs, understanding the underlying infrastructure is crucial for engineers building custom solutions.

The Memory Wall in LLM Training

When training a model with NN parameters, the memory consumption is not just the NN parameters themselves. We must account for:

  1. Model Parameters: In FP16/BF16, this is 2N2N bytes.
  2. Gradients: Another 2N2N bytes in FP16/BF16.
  3. Optimizer States: For the Adam optimizer, we store a copy of the weights in FP32 (4N4N), momentum (4N4N), and variance (4N4N). Totaling 12N12N bytes.

For a 7B parameter model, the model states alone require roughly 112GB of VRAM, which is more than the 80GB available on an A100. This is where ZeRO and FSDP become essential.

Understanding ZeRO: Zero Redundancy Optimizer

Introduced by Microsoft Research as part of the DeepSpeed library, ZeRO eliminates memory redundancy across data-parallel processes. In traditional Distributed Data Parallel (DDP), every GPU keeps a full copy of the model states. ZeRO breaks this redundancy in three stages:

ZeRO-1: Optimizer State Sharding

The optimizer states (which occupy the largest chunk of memory) are partitioned across all available GPUs. If you have 8 GPUs, each GPU only stores 1/81/8 of the optimizer states. This reduces memory usage significantly without increasing communication overhead.

ZeRO-2: Gradient Sharding

Building on ZeRO-1, this stage also partitions the gradients. After the backward pass, gradients are reduced only on the GPUs that own the corresponding parameters. This further slashes the memory footprint.

ZeRO-3: Parameter Sharding

ZeRO-3 is the most aggressive form. It shards the actual model parameters across GPUs. During the forward and backward passes, the necessary parameters are fetched via all-gather operations and discarded immediately after use. This allows for training models with trillions of parameters, provided the aggregate memory of the cluster is sufficient.

For developers who prefer to focus on application logic rather than managing GPU clusters, n1n.ai provides a unified interface to access models that have already been trained using these advanced techniques.

PyTorch FSDP: The Industry Standard

Fully Sharded Data Parallel (FSDP) is PyTorch's native implementation of the ZeRO-3 concept. It provides a more integrated experience for the PyTorch ecosystem. FSDP shards parameters, gradients, and optimizer states, while also offering features like offloading to CPU for even larger models.

Implementing FSDP in PyTorch

To use FSDP, you wrap your model with the FullyShardedDataParallel class. Below is a simplified implementation structure:

import torch
import torch.nn as nn
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.fully_sharded_data_parallel import (
    CPUOffload,
    BackwardPrefetch,
)
from torch.distributed.fsdp.wrap import (
    size_based_auto_wrap_policy,
    enable_wrap,
    wrap,
)

def setup_fsdp_model(model, device_id):
    # Define a policy to decide which layers to shard
    my_auto_wrap_policy = functools.partial(
        size_based_auto_wrap_policy, min_num_params=10000
    )

    # Wrap the model
    sharded_model = FSDP(
        model,
        auto_wrap_policy=my_auto_wrap_policy,
        cpu_offload=CPUOffload(offload_params=True),
        device_id=device_id
    )
    return sharded_model

Implementation Logic from Scratch

If we were to conceptualize a ZeRO-style sharding from scratch, the logic follows a specific communication pattern:

  1. Partitioning: Divide the flat parameter tensor into KK chunks (where KK is the number of GPUs).
  2. Forward Pass: Each GPU broadcasts its shard to all other GPUs (All-Gather) so that every GPU can compute the forward pass for its specific data batch.
  3. Backward Pass: Similar to the forward pass, parameters are gathered to compute gradients.
  4. Reduction: Gradients are reduced (Reduce-Scatter) so that each GPU only receives the gradients for the parameters it owns.
  5. Update: Each GPU updates its local optimizer states and parameter shards.

This cycle ensures that at no point does a single GPU need to hold the entire state of the model and the optimizer simultaneously, except for the transient window of computation.

Pro Tip: Communication vs. Computation

While ZeRO-3/FSDP solves the memory problem, it introduces communication latency. To optimize performance, developers should use BackwardPrefetch to overlap the communication of the next layer's parameters with the current layer's gradient computation. High-speed interconnects like NVLink are vital here. If your local infrastructure cannot support this, using an API aggregator like n1n.ai allows you to leverage these optimizations indirectly through their high-performance inference endpoints.

Comparison Table: DDP vs. ZeRO vs. FSDP

FeatureDDPZeRO-1/2ZeRO-3 / FSDP
Parameter RedundancyHigh (Full Copy)High (Full Copy)Zero (Sharded)
Optimizer State RedundancyHighZeroZero
Communication OverheadLowMediumHigh
Max Model SizeLimited by 1 GPUMediumLimited by Cluster Total
Implementation ComplexityLowMediumHigh

Conclusion

Mastering ZeRO and FSDP is essential for any engineer working on the frontier of AI. By sharding model states, we break the physical limits of individual GPUs, enabling the training of the next generation of LLMs. However, the complexity of managing such clusters is non-trivial. For many production use cases, it is more efficient to utilize the optimized infrastructure of n1n.ai to access state-of-the-art models without the overhead of hardware management.

Get a free API key at n1n.ai