PyTorch Distributed
PyTorch Distributed Data Parallel (DDP) lets you train models across multiple GPUs on a single node by running one process per GPU and synchronizing gradients after each backward pass. On TIR, your Training Cluster node comes fully pre-configured — drivers, CUDA, NCCL, and PyTorch are ready to use.
Environment
Each TIR Training Cluster node comes pre-configured with:
- PyTorch, CUDA, and NCCL installed and optimized
- GPU drivers and high-bandwidth interconnects (NVLink or PCIe)
- Identical software environments across all GPUs
Connect to the Node
ssh $hostname
Shared Storage
All datasets, checkpoints, and logs should be written to the shared directory so they persist after the deployment ends:
/mnt/shared
Training Guide
Step 1: Install Dependencies
The TIR-provided image includes PyTorch. If you are using a custom environment, install it manually:
pip install torch torchvision torchaudio
Step 2: Write a Distributed Training Script
Save the following as train.py. It initializes NCCL, wraps the model in DDP, and runs a training loop across all available GPUs:
import os
import torch
import torch.nn as nn
import torch.optim as optim
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
def setup():
dist.init_process_group(backend="nccl")
def cleanup():
dist.destroy_process_group()
def main():
setup()
local_rank = int(os.environ["LOCAL_RANK"])
torch.cuda.set_device(local_rank)
# Model, optimizer, loss
model = nn.Linear(10, 1).cuda()
ddp_model = DDP(model, device_ids=[local_rank])
optimizer = optim.SGD(ddp_model.parameters(), lr=0.01)
criterion = nn.MSELoss()
for epoch in range(5):
optimizer.zero_grad()
inputs = torch.randn(32, 10).cuda()
targets = torch.randn(32, 1).cuda()
outputs = ddp_model(inputs)
loss = criterion(outputs, targets)
loss.backward()
optimizer.step()
if local_rank == 0:
print(f"Epoch {epoch} | Loss: {loss.item():.4f}")
cleanup()
if __name__ == "__main__":
main()
Step 3: Launch Training
Use torchrun to spawn one process per GPU automatically:
torchrun --nproc_per_node=4 train.py
| Flag | Description |
|---|---|
--nproc_per_node | Number of GPUs on the node (e.g., 4 or 8) |
train.py | Your training script |
Data Management
Import Datasets
Place datasets in shared storage before launching training:
/mnt/shared/datasets
Save Checkpoints
Save checkpoints from the rank-0 process only to avoid write conflicts:
if local_rank == 0:
torch.save(model.state_dict(), "/mnt/shared/checkpoints/model_epoch_5.pt")
Monitoring
TensorBoard
Write logs to shared storage and access TensorBoard remotely via SSH port forwarding:
# On the worker node
tensorboard --logdir /mnt/shared/logs --port 6006
# On your local machine
ssh -L 6006:localhost:6006 $hostname
Weights & Biases
pip install wandb
import wandb
wandb.init(project="pytorch-ddp", mode="offline")
wandb.log({"loss": loss.item()})
GPU & System Utilization
watch -n 2 nvidia-smi # GPU utilization
htop # CPU and memory
Mixed Precision Training
Enable Automatic Mixed Precision (AMP) to reduce GPU memory usage and speed up training:
from torch.amp import autocast, GradScaler
scaler = GradScaler(device="cuda")
for epoch in range(5):
optimizer.zero_grad()
with autocast(device_type="cuda"):
outputs = ddp_model(inputs)
loss = criterion(outputs, targets)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
Troubleshooting
| Issue | Cause | Resolution |
|---|---|---|
| CUDA Out of Memory | Batch size too large for GPU memory | Reduce batch size or enable AMP |
| NCCL Timeout | GPU communication failure | Confirm all GPUs are visible via nvidia-smi |
| Disk Full | Checkpoints or logs filling /mnt/shared | Delete old files or increase storage quota |
| Process group init failure | LOCAL_RANK not set | Use torchrun instead of python to launch |
FAQ
Q: Why use DDP instead of DataParallel?
DDP spawns one process per GPU and communicates via NCCL, giving near-linear speedups. DataParallel runs in a single process and is slower due to Python GIL overhead and less efficient gradient synchronization.
Q: How do I save a checkpoint safely with multiple workers?
Only the rank-0 process should write checkpoints. Wrap your save call with if dist.get_rank() == 0: to avoid write conflicts from other workers.
Q: Where are logs and checkpoints stored?
Under /mnt/shared/logs and /mnt/shared/checkpoints by convention. Use these paths to ensure data persists after the deployment ends.