Coverage for dist_utils.py: 100%
60 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-21 23:06 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-21 23:06 +0000
1# Copyright 2026 venim1103
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
15"""
16Distributed training utilities.
18Auto-detects multi-GPU setups and handles DDP initialization/teardown.
20Single-GPU:
21 python train.py # works as before
23Multi-GPU (2x T4 on Kaggle, etc.):
24 torchrun --nproc_per_node=2 train.py
25"""
27import os
28import sys
29import random
30import numpy as np
31import torch
32import torch.distributed as dist
33from torch.nn.parallel import DistributedDataParallel as DDP
36def setup_distributed():
37 """Initialize distributed training if launched via torchrun, else single-GPU.
39 Returns:
40 rank: int — global rank (0 for single-GPU)
41 local_rank: int — local GPU index (0 for single-GPU / CPU)
42 world_size: int — total number of processes (1 for single-GPU)
43 device: str — device string (e.g. 'cuda', 'cuda:1', 'cpu')
44 """
45 if "RANK" in os.environ:
46 # Launched via torchrun / torch.distributed.launch
47 rank = int(os.environ["RANK"])
48 local_rank = int(os.environ["LOCAL_RANK"])
49 world_size = int(os.environ["WORLD_SIZE"])
50 dist.init_process_group(backend="nccl")
51 torch.cuda.set_device(local_rank)
52 device = f"cuda:{local_rank}"
53 if rank == 0:
54 print(f"[DDP] Initialized {world_size} processes (nccl)")
55 elif torch.cuda.is_available():
56 rank = 0
57 local_rank = 0
58 world_size = 1
59 device = "cuda"
60 n_gpus = torch.cuda.device_count()
61 if n_gpus > 1:
62 script = os.path.basename(sys.argv[0]) if sys.argv else "train.py"
63 print(f"[dist_utils] {n_gpus} GPUs detected but running single-process.")
64 print(f"[dist_utils] For multi-GPU, launch with: "
65 f"torchrun --nproc_per_node={n_gpus} {script}")
66 else:
67 rank = 0
68 local_rank = 0
69 world_size = 1
70 device = "cpu"
72 # Seed differently per rank so each process samples different data
73 seed = 42 + rank
74 random.seed(seed)
75 np.random.seed(seed)
76 torch.manual_seed(seed)
77 if torch.cuda.is_available():
78 torch.cuda.manual_seed(seed)
80 return rank, local_rank, world_size, device
83def cleanup_distributed():
84 """Destroy the process group if distributed training was initialized."""
85 if dist.is_initialized():
86 dist.destroy_process_group()
89def is_main_process():
90 """Return True if this is rank 0 (or not distributed)."""
91 if dist.is_initialized():
92 return dist.get_rank() == 0
93 return True
96def wrap_model_ddp(model, local_rank):
97 """Wrap a model in DistributedDataParallel when distributed is active.
99 Uses static_graph=True for compatibility with torch.compile / CUDA graphs.
100 Returns the model unwrapped if not distributed.
101 """
102 if dist.is_initialized():
103 return DDP(model, device_ids=[local_rank], static_graph=True)
104 return model
107def unwrap_model(model):
108 """Return the underlying module, stripping DDP wrapper if present."""
109 if isinstance(model, DDP):
110 return model.module
111 return model
114def barrier():
115 """Synchronize all processes. No-op when not distributed."""
116 if dist.is_initialized():
117 dist.barrier()
120def get_world_size():
121 """Return the world size (1 when not distributed)."""
122 if dist.is_initialized():
123 return dist.get_world_size()
124 return 1