Coverage for train.py: 55%
198 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-21 23:06 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-21 23:06 +0000
1# Copyright 2026 venim1103
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
15import os
16import glob
17import argparse
18# Set CUDA allocator config to reduce fragmentation on 16GB GPUs
19os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
21import torch
22import torch.nn.functional as F
23import time
24import wandb
25from context_config import CONTEXT_LENGTH
26from model import BitMambaLLM, maybe_autocast
27from data import create_dataloaders
28from optim import setup_mamba_optimizers, FGWSD_Scheduler
29from dist_utils import (
30 setup_distributed, cleanup_distributed, is_main_process,
31 wrap_model_ddp, unwrap_model, barrier,
32)
34MODE = "scout"
36if MODE == "scout": 36 ↛ 40line 36 didn't jump to line 40 because the condition on line 36 was always true
37 MODEL_CONFIG = dict(vocab_size=64000, dim=512, n_layers=24, d_state=64, expand=2, use_checkpoint=True)
38 TOTAL_STEPS = 100_000
39 PEAK_LR = 4.5e-4
40elif MODE == "upscaled":
41 # Continued pre-training after SOLAR upscaling (MiniPuzzle-inspired)
42 # Lower LR since model already has pretrained weights; shorter run
43 MODEL_CONFIG = dict(vocab_size=64000, dim=1024, n_layers=64, d_state=128, expand=2,
44 use_checkpoint=True, use_attn=True, attn_pct=0.08)
45 TOTAL_STEPS = 20_000 # Short continued pretrain (5-10B tokens equivalent)
46 PEAK_LR = 1.0e-4 # Lower LR for continued training
47else:
48 MODEL_CONFIG = dict(vocab_size=64000, dim=1024, n_layers=40, d_state=128, expand=2, use_checkpoint=True)
49 TOTAL_STEPS = 1_000_000
50 PEAK_LR = 3.0e-4
52DEVICE = "cuda" if torch.cuda.is_available() else "cpu" # overridden by setup_distributed()
53BATCH_SIZE = 2
54GRAD_ACCUM_STEPS = 8
55SAVE_EVERY = 100 # Checkpoint roughly every 3 hours on Kaggle
57# G10: Use FP16 + GradScaler on Ampere (RTX 3090) for 2x throughput.
58# Set to torch.bfloat16 on Ada Lovelace (RTX 4090) where BF16 == FP16 speed.
59AMP_DTYPE = torch.float16
60CHECKPOINT_DIR = f"checkpoints/bitmamba_{MODE}"
61# Static divisor to keep loss_sum in a safe range for the FP16 GradScaler.
62# Must be a fixed constant — do not use current_ctx here, as it changes between
63# FG-WSD phases and would cause logged loss values to jump discontinuously.
64SAFE_DIVISOR = float(CONTEXT_LENGTH) * BATCH_SIZE
65HALF_CONTEXT_LENGTH = max(1, CONTEXT_LENGTH // 2)
67# FG-WSD: Data quality progression per phase (Nanbeige4-3B §2.2.2)
68# Keep LR flat while progressively increasing data quality
69# Warmup: Mixed | Stable 1: Web-heavy | Stable 2: Code/Logic-heavy | Decay: HQ reasoning only
70# Synthetic data integrated in Stable 2 and Decay (Nemotron-H §2.3, Nanbeige4-3B §2.2)
71#
72# PRE-TRAINING: Generate synthetic data first using synth_data.py:
73# python synth_data.py --strategy diverse_qa --input local_data/train/web --output local_data/synth/web_qa
74# python synth_data.py --strategy distill --input local_data/train/code --output local_data/synth/code_distill
75# python synth_data.py --strategy extract --input local_data/train/web --output local_data/synth/knowledge_extract
76# python synth_data.py --strategy rephrase --input local_data/train/web --output local_data/synth/web_rephrased
77#
78TRAIN_CONFIGS = {
79 "Phase_1": [ # Warmup - diverse mixed data
80 {"name": "formal_logic", "path": "local_data/train/logic", "format": "parquet", "weight": 0.25},
81 {"name": "code", "path": "local_data/train/code", "format": "parquet", "weight": 0.25},
82 {"name": "web", "path": "local_data/train/web", "format": "parquet", "weight": 0.30},
83 {"name": "tool_use", "path": "local_data/train/tools", "format": "parquet", "weight": 0.20}
84 ],
85 "Phase_2": [ # Stable 1 - heavy on web/diversity
86 {"name": "formal_logic", "path": "local_data/train/logic", "format": "parquet", "weight": 0.20},
87 {"name": "code", "path": "local_data/train/code", "format": "parquet", "weight": 0.20},
88 {"name": "web", "path": "local_data/train/web", "format": "parquet", "weight": 0.45},
89 {"name": "tool_use", "path": "local_data/train/tools", "format": "parquet", "weight": 0.15}
90 ],
91 "Phase_3": [ # Stable 2 - heavy on code/logic + synthetic CoT (Nanbeige4-3B)
92 {"name": "formal_logic", "path": "local_data/train/logic", "format": "parquet", "weight": 0.30},
93 {"name": "code", "path": "local_data/train/code", "format": "parquet", "weight": 0.25},
94 {"name": "synth_qa", "path": "local_data/synth/web_qa", "format": "json", "weight": 0.20},
95 {"name": "synth_distill","path": "local_data/synth/code_distill","format": "json","weight": 0.15},
96 {"name": "web", "path": "local_data/train/web", "format": "parquet", "weight": 0.05},
97 {"name": "tool_use", "path": "local_data/train/tools", "format": "parquet", "weight": 0.05}
98 ],
99 "Phase_4": [ # Decay - 100% high-quality reasoning/synthetic (Nemotron-H)
100 {"name": "formal_logic", "path": "local_data/train/logic", "format": "parquet", "weight": 0.25},
101 {"name": "code", "path": "local_data/train/code", "format": "parquet", "weight": 0.20},
102 {"name": "synth_qa", "path": "local_data/synth/web_qa", "format": "json", "weight": 0.20},
103 {"name": "synth_distill","path": "local_data/synth/code_distill","format": "json","weight": 0.15},
104 {"name": "synth_extract","path": "local_data/synth/knowledge_extract","format": "json","weight": 0.10},
105 {"name": "synth_rephrase","path": "local_data/synth/web_rephrased","format": "json","weight": 0.10}
106 ],
107}
109# Legacy single config (used for backward compatibility if needed)
110TRAIN_CONFIG = TRAIN_CONFIGS["Phase_2"]
112# UPDATED: Use half-context for warmup/stable phases, then expand to the shared
113# project max context during decay.
114# Per Nanbeige4-3B and Nemotron-H: expanding context during stable training ruins dynamics
115CURRICULUM_CONFIG = {
116 "peak_lr": PEAK_LR, "end_lr": 1.5e-6,
117 "phases": [
118 {"pct": 0.05, "ctx": HALF_CONTEXT_LENGTH}, # warmup
119 {"pct": 0.40, "ctx": HALF_CONTEXT_LENGTH}, # stable 1
120 {"pct": 0.35, "ctx": HALF_CONTEXT_LENGTH}, # stable 2
121 {"pct": 0.20, "ctx": CONTEXT_LENGTH} # decay (extend context here!)
122 ]
123}
126def _parse_runtime_args():
127 """Parse optional runtime flags without breaking torchrun/ddp launchers."""
128 parser = argparse.ArgumentParser(add_help=True)
129 parser.add_argument(
130 "--amp-dtype",
131 choices=["fp16", "bf16", "auto"],
132 default="fp16",
133 help="AMP dtype policy (default: fp16).",
134 )
135 args, _ = parser.parse_known_args()
136 return args
139def _resolve_amp_dtype(amp_dtype_arg):
140 if amp_dtype_arg == "fp16":
141 return torch.float16
142 if amp_dtype_arg == "bf16":
143 if torch.cuda.is_available() and not torch.cuda.is_bf16_supported():
144 return torch.float16
145 return torch.bfloat16
146 # auto
147 return torch.bfloat16 if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else torch.float16
149def create_seq_idx_batch(cu_seqlens_padded, n_segs, seqlen):
150 """Create per-batch-element seq_idx from padded cu_seqlens.
152 Args:
153 cu_seqlens_padded: (batch_size, max_n_segs) — padded with -1 sentinels
154 n_segs: (batch_size,) — real lengths per element
155 seqlen: int — current context length
157 Returns:
158 seq_idx: (batch_size, seqlen) — int32 tensor on DEVICE
159 """
160 batch_size = cu_seqlens_padded.shape[0]
161 seq_idx = torch.zeros(batch_size, seqlen, dtype=torch.int32, device=DEVICE)
162 for b in range(batch_size):
163 n = n_segs[b].item()
164 cu = cu_seqlens_padded[b, :n]
165 # Filter to boundaries within the current (possibly truncated) context
166 valid = cu[cu <= seqlen]
167 if len(valid) == 0 or valid[-1] != seqlen:
168 valid = torch.cat([valid, torch.tensor([seqlen], dtype=torch.int32)])
169 for i in range(len(valid) - 1):
170 start, end = valid[i].item(), valid[i + 1].item()
171 seq_idx[b, start:end] = i
172 return seq_idx
175def run_training_steps(model, raw_model, optimizers, scheduler,
176 train_loader, scaler, total_steps,
177 checkpoint_dir, device, world_size=1, start_step=0,
178 start_total_tokens=0):
179 """Inner training loop, decoupled from DDP setup, wandb init, and torch.compile.
181 Args:
182 model: DDP-wrapped (or bare) model for forward passes
183 raw_model: Unwrapped model for state_dict saves and output head access
184 optimizers: Tuple of (muon_opt, adam_opt, mamba_core_opt)
185 scheduler: FGWSD_Scheduler instance
186 train_loader: Initial DataLoader (recreated internally on phase change)
187 scaler: torch.amp.GradScaler (disabled on CPU)
188 total_steps: Number of optimizer steps to run
189 checkpoint_dir: Directory to write step checkpoints into
190 device: String device identifier ('cpu' or 'cuda:N')
191 world_size: Number of distributed ranks (1 for single-GPU)
192 start_step: Step to resume training from
193 start_total_tokens: Token counter value to resume from
194 """
195 muon_opt, adam_opt, mamba_core_opt = optimizers
196 data_iter = iter(train_loader)
197 total_tokens = start_total_tokens
198 t0 = time.time()
200 # Use read-only get_lr_and_ctx to prevent redundant state mutations on resume
201 if start_step > 0: 201 ↛ 202line 201 didn't jump to line 202 because the condition on line 201 was never true
202 _, previous_ctx, previous_phase = scheduler.get_lr_and_ctx(start_step)
203 else:
204 previous_ctx, previous_phase = CURRICULUM_CONFIG['phases'][0]['ctx'], "Phase_1"
206 for step in range(start_step, total_steps):
207 current_lr, current_ctx, phase_name = scheduler.step(step)
209 need_reload = False
210 if phase_name != previous_phase and phase_name != "Complete" and previous_phase is not None: 210 ↛ 211line 210 didn't jump to line 211 because the condition on line 210 was never true
211 need_reload = True
212 if is_main_process():
213 print(f" [FG-WSD] Phase changed to {phase_name}: data quality updated")
214 wandb.log({"System/FG_WSD_Phase": phase_name}, step=step)
215 if current_ctx != previous_ctx:
216 need_reload = True
218 if need_reload:
219 train_loader, _ = create_dataloaders(
220 TRAIN_CONFIGS.get(phase_name, TRAIN_CONFIG), tokenizer_path="custom_agentic_tokenizer",
221 max_seq_len=current_ctx, batch_size=BATCH_SIZE
222 )
223 data_iter = iter(train_loader)
224 previous_ctx = current_ctx
226 previous_phase = phase_name
228 for opt in [muon_opt, adam_opt, mamba_core_opt]: opt.zero_grad()
229 accumulated_loss_sum = 0.0
230 accumulated_valid_tokens = 0
232 for _ in range(GRAD_ACCUM_STEPS):
233 # Tell CUDAGraphs a new forward/backward cycle is starting.
234 # Prevents RuntimeError about overwriting static graph memory during grad accumulation.
235 torch.compiler.cudagraph_mark_step_begin()
237 x, y, cu_seqlens, n_segs = next(data_iter)
239 x, y = x.to(device)[:, :current_ctx], y.to(device)[:, :current_ctx]
240 seq_idx = create_seq_idx_batch(cu_seqlens, n_segs, current_ctx)
242 with maybe_autocast(device, amp_dtype=AMP_DTYPE):
243 loss_sum, valid_tokens = model(x, seq_idx=seq_idx, targets=y)
244 safe_loss = loss_sum / SAFE_DIVISOR
245 scaler.scale(safe_loss).backward()
246 accumulated_loss_sum += loss_sum.detach().item()
247 accumulated_valid_tokens += valid_tokens.detach().item()
249 for opt in [muon_opt, adam_opt, mamba_core_opt]: scaler.unscale_(opt)
250 global_valid_tokens = accumulated_valid_tokens
251 global_loss_sum = accumulated_loss_sum
252 if world_size > 1: 252 ↛ 253line 252 didn't jump to line 253 because the condition on line 252 was never true
253 valid_token_tensor = torch.tensor(accumulated_valid_tokens, device=device, dtype=torch.float32)
254 torch.distributed.all_reduce(valid_token_tensor, op=torch.distributed.ReduceOp.SUM)
255 global_valid_tokens = valid_token_tensor.item()
256 loss_sum_tensor = torch.tensor(accumulated_loss_sum, device=device, dtype=torch.float32)
257 torch.distributed.all_reduce(loss_sum_tensor, op=torch.distributed.ReduceOp.SUM)
258 global_loss_sum = loss_sum_tensor.item()
259 grad_scale = (world_size * SAFE_DIVISOR) / max(global_valid_tokens, 1.0)
260 for param in model.parameters():
261 if param.grad is not None: 261 ↛ 262line 261 didn't jump to line 262 because the condition on line 261 was never true
262 param.grad.mul_(grad_scale)
263 torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
265 scaler.step(muon_opt)
266 scaler.step(adam_opt)
267 scaler.step(mamba_core_opt)
268 scaler.update()
270 for opt in [muon_opt, adam_opt, mamba_core_opt]: opt.zero_grad()
272 accumulated_loss = global_loss_sum / max(global_valid_tokens, 1.0)
274 tokens_this_step = BATCH_SIZE * GRAD_ACCUM_STEPS * current_ctx
275 total_tokens += tokens_this_step
277 if step % 10 == 0 and is_main_process():
278 t1 = time.time()
279 dt, t0 = t1 - t0, t1
280 toks = tokens_this_step * world_size
281 print(f"Step {step:06d} | {phase_name:<15} | Loss: {accumulated_loss:.4f} | Tok/s: {toks/dt:.0f}")
282 wandb.log({"Train/Loss": accumulated_loss, "System/Total_Tokens": total_tokens * world_size}, step=step)
284 if step > 0 and step % SAVE_EVERY == 0 and is_main_process():
285 ckpt_path = os.path.join(checkpoint_dir, f"step_{step:06d}.pt")
286 tmp_ckpt_path = ckpt_path + ".tmp"
287 payload = {
288 'step': step,
289 'total_tokens': total_tokens,
290 'model_state_dict': raw_model.state_dict(),
291 'muon_opt_state': muon_opt.state_dict(),
292 'adam_opt_state': adam_opt.state_dict(),
293 'mamba_core_opt_state': mamba_core_opt.state_dict(),
294 'scaler_state': scaler.state_dict(),
295 'wandb_run_id': wandb.run.id if wandb.run else None
296 }
297 try:
298 with open(tmp_ckpt_path, 'wb') as handle:
299 torch.save(payload, handle)
300 handle.flush()
301 os.fsync(handle.fileno())
302 tmp_size_mb = os.path.getsize(tmp_ckpt_path) / (1024**2)
303 print(f" [Save] Temp checkpoint {tmp_size_mb:.2f} MB ready, renaming to final location...")
304 os.replace(tmp_ckpt_path, ckpt_path)
305 print(f" [Save] Checkpoint written atomically to {ckpt_path}")
306 finally:
307 if os.path.exists(tmp_ckpt_path): 307 ↛ 308line 307 didn't jump to line 308 because the condition on line 307 was never true
308 os.remove(tmp_ckpt_path)
309 barrier()
312def load_latest_checkpoint(checkpoint_dir, raw_model, optimizers, scaler, device):
313 """Finds the most recent checkpoint, loads weights & optimizer states safely."""
314 muon_opt, adam_opt, mamba_core_opt = optimizers
315 pt_files = glob.glob(os.path.join(checkpoint_dir, "step_*.pt"))
316 if not pt_files:
317 return 0, None, 0
319 # Sort by step number extracted from filename
320 latest_ckpt = max(pt_files, key=lambda f: int(os.path.basename(f).split('_')[1].split('.')[0]))
321 print(f"Resuming training from latest checkpoint: {latest_ckpt}")
323 # weights_only=False required to load Python objects in optimizer states safely
324 ckpt = torch.load(latest_ckpt, map_location=device, weights_only=False)
325 raw_model.load_state_dict(ckpt['model_state_dict'])
326 if 'muon_opt_state' in ckpt: muon_opt.load_state_dict(ckpt['muon_opt_state'])
327 if 'adam_opt_state' in ckpt: adam_opt.load_state_dict(ckpt['adam_opt_state'])
328 if 'mamba_core_opt_state' in ckpt: mamba_core_opt.load_state_dict(ckpt['mamba_core_opt_state'])
329 if 'scaler_state' in ckpt: scaler.load_state_dict(ckpt['scaler_state'])
331 # Backward compatible with older checkpoints that did not persist this field.
332 total_tokens = int(ckpt.get('total_tokens', 0))
333 return ckpt['step'] + 1, ckpt.get('wandb_run_id'), total_tokens
335def main():
336 global DEVICE, AMP_DTYPE
337 args = _parse_runtime_args()
338 AMP_DTYPE = _resolve_amp_dtype(args.amp_dtype)
340 rank, local_rank, world_size, device = setup_distributed()
341 DEVICE = device # update module-level DEVICE for create_seq_idx_batch
343 os.makedirs(CHECKPOINT_DIR, exist_ok=True)
345 if is_main_process():
346 print(f"Initializing {MODE.upper()} BitMamba Model...")
347 print(f"AMP dtype: {AMP_DTYPE}")
348 model = BitMambaLLM(**MODEL_CONFIG).to(DEVICE)
350 # Load upscaled checkpoint if in continued pretraining mode
351 if MODE == "upscaled":
352 print("Loading upscaled checkpoint for continued pre-training...")
353 # Look for upscaled checkpoint in default location
354 upscale_ckpt = "checkpoints/upscaled/step_000000_1B_mamba.pt"
355 if os.path.exists(upscale_ckpt):
356 ckpt = torch.load(upscale_ckpt, map_location=DEVICE)
357 model.load_state_dict(ckpt['model_state_dict'])
358 print(f"Loaded upscaled weights from {upscale_ckpt}")
359 else:
360 print(f"Warning: Upscaled checkpoint not found at {upscale_ckpt}")
361 print("Please run: python upscale.py first")
363 # IMPORTANT: torch.compile is compatible with gradient checkpointing ONLY when
364 # use_reentrant=False (set in BitMambaLLM._backbone). Do NOT change checkpointing
365 # to use_reentrant=True — it will silently corrupt gradients under compile.
366 # Note: Using mode="default" instead of "reduce-overhead" to disable CUDA Graphs,
367 # which conflict with gradient checkpointing + accumulation.
368 model._backbone = torch.compile(model._backbone, mode="default")
370 # Wrap in DDP after compile but before optimizer creation
371 model = wrap_model_ddp(model, local_rank)
372 raw_model = unwrap_model(model) # for state_dict saves and output head access
374 muon_opt, adam_opt, mamba_core_opt = setup_mamba_optimizers(raw_model, CURRICULUM_CONFIG)
375 scheduler = FGWSD_Scheduler(muon_opt, adam_opt, mamba_core_opt, TOTAL_STEPS, CURRICULUM_CONFIG)
377 # G10: GradScaler for FP16 (no-op when using BF16)
378 scaler = torch.amp.GradScaler(enabled=(DEVICE.startswith("cuda") and AMP_DTYPE == torch.float16))
380 start_step, wandb_run_id, start_total_tokens = 0, None, 0
381 if MODE != "upscaled":
382 start_step, wandb_run_id, start_total_tokens = load_latest_checkpoint(
383 CHECKPOINT_DIR, raw_model, (muon_opt, adam_opt, mamba_core_opt), scaler, DEVICE
384 )
386 if is_main_process():
387 wandb.init(
388 project="Agentic-1.58b-Model",
389 name=f"run-bitmamba-{MODE}",
390 id=wandb_run_id, # Resumes the same loss curve if id exists!
391 resume="allow", # Crucial for preventing fragmented runs
392 config=CURRICULUM_CONFIG
393 )
395 # Initialize correct data phase and context based on starting step using read-only method
396 if start_step > 0:
397 _, start_ctx, start_phase = scheduler.get_lr_and_ctx(start_step)
398 else:
399 start_ctx, start_phase = CURRICULUM_CONFIG['phases'][0]['ctx'], "Phase_1"
401 train_loader, tokenizer = create_dataloaders(
402 TRAIN_CONFIGS.get(start_phase, TRAIN_CONFIGS["Phase_1"]),
403 tokenizer_path="custom_agentic_tokenizer",
404 max_seq_len=start_ctx, batch_size=BATCH_SIZE
405 )
406 model.train()
408 run_training_steps(
409 model=model,
410 raw_model=raw_model,
411 optimizers=(muon_opt, adam_opt, mamba_core_opt),
412 scheduler=scheduler,
413 train_loader=train_loader,
414 scaler=scaler,
415 total_steps=TOTAL_STEPS,
416 checkpoint_dir=CHECKPOINT_DIR,
417 device=DEVICE,
418 world_size=world_size,
419 start_step=start_step,
420 start_total_tokens=start_total_tokens,
421 )
423 if is_main_process():
424 wandb.finish()
425 cleanup_distributed()
427if __name__ == "__main__":
428 main()