Coverage for dist_utils.py: 100%

60 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-21 23:06 +0000

1# Copyright 2026 venim1103 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14 

15""" 

16Distributed training utilities. 

17 

18Auto-detects multi-GPU setups and handles DDP initialization/teardown. 

19 

20Single-GPU: 

21 python train.py # works as before 

22 

23Multi-GPU (2x T4 on Kaggle, etc.): 

24 torchrun --nproc_per_node=2 train.py 

25""" 

26 

27import os 

28import sys 

29import random 

30import numpy as np 

31import torch 

32import torch.distributed as dist 

33from torch.nn.parallel import DistributedDataParallel as DDP 

34 

35 

36def setup_distributed(): 

37 """Initialize distributed training if launched via torchrun, else single-GPU. 

38 

39 Returns: 

40 rank: int — global rank (0 for single-GPU) 

41 local_rank: int — local GPU index (0 for single-GPU / CPU) 

42 world_size: int — total number of processes (1 for single-GPU) 

43 device: str — device string (e.g. 'cuda', 'cuda:1', 'cpu') 

44 """ 

45 if "RANK" in os.environ: 

46 # Launched via torchrun / torch.distributed.launch 

47 rank = int(os.environ["RANK"]) 

48 local_rank = int(os.environ["LOCAL_RANK"]) 

49 world_size = int(os.environ["WORLD_SIZE"]) 

50 dist.init_process_group(backend="nccl") 

51 torch.cuda.set_device(local_rank) 

52 device = f"cuda:{local_rank}" 

53 if rank == 0: 

54 print(f"[DDP] Initialized {world_size} processes (nccl)") 

55 elif torch.cuda.is_available(): 

56 rank = 0 

57 local_rank = 0 

58 world_size = 1 

59 device = "cuda" 

60 n_gpus = torch.cuda.device_count() 

61 if n_gpus > 1: 

62 script = os.path.basename(sys.argv[0]) if sys.argv else "train.py" 

63 print(f"[dist_utils] {n_gpus} GPUs detected but running single-process.") 

64 print(f"[dist_utils] For multi-GPU, launch with: " 

65 f"torchrun --nproc_per_node={n_gpus} {script}") 

66 else: 

67 rank = 0 

68 local_rank = 0 

69 world_size = 1 

70 device = "cpu" 

71 

72 # Seed differently per rank so each process samples different data 

73 seed = 42 + rank 

74 random.seed(seed) 

75 np.random.seed(seed) 

76 torch.manual_seed(seed) 

77 if torch.cuda.is_available(): 

78 torch.cuda.manual_seed(seed) 

79 

80 return rank, local_rank, world_size, device 

81 

82 

83def cleanup_distributed(): 

84 """Destroy the process group if distributed training was initialized.""" 

85 if dist.is_initialized(): 

86 dist.destroy_process_group() 

87 

88 

89def is_main_process(): 

90 """Return True if this is rank 0 (or not distributed).""" 

91 if dist.is_initialized(): 

92 return dist.get_rank() == 0 

93 return True 

94 

95 

96def wrap_model_ddp(model, local_rank): 

97 """Wrap a model in DistributedDataParallel when distributed is active. 

98 

99 Uses static_graph=True for compatibility with torch.compile / CUDA graphs. 

100 Returns the model unwrapped if not distributed. 

101 """ 

102 if dist.is_initialized(): 

103 return DDP(model, device_ids=[local_rank], static_graph=True) 

104 return model 

105 

106 

107def unwrap_model(model): 

108 """Return the underlying module, stripping DDP wrapper if present.""" 

109 if isinstance(model, DDP): 

110 return model.module 

111 return model 

112 

113 

114def barrier(): 

115 """Synchronize all processes. No-op when not distributed.""" 

116 if dist.is_initialized(): 

117 dist.barrier() 

118 

119 

120def get_world_size(): 

121 """Return the world size (1 when not distributed).""" 

122 if dist.is_initialized(): 

123 return dist.get_world_size() 

124 return 1