Scaling LLM Training: Deep Dive into ZeRO and FSDP for Multi-GPU Systems
- Authors

- Name
- Nino
- Occupation
- Senior Tech Editor
As Large Language Models (LLMs) like DeepSeek-V3 and Claude 3.5 Sonnet continue to scale into hundreds of billions of parameters, the primary bottleneck for developers isn't just raw compute power—it is memory. Training or even fine-tuning these models on a single GPU is often impossible because the model states exceed the VRAM capacity of even the most high-end hardware like the H100. To solve this, technical teams rely on advanced distributed training techniques. While platforms like n1n.ai simplify access to these models via high-speed APIs, understanding the underlying infrastructure is crucial for engineers building custom solutions.
The Memory Wall in LLM Training
When training a model with parameters, the memory consumption is not just the parameters themselves. We must account for:
- Model Parameters: In FP16/BF16, this is bytes.
- Gradients: Another bytes in FP16/BF16.
- Optimizer States: For the Adam optimizer, we store a copy of the weights in FP32 (), momentum (), and variance (). Totaling bytes.
For a 7B parameter model, the model states alone require roughly 112GB of VRAM, which is more than the 80GB available on an A100. This is where ZeRO and FSDP become essential.
Understanding ZeRO: Zero Redundancy Optimizer
Introduced by Microsoft Research as part of the DeepSpeed library, ZeRO eliminates memory redundancy across data-parallel processes. In traditional Distributed Data Parallel (DDP), every GPU keeps a full copy of the model states. ZeRO breaks this redundancy in three stages:
ZeRO-1: Optimizer State Sharding
The optimizer states (which occupy the largest chunk of memory) are partitioned across all available GPUs. If you have 8 GPUs, each GPU only stores of the optimizer states. This reduces memory usage significantly without increasing communication overhead.
ZeRO-2: Gradient Sharding
Building on ZeRO-1, this stage also partitions the gradients. After the backward pass, gradients are reduced only on the GPUs that own the corresponding parameters. This further slashes the memory footprint.
ZeRO-3: Parameter Sharding
ZeRO-3 is the most aggressive form. It shards the actual model parameters across GPUs. During the forward and backward passes, the necessary parameters are fetched via all-gather operations and discarded immediately after use. This allows for training models with trillions of parameters, provided the aggregate memory of the cluster is sufficient.
For developers who prefer to focus on application logic rather than managing GPU clusters, n1n.ai provides a unified interface to access models that have already been trained using these advanced techniques.
PyTorch FSDP: The Industry Standard
Fully Sharded Data Parallel (FSDP) is PyTorch's native implementation of the ZeRO-3 concept. It provides a more integrated experience for the PyTorch ecosystem. FSDP shards parameters, gradients, and optimizer states, while also offering features like offloading to CPU for even larger models.
Implementing FSDP in PyTorch
To use FSDP, you wrap your model with the FullyShardedDataParallel class. Below is a simplified implementation structure:
import torch
import torch.nn as nn
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.fully_sharded_data_parallel import (
CPUOffload,
BackwardPrefetch,
)
from torch.distributed.fsdp.wrap import (
size_based_auto_wrap_policy,
enable_wrap,
wrap,
)
def setup_fsdp_model(model, device_id):
# Define a policy to decide which layers to shard
my_auto_wrap_policy = functools.partial(
size_based_auto_wrap_policy, min_num_params=10000
)
# Wrap the model
sharded_model = FSDP(
model,
auto_wrap_policy=my_auto_wrap_policy,
cpu_offload=CPUOffload(offload_params=True),
device_id=device_id
)
return sharded_model
Implementation Logic from Scratch
If we were to conceptualize a ZeRO-style sharding from scratch, the logic follows a specific communication pattern:
- Partitioning: Divide the flat parameter tensor into chunks (where is the number of GPUs).
- Forward Pass: Each GPU broadcasts its shard to all other GPUs (
All-Gather) so that every GPU can compute the forward pass for its specific data batch. - Backward Pass: Similar to the forward pass, parameters are gathered to compute gradients.
- Reduction: Gradients are reduced (
Reduce-Scatter) so that each GPU only receives the gradients for the parameters it owns. - Update: Each GPU updates its local optimizer states and parameter shards.
This cycle ensures that at no point does a single GPU need to hold the entire state of the model and the optimizer simultaneously, except for the transient window of computation.
Pro Tip: Communication vs. Computation
While ZeRO-3/FSDP solves the memory problem, it introduces communication latency. To optimize performance, developers should use BackwardPrefetch to overlap the communication of the next layer's parameters with the current layer's gradient computation. High-speed interconnects like NVLink are vital here. If your local infrastructure cannot support this, using an API aggregator like n1n.ai allows you to leverage these optimizations indirectly through their high-performance inference endpoints.
Comparison Table: DDP vs. ZeRO vs. FSDP
| Feature | DDP | ZeRO-1/2 | ZeRO-3 / FSDP |
|---|---|---|---|
| Parameter Redundancy | High (Full Copy) | High (Full Copy) | Zero (Sharded) |
| Optimizer State Redundancy | High | Zero | Zero |
| Communication Overhead | Low | Medium | High |
| Max Model Size | Limited by 1 GPU | Medium | Limited by Cluster Total |
| Implementation Complexity | Low | Medium | High |
Conclusion
Mastering ZeRO and FSDP is essential for any engineer working on the frontier of AI. By sharding model states, we break the physical limits of individual GPUs, enabling the training of the next generation of LLMs. However, the complexity of managing such clusters is non-trivial. For many production use cases, it is more efficient to utilize the optimized infrastructure of n1n.ai to access state-of-the-art models without the overhead of hardware management.
Get a free API key at n1n.ai