Coverage for train.py: 55%

198 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-21 23:06 +0000

1# Copyright 2026 venim1103 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14 

15import os 

16import glob 

17import argparse 

18# Set CUDA allocator config to reduce fragmentation on 16GB GPUs 

19os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" 

20 

21import torch 

22import torch.nn.functional as F 

23import time 

24import wandb 

25from context_config import CONTEXT_LENGTH 

26from model import BitMambaLLM, maybe_autocast 

27from data import create_dataloaders 

28from optim import setup_mamba_optimizers, FGWSD_Scheduler 

29from dist_utils import ( 

30 setup_distributed, cleanup_distributed, is_main_process, 

31 wrap_model_ddp, unwrap_model, barrier, 

32) 

33 

34MODE = "scout" 

35 

36if MODE == "scout": 36 ↛ 40line 36 didn't jump to line 40 because the condition on line 36 was always true

37 MODEL_CONFIG = dict(vocab_size=64000, dim=512, n_layers=24, d_state=64, expand=2, use_checkpoint=True) 

38 TOTAL_STEPS = 100_000 

39 PEAK_LR = 4.5e-4 

40elif MODE == "upscaled": 

41 # Continued pre-training after SOLAR upscaling (MiniPuzzle-inspired) 

42 # Lower LR since model already has pretrained weights; shorter run 

43 MODEL_CONFIG = dict(vocab_size=64000, dim=1024, n_layers=64, d_state=128, expand=2, 

44 use_checkpoint=True, use_attn=True, attn_pct=0.08) 

45 TOTAL_STEPS = 20_000 # Short continued pretrain (5-10B tokens equivalent) 

46 PEAK_LR = 1.0e-4 # Lower LR for continued training 

47else: 

48 MODEL_CONFIG = dict(vocab_size=64000, dim=1024, n_layers=40, d_state=128, expand=2, use_checkpoint=True) 

49 TOTAL_STEPS = 1_000_000 

50 PEAK_LR = 3.0e-4 

51 

52DEVICE = "cuda" if torch.cuda.is_available() else "cpu" # overridden by setup_distributed() 

53BATCH_SIZE = 2 

54GRAD_ACCUM_STEPS = 8 

55SAVE_EVERY = 100 # Checkpoint roughly every 3 hours on Kaggle 

56 

57# G10: Use FP16 + GradScaler on Ampere (RTX 3090) for 2x throughput. 

58# Set to torch.bfloat16 on Ada Lovelace (RTX 4090) where BF16 == FP16 speed. 

59AMP_DTYPE = torch.float16 

60CHECKPOINT_DIR = f"checkpoints/bitmamba_{MODE}" 

61# Static divisor to keep loss_sum in a safe range for the FP16 GradScaler. 

62# Must be a fixed constant — do not use current_ctx here, as it changes between 

63# FG-WSD phases and would cause logged loss values to jump discontinuously. 

64SAFE_DIVISOR = float(CONTEXT_LENGTH) * BATCH_SIZE 

65HALF_CONTEXT_LENGTH = max(1, CONTEXT_LENGTH // 2) 

66 

67# FG-WSD: Data quality progression per phase (Nanbeige4-3B §2.2.2) 

68# Keep LR flat while progressively increasing data quality 

69# Warmup: Mixed | Stable 1: Web-heavy | Stable 2: Code/Logic-heavy | Decay: HQ reasoning only 

70# Synthetic data integrated in Stable 2 and Decay (Nemotron-H §2.3, Nanbeige4-3B §2.2) 

71# 

72# PRE-TRAINING: Generate synthetic data first using synth_data.py: 

73# python synth_data.py --strategy diverse_qa --input local_data/train/web --output local_data/synth/web_qa 

74# python synth_data.py --strategy distill --input local_data/train/code --output local_data/synth/code_distill 

75# python synth_data.py --strategy extract --input local_data/train/web --output local_data/synth/knowledge_extract 

76# python synth_data.py --strategy rephrase --input local_data/train/web --output local_data/synth/web_rephrased 

77# 

78TRAIN_CONFIGS = { 

79 "Phase_1": [ # Warmup - diverse mixed data 

80 {"name": "formal_logic", "path": "local_data/train/logic", "format": "parquet", "weight": 0.25}, 

81 {"name": "code", "path": "local_data/train/code", "format": "parquet", "weight": 0.25}, 

82 {"name": "web", "path": "local_data/train/web", "format": "parquet", "weight": 0.30}, 

83 {"name": "tool_use", "path": "local_data/train/tools", "format": "parquet", "weight": 0.20} 

84 ], 

85 "Phase_2": [ # Stable 1 - heavy on web/diversity 

86 {"name": "formal_logic", "path": "local_data/train/logic", "format": "parquet", "weight": 0.20}, 

87 {"name": "code", "path": "local_data/train/code", "format": "parquet", "weight": 0.20}, 

88 {"name": "web", "path": "local_data/train/web", "format": "parquet", "weight": 0.45}, 

89 {"name": "tool_use", "path": "local_data/train/tools", "format": "parquet", "weight": 0.15} 

90 ], 

91 "Phase_3": [ # Stable 2 - heavy on code/logic + synthetic CoT (Nanbeige4-3B) 

92 {"name": "formal_logic", "path": "local_data/train/logic", "format": "parquet", "weight": 0.30}, 

93 {"name": "code", "path": "local_data/train/code", "format": "parquet", "weight": 0.25}, 

94 {"name": "synth_qa", "path": "local_data/synth/web_qa", "format": "json", "weight": 0.20}, 

95 {"name": "synth_distill","path": "local_data/synth/code_distill","format": "json","weight": 0.15}, 

96 {"name": "web", "path": "local_data/train/web", "format": "parquet", "weight": 0.05}, 

97 {"name": "tool_use", "path": "local_data/train/tools", "format": "parquet", "weight": 0.05} 

98 ], 

99 "Phase_4": [ # Decay - 100% high-quality reasoning/synthetic (Nemotron-H) 

100 {"name": "formal_logic", "path": "local_data/train/logic", "format": "parquet", "weight": 0.25}, 

101 {"name": "code", "path": "local_data/train/code", "format": "parquet", "weight": 0.20}, 

102 {"name": "synth_qa", "path": "local_data/synth/web_qa", "format": "json", "weight": 0.20}, 

103 {"name": "synth_distill","path": "local_data/synth/code_distill","format": "json","weight": 0.15}, 

104 {"name": "synth_extract","path": "local_data/synth/knowledge_extract","format": "json","weight": 0.10}, 

105 {"name": "synth_rephrase","path": "local_data/synth/web_rephrased","format": "json","weight": 0.10} 

106 ], 

107} 

108 

109# Legacy single config (used for backward compatibility if needed) 

110TRAIN_CONFIG = TRAIN_CONFIGS["Phase_2"] 

111 

112# UPDATED: Use half-context for warmup/stable phases, then expand to the shared 

113# project max context during decay. 

114# Per Nanbeige4-3B and Nemotron-H: expanding context during stable training ruins dynamics 

115CURRICULUM_CONFIG = { 

116 "peak_lr": PEAK_LR, "end_lr": 1.5e-6, 

117 "phases": [ 

118 {"pct": 0.05, "ctx": HALF_CONTEXT_LENGTH}, # warmup 

119 {"pct": 0.40, "ctx": HALF_CONTEXT_LENGTH}, # stable 1 

120 {"pct": 0.35, "ctx": HALF_CONTEXT_LENGTH}, # stable 2 

121 {"pct": 0.20, "ctx": CONTEXT_LENGTH} # decay (extend context here!) 

122 ] 

123} 

124 

125 

126def _parse_runtime_args(): 

127 """Parse optional runtime flags without breaking torchrun/ddp launchers.""" 

128 parser = argparse.ArgumentParser(add_help=True) 

129 parser.add_argument( 

130 "--amp-dtype", 

131 choices=["fp16", "bf16", "auto"], 

132 default="fp16", 

133 help="AMP dtype policy (default: fp16).", 

134 ) 

135 args, _ = parser.parse_known_args() 

136 return args 

137 

138 

139def _resolve_amp_dtype(amp_dtype_arg): 

140 if amp_dtype_arg == "fp16": 

141 return torch.float16 

142 if amp_dtype_arg == "bf16": 

143 if torch.cuda.is_available() and not torch.cuda.is_bf16_supported(): 

144 return torch.float16 

145 return torch.bfloat16 

146 # auto 

147 return torch.bfloat16 if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else torch.float16 

148 

149def create_seq_idx_batch(cu_seqlens_padded, n_segs, seqlen): 

150 """Create per-batch-element seq_idx from padded cu_seqlens. 

151 

152 Args: 

153 cu_seqlens_padded: (batch_size, max_n_segs) — padded with -1 sentinels 

154 n_segs: (batch_size,) — real lengths per element 

155 seqlen: int — current context length 

156 

157 Returns: 

158 seq_idx: (batch_size, seqlen) — int32 tensor on DEVICE 

159 """ 

160 batch_size = cu_seqlens_padded.shape[0] 

161 seq_idx = torch.zeros(batch_size, seqlen, dtype=torch.int32, device=DEVICE) 

162 for b in range(batch_size): 

163 n = n_segs[b].item() 

164 cu = cu_seqlens_padded[b, :n] 

165 # Filter to boundaries within the current (possibly truncated) context 

166 valid = cu[cu <= seqlen] 

167 if len(valid) == 0 or valid[-1] != seqlen: 

168 valid = torch.cat([valid, torch.tensor([seqlen], dtype=torch.int32)]) 

169 for i in range(len(valid) - 1): 

170 start, end = valid[i].item(), valid[i + 1].item() 

171 seq_idx[b, start:end] = i 

172 return seq_idx 

173 

174 

175def run_training_steps(model, raw_model, optimizers, scheduler, 

176 train_loader, scaler, total_steps, 

177 checkpoint_dir, device, world_size=1, start_step=0, 

178 start_total_tokens=0): 

179 """Inner training loop, decoupled from DDP setup, wandb init, and torch.compile. 

180 

181 Args: 

182 model: DDP-wrapped (or bare) model for forward passes 

183 raw_model: Unwrapped model for state_dict saves and output head access 

184 optimizers: Tuple of (muon_opt, adam_opt, mamba_core_opt) 

185 scheduler: FGWSD_Scheduler instance 

186 train_loader: Initial DataLoader (recreated internally on phase change) 

187 scaler: torch.amp.GradScaler (disabled on CPU) 

188 total_steps: Number of optimizer steps to run 

189 checkpoint_dir: Directory to write step checkpoints into 

190 device: String device identifier ('cpu' or 'cuda:N') 

191 world_size: Number of distributed ranks (1 for single-GPU) 

192 start_step: Step to resume training from 

193 start_total_tokens: Token counter value to resume from 

194 """ 

195 muon_opt, adam_opt, mamba_core_opt = optimizers 

196 data_iter = iter(train_loader) 

197 total_tokens = start_total_tokens 

198 t0 = time.time() 

199 

200 # Use read-only get_lr_and_ctx to prevent redundant state mutations on resume 

201 if start_step > 0: 201 ↛ 202line 201 didn't jump to line 202 because the condition on line 201 was never true

202 _, previous_ctx, previous_phase = scheduler.get_lr_and_ctx(start_step) 

203 else: 

204 previous_ctx, previous_phase = CURRICULUM_CONFIG['phases'][0]['ctx'], "Phase_1" 

205 

206 for step in range(start_step, total_steps): 

207 current_lr, current_ctx, phase_name = scheduler.step(step) 

208 

209 need_reload = False 

210 if phase_name != previous_phase and phase_name != "Complete" and previous_phase is not None: 210 ↛ 211line 210 didn't jump to line 211 because the condition on line 210 was never true

211 need_reload = True 

212 if is_main_process(): 

213 print(f" [FG-WSD] Phase changed to {phase_name}: data quality updated") 

214 wandb.log({"System/FG_WSD_Phase": phase_name}, step=step) 

215 if current_ctx != previous_ctx: 

216 need_reload = True 

217 

218 if need_reload: 

219 train_loader, _ = create_dataloaders( 

220 TRAIN_CONFIGS.get(phase_name, TRAIN_CONFIG), tokenizer_path="custom_agentic_tokenizer", 

221 max_seq_len=current_ctx, batch_size=BATCH_SIZE 

222 ) 

223 data_iter = iter(train_loader) 

224 previous_ctx = current_ctx 

225 

226 previous_phase = phase_name 

227 

228 for opt in [muon_opt, adam_opt, mamba_core_opt]: opt.zero_grad() 

229 accumulated_loss_sum = 0.0 

230 accumulated_valid_tokens = 0 

231 

232 for _ in range(GRAD_ACCUM_STEPS): 

233 # Tell CUDAGraphs a new forward/backward cycle is starting. 

234 # Prevents RuntimeError about overwriting static graph memory during grad accumulation. 

235 torch.compiler.cudagraph_mark_step_begin() 

236 

237 x, y, cu_seqlens, n_segs = next(data_iter) 

238 

239 x, y = x.to(device)[:, :current_ctx], y.to(device)[:, :current_ctx] 

240 seq_idx = create_seq_idx_batch(cu_seqlens, n_segs, current_ctx) 

241 

242 with maybe_autocast(device, amp_dtype=AMP_DTYPE): 

243 loss_sum, valid_tokens = model(x, seq_idx=seq_idx, targets=y) 

244 safe_loss = loss_sum / SAFE_DIVISOR 

245 scaler.scale(safe_loss).backward() 

246 accumulated_loss_sum += loss_sum.detach().item() 

247 accumulated_valid_tokens += valid_tokens.detach().item() 

248 

249 for opt in [muon_opt, adam_opt, mamba_core_opt]: scaler.unscale_(opt) 

250 global_valid_tokens = accumulated_valid_tokens 

251 global_loss_sum = accumulated_loss_sum 

252 if world_size > 1: 252 ↛ 253line 252 didn't jump to line 253 because the condition on line 252 was never true

253 valid_token_tensor = torch.tensor(accumulated_valid_tokens, device=device, dtype=torch.float32) 

254 torch.distributed.all_reduce(valid_token_tensor, op=torch.distributed.ReduceOp.SUM) 

255 global_valid_tokens = valid_token_tensor.item() 

256 loss_sum_tensor = torch.tensor(accumulated_loss_sum, device=device, dtype=torch.float32) 

257 torch.distributed.all_reduce(loss_sum_tensor, op=torch.distributed.ReduceOp.SUM) 

258 global_loss_sum = loss_sum_tensor.item() 

259 grad_scale = (world_size * SAFE_DIVISOR) / max(global_valid_tokens, 1.0) 

260 for param in model.parameters(): 

261 if param.grad is not None: 261 ↛ 262line 261 didn't jump to line 262 because the condition on line 261 was never true

262 param.grad.mul_(grad_scale) 

263 torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) 

264 

265 scaler.step(muon_opt) 

266 scaler.step(adam_opt) 

267 scaler.step(mamba_core_opt) 

268 scaler.update() 

269 

270 for opt in [muon_opt, adam_opt, mamba_core_opt]: opt.zero_grad() 

271 

272 accumulated_loss = global_loss_sum / max(global_valid_tokens, 1.0) 

273 

274 tokens_this_step = BATCH_SIZE * GRAD_ACCUM_STEPS * current_ctx 

275 total_tokens += tokens_this_step 

276 

277 if step % 10 == 0 and is_main_process(): 

278 t1 = time.time() 

279 dt, t0 = t1 - t0, t1 

280 toks = tokens_this_step * world_size 

281 print(f"Step {step:06d} | {phase_name:<15} | Loss: {accumulated_loss:.4f} | Tok/s: {toks/dt:.0f}") 

282 wandb.log({"Train/Loss": accumulated_loss, "System/Total_Tokens": total_tokens * world_size}, step=step) 

283 

284 if step > 0 and step % SAVE_EVERY == 0 and is_main_process(): 

285 ckpt_path = os.path.join(checkpoint_dir, f"step_{step:06d}.pt") 

286 tmp_ckpt_path = ckpt_path + ".tmp" 

287 payload = { 

288 'step': step, 

289 'total_tokens': total_tokens, 

290 'model_state_dict': raw_model.state_dict(), 

291 'muon_opt_state': muon_opt.state_dict(), 

292 'adam_opt_state': adam_opt.state_dict(), 

293 'mamba_core_opt_state': mamba_core_opt.state_dict(), 

294 'scaler_state': scaler.state_dict(), 

295 'wandb_run_id': wandb.run.id if wandb.run else None 

296 } 

297 try: 

298 with open(tmp_ckpt_path, 'wb') as handle: 

299 torch.save(payload, handle) 

300 handle.flush() 

301 os.fsync(handle.fileno()) 

302 tmp_size_mb = os.path.getsize(tmp_ckpt_path) / (1024**2) 

303 print(f" [Save] Temp checkpoint {tmp_size_mb:.2f} MB ready, renaming to final location...") 

304 os.replace(tmp_ckpt_path, ckpt_path) 

305 print(f" [Save] Checkpoint written atomically to {ckpt_path}") 

306 finally: 

307 if os.path.exists(tmp_ckpt_path): 307 ↛ 308line 307 didn't jump to line 308 because the condition on line 307 was never true

308 os.remove(tmp_ckpt_path) 

309 barrier() 

310 

311 

312def load_latest_checkpoint(checkpoint_dir, raw_model, optimizers, scaler, device): 

313 """Finds the most recent checkpoint, loads weights & optimizer states safely.""" 

314 muon_opt, adam_opt, mamba_core_opt = optimizers 

315 pt_files = glob.glob(os.path.join(checkpoint_dir, "step_*.pt")) 

316 if not pt_files: 

317 return 0, None, 0 

318 

319 # Sort by step number extracted from filename 

320 latest_ckpt = max(pt_files, key=lambda f: int(os.path.basename(f).split('_')[1].split('.')[0])) 

321 print(f"Resuming training from latest checkpoint: {latest_ckpt}") 

322 

323 # weights_only=False required to load Python objects in optimizer states safely 

324 ckpt = torch.load(latest_ckpt, map_location=device, weights_only=False) 

325 raw_model.load_state_dict(ckpt['model_state_dict']) 

326 if 'muon_opt_state' in ckpt: muon_opt.load_state_dict(ckpt['muon_opt_state']) 

327 if 'adam_opt_state' in ckpt: adam_opt.load_state_dict(ckpt['adam_opt_state']) 

328 if 'mamba_core_opt_state' in ckpt: mamba_core_opt.load_state_dict(ckpt['mamba_core_opt_state']) 

329 if 'scaler_state' in ckpt: scaler.load_state_dict(ckpt['scaler_state']) 

330 

331 # Backward compatible with older checkpoints that did not persist this field. 

332 total_tokens = int(ckpt.get('total_tokens', 0)) 

333 return ckpt['step'] + 1, ckpt.get('wandb_run_id'), total_tokens 

334 

335def main(): 

336 global DEVICE, AMP_DTYPE 

337 args = _parse_runtime_args() 

338 AMP_DTYPE = _resolve_amp_dtype(args.amp_dtype) 

339 

340 rank, local_rank, world_size, device = setup_distributed() 

341 DEVICE = device # update module-level DEVICE for create_seq_idx_batch 

342 

343 os.makedirs(CHECKPOINT_DIR, exist_ok=True) 

344 

345 if is_main_process(): 

346 print(f"Initializing {MODE.upper()} BitMamba Model...") 

347 print(f"AMP dtype: {AMP_DTYPE}") 

348 model = BitMambaLLM(**MODEL_CONFIG).to(DEVICE) 

349 

350 # Load upscaled checkpoint if in continued pretraining mode 

351 if MODE == "upscaled": 

352 print("Loading upscaled checkpoint for continued pre-training...") 

353 # Look for upscaled checkpoint in default location 

354 upscale_ckpt = "checkpoints/upscaled/step_000000_1B_mamba.pt" 

355 if os.path.exists(upscale_ckpt): 

356 ckpt = torch.load(upscale_ckpt, map_location=DEVICE) 

357 model.load_state_dict(ckpt['model_state_dict']) 

358 print(f"Loaded upscaled weights from {upscale_ckpt}") 

359 else: 

360 print(f"Warning: Upscaled checkpoint not found at {upscale_ckpt}") 

361 print("Please run: python upscale.py first") 

362 

363 # IMPORTANT: torch.compile is compatible with gradient checkpointing ONLY when 

364 # use_reentrant=False (set in BitMambaLLM._backbone). Do NOT change checkpointing 

365 # to use_reentrant=True — it will silently corrupt gradients under compile. 

366 # Note: Using mode="default" instead of "reduce-overhead" to disable CUDA Graphs, 

367 # which conflict with gradient checkpointing + accumulation. 

368 model._backbone = torch.compile(model._backbone, mode="default") 

369 

370 # Wrap in DDP after compile but before optimizer creation 

371 model = wrap_model_ddp(model, local_rank) 

372 raw_model = unwrap_model(model) # for state_dict saves and output head access 

373 

374 muon_opt, adam_opt, mamba_core_opt = setup_mamba_optimizers(raw_model, CURRICULUM_CONFIG) 

375 scheduler = FGWSD_Scheduler(muon_opt, adam_opt, mamba_core_opt, TOTAL_STEPS, CURRICULUM_CONFIG) 

376 

377 # G10: GradScaler for FP16 (no-op when using BF16) 

378 scaler = torch.amp.GradScaler(enabled=(DEVICE.startswith("cuda") and AMP_DTYPE == torch.float16)) 

379 

380 start_step, wandb_run_id, start_total_tokens = 0, None, 0 

381 if MODE != "upscaled": 

382 start_step, wandb_run_id, start_total_tokens = load_latest_checkpoint( 

383 CHECKPOINT_DIR, raw_model, (muon_opt, adam_opt, mamba_core_opt), scaler, DEVICE 

384 ) 

385 

386 if is_main_process(): 

387 wandb.init( 

388 project="Agentic-1.58b-Model", 

389 name=f"run-bitmamba-{MODE}", 

390 id=wandb_run_id, # Resumes the same loss curve if id exists! 

391 resume="allow", # Crucial for preventing fragmented runs 

392 config=CURRICULUM_CONFIG 

393 ) 

394 

395 # Initialize correct data phase and context based on starting step using read-only method 

396 if start_step > 0: 

397 _, start_ctx, start_phase = scheduler.get_lr_and_ctx(start_step) 

398 else: 

399 start_ctx, start_phase = CURRICULUM_CONFIG['phases'][0]['ctx'], "Phase_1" 

400 

401 train_loader, tokenizer = create_dataloaders( 

402 TRAIN_CONFIGS.get(start_phase, TRAIN_CONFIGS["Phase_1"]), 

403 tokenizer_path="custom_agentic_tokenizer", 

404 max_seq_len=start_ctx, batch_size=BATCH_SIZE 

405 ) 

406 model.train() 

407 

408 run_training_steps( 

409 model=model, 

410 raw_model=raw_model, 

411 optimizers=(muon_opt, adam_opt, mamba_core_opt), 

412 scheduler=scheduler, 

413 train_loader=train_loader, 

414 scaler=scaler, 

415 total_steps=TOTAL_STEPS, 

416 checkpoint_dir=CHECKPOINT_DIR, 

417 device=DEVICE, 

418 world_size=world_size, 

419 start_step=start_step, 

420 start_total_tokens=start_total_tokens, 

421 ) 

422 

423 if is_main_process(): 

424 wandb.finish() 

425 cleanup_distributed() 

426 

427if __name__ == "__main__": 

428 main()