Skip to main content

PyTorch Distributed

PyTorch Distributed Data Parallel (DDP) lets you train models across multiple GPUs on a single node by running one process per GPU and synchronizing gradients after each backward pass. On TIR, your Training Cluster node comes fully pre-configured — drivers, CUDA, NCCL, and PyTorch are ready to use.


Environment

Each TIR Training Cluster node comes pre-configured with:

  • PyTorch, CUDA, and NCCL installed and optimized
  • GPU drivers and high-bandwidth interconnects (NVLink or PCIe)
  • Identical software environments across all GPUs

Connect to the Node

ssh $hostname

Shared Storage

All datasets, checkpoints, and logs should be written to the shared directory so they persist after the deployment ends:

/mnt/shared

Training Guide

Step 1: Install Dependencies

The TIR-provided image includes PyTorch. If you are using a custom environment, install it manually:

pip install torch torchvision torchaudio

Step 2: Write a Distributed Training Script

Save the following as train.py. It initializes NCCL, wraps the model in DDP, and runs a training loop across all available GPUs:

import os
import torch
import torch.nn as nn
import torch.optim as optim
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP


def setup():
dist.init_process_group(backend="nccl")


def cleanup():
dist.destroy_process_group()


def main():
setup()

local_rank = int(os.environ["LOCAL_RANK"])
torch.cuda.set_device(local_rank)

# Model, optimizer, loss
model = nn.Linear(10, 1).cuda()
ddp_model = DDP(model, device_ids=[local_rank])

optimizer = optim.SGD(ddp_model.parameters(), lr=0.01)
criterion = nn.MSELoss()

for epoch in range(5):
optimizer.zero_grad()
inputs = torch.randn(32, 10).cuda()
targets = torch.randn(32, 1).cuda()
outputs = ddp_model(inputs)
loss = criterion(outputs, targets)
loss.backward()
optimizer.step()

if local_rank == 0:
print(f"Epoch {epoch} | Loss: {loss.item():.4f}")

cleanup()


if __name__ == "__main__":
main()

Step 3: Launch Training

Use torchrun to spawn one process per GPU automatically:

torchrun --nproc_per_node=4 train.py
FlagDescription
--nproc_per_nodeNumber of GPUs on the node (e.g., 4 or 8)
train.pyYour training script

Data Management

Import Datasets

Place datasets in shared storage before launching training:

/mnt/shared/datasets

Save Checkpoints

Save checkpoints from the rank-0 process only to avoid write conflicts:

if local_rank == 0:
torch.save(model.state_dict(), "/mnt/shared/checkpoints/model_epoch_5.pt")

Monitoring

TensorBoard

Write logs to shared storage and access TensorBoard remotely via SSH port forwarding:

# On the worker node
tensorboard --logdir /mnt/shared/logs --port 6006

# On your local machine
ssh -L 6006:localhost:6006 $hostname

Weights & Biases

pip install wandb
import wandb
wandb.init(project="pytorch-ddp", mode="offline")
wandb.log({"loss": loss.item()})

GPU & System Utilization

watch -n 2 nvidia-smi   # GPU utilization
htop # CPU and memory

Mixed Precision Training

Enable Automatic Mixed Precision (AMP) to reduce GPU memory usage and speed up training:

from torch.amp import autocast, GradScaler

scaler = GradScaler(device="cuda")

for epoch in range(5):
optimizer.zero_grad()
with autocast(device_type="cuda"):
outputs = ddp_model(inputs)
loss = criterion(outputs, targets)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()

Troubleshooting

IssueCauseResolution
CUDA Out of MemoryBatch size too large for GPU memoryReduce batch size or enable AMP
NCCL TimeoutGPU communication failureConfirm all GPUs are visible via nvidia-smi
Disk FullCheckpoints or logs filling /mnt/sharedDelete old files or increase storage quota
Process group init failureLOCAL_RANK not setUse torchrun instead of python to launch

FAQ

Q: Why use DDP instead of DataParallel?

DDP spawns one process per GPU and communicates via NCCL, giving near-linear speedups. DataParallel runs in a single process and is slower due to Python GIL overhead and less efficient gradient synchronization.

Q: How do I save a checkpoint safely with multiple workers?

Only the rank-0 process should write checkpoints. Wrap your save call with if dist.get_rank() == 0: to avoid write conflicts from other workers.

Q: Where are logs and checkpoints stored?

Under /mnt/shared/logs and /mnt/shared/checkpoints by convention. Use these paths to ensure data persists after the deployment ends.