Coverage for validate.py: 31%
214 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-21 23:06 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-21 23:06 +0000
1# Copyright 2026 venim1103
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
15"""Run validation on held-out pillars and log metrics to Weights & Biases.
17Example:
18 python validate.py --mode scout --manifest-path val_data/validation_manifest.json
19 python validate.py --mode scout --checkpoint-path checkpoints/bitmamba_scout/step_010000.pt
20"""
22from __future__ import annotations
24import argparse
25import glob
26import json
27import math
28import os
29import pickle
30import re
31from pathlib import Path
32from typing import Dict, Iterable, List, Sequence, Tuple
34import torch
35import wandb
36from datasets import load_dataset
37from transformers import AutoTokenizer
39from context_config import CONTEXT_LENGTH
40from data import extract_text_from_row
41from model import BitMambaLLM, maybe_autocast
44DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
45MAX_SAFE_EXP = 20.0
48def resolve_model_config(mode: str) -> Dict[str, int | bool | float]:
49 mode = mode.lower()
50 if mode == "scout":
51 return dict(vocab_size=64000, dim=512, n_layers=24, d_state=64, expand=2, use_checkpoint=True)
52 if mode == "parent":
53 return dict(vocab_size=64000, dim=1024, n_layers=40, d_state=128, expand=2, use_checkpoint=True)
54 if mode == "upscaled":
55 return dict(
56 vocab_size=64000,
57 dim=1024,
58 n_layers=64,
59 d_state=128,
60 expand=2,
61 use_checkpoint=True,
62 use_attn=True,
63 attn_pct=0.08,
64 )
65 raise ValueError(f"Unsupported --mode {mode!r}. Expected scout, parent, or upscaled.")
68def parse_checkpoint_step(path: str | Path) -> int:
69 name = Path(path).name
70 match = re.search(r"step_(\d+)\.pt$", name)
71 if match:
72 return int(match.group(1))
73 return -1
76def resolve_default_checkpoint_dir(mode: str) -> str:
77 return str(Path("checkpoints") / f"bitmamba_{mode.lower()}")
80def discover_checkpoints(
81 checkpoint_path: str | None,
82 checkpoint_dir: str,
83 checkpoint_glob: str,
84 max_checkpoints: int | None,
85) -> List[str]:
86 if checkpoint_path:
87 cp = Path(checkpoint_path)
88 if not cp.exists():
89 raise FileNotFoundError(f"Checkpoint path does not exist: {cp}")
90 return [str(cp)]
92 pattern = str(Path(checkpoint_dir) / checkpoint_glob)
93 found = [p for p in glob.glob(pattern) if Path(p).is_file()]
94 if not found:
95 raise FileNotFoundError(
96 f"No checkpoints found with glob {pattern!r}. "
97 "Pass --checkpoint-path or adjust --checkpoint-dir/--checkpoint-glob."
98 )
100 found.sort(key=lambda p: (parse_checkpoint_step(p), p))
101 if max_checkpoints is not None and max_checkpoints > 0:
102 found = found[-max_checkpoints:]
103 return found
106def _resolve_existing_path(path_str: str, manifest_dir: Path) -> Path:
107 candidate = Path(path_str)
108 if candidate.exists(): 108 ↛ 109line 108 didn't jump to line 109 because the condition on line 108 was never true
109 return candidate
111 by_manifest = (manifest_dir / path_str).resolve()
112 if by_manifest.exists():
113 return by_manifest
115 raise FileNotFoundError(
116 f"Could not resolve pillar parquet path {path_str!r}. "
117 f"Checked {candidate} and {by_manifest}."
118 )
121def load_manifest_pillars(manifest_path: str | Path) -> List[Dict[str, str]]:
122 manifest_file = Path(manifest_path)
123 with manifest_file.open("r", encoding="utf-8") as handle:
124 manifest = json.load(handle)
126 pillars = manifest.get("pillars", [])
127 if not pillars: 127 ↛ 128line 127 didn't jump to line 128 because the condition on line 127 was never true
128 raise ValueError("Manifest has no pillars. Expected a non-empty 'pillars' list.")
130 manifest_dir = manifest_file.resolve().parent
131 resolved = []
132 for item in pillars:
133 name = item.get("name")
134 path = item.get("path")
135 if not name or not path: 135 ↛ 136line 135 didn't jump to line 136 because the condition on line 135 was never true
136 raise ValueError("Each pillar in manifest must include 'name' and 'path'.")
137 resolved_path = _resolve_existing_path(path, manifest_dir)
138 resolved.append({"name": str(name), "path": str(resolved_path)})
139 return resolved
142def iter_token_windows(token_ids: Sequence[int], max_seq_len: int) -> Iterable[Tuple[List[int], List[int]]]:
143 if max_seq_len <= 0: 143 ↛ 144line 143 didn't jump to line 144 because the condition on line 143 was never true
144 raise ValueError("max_seq_len must be > 0")
145 if len(token_ids) < 2: 145 ↛ 146line 145 didn't jump to line 146 because the condition on line 145 was never true
146 return
148 upper = len(token_ids) - 1
149 for start in range(0, upper, max_seq_len):
150 window = token_ids[start : start + max_seq_len + 1]
151 if len(window) < 2: 151 ↛ 152line 151 didn't jump to line 152 because the condition on line 151 was never true
152 continue
153 yield window[:-1], window[1:]
156def _run_loss_batch(model, input_batch, target_batch, device: str):
157 max_len = max(len(x) for x in input_batch)
158 batch_size = len(input_batch)
160 input_ids = torch.zeros((batch_size, max_len), dtype=torch.long, device=device)
161 targets = torch.full((batch_size, max_len), -100, dtype=torch.long, device=device)
162 for i, (x_ids, y_ids) in enumerate(zip(input_batch, target_batch)):
163 n = len(x_ids)
164 input_ids[i, :n] = torch.tensor(x_ids, dtype=torch.long, device=device)
165 targets[i, :n] = torch.tensor(y_ids, dtype=torch.long, device=device)
167 with torch.no_grad():
168 with maybe_autocast(device):
169 loss_sum, valid_tokens = model(input_ids, targets=targets)
171 return float(loss_sum.item()), int(valid_tokens.item())
174def evaluate_pillar(
175 model,
176 tokenizer,
177 parquet_path: str,
178 *,
179 device: str,
180 max_seq_len: int,
181 batch_size: int,
182 max_examples: int | None,
183):
184 dataset = load_dataset("parquet", data_files=parquet_path, split="train")
186 loss_sum_total = 0.0
187 token_count_total = 0
188 docs_processed = 0
189 windows_processed = 0
191 pending_inputs: List[List[int]] = []
192 pending_targets: List[List[int]] = []
194 for row_idx, row in enumerate(dataset):
195 if max_examples is not None and row_idx >= max_examples:
196 break
198 text = extract_text_from_row(row)
199 if not text:
200 continue
202 token_ids = tokenizer(text, add_special_tokens=False)["input_ids"]
203 if tokenizer.eos_token_id is not None:
204 token_ids = token_ids + [int(tokenizer.eos_token_id)]
206 has_windows = False
207 for x_ids, y_ids in iter_token_windows(token_ids, max_seq_len=max_seq_len):
208 has_windows = True
209 pending_inputs.append(x_ids)
210 pending_targets.append(y_ids)
212 if len(pending_inputs) >= batch_size:
213 batch_loss_sum, batch_tokens = _run_loss_batch(model, pending_inputs, pending_targets, device)
214 loss_sum_total += batch_loss_sum
215 token_count_total += batch_tokens
216 windows_processed += len(pending_inputs)
217 pending_inputs.clear()
218 pending_targets.clear()
220 if has_windows:
221 docs_processed += 1
223 if pending_inputs:
224 batch_loss_sum, batch_tokens = _run_loss_batch(model, pending_inputs, pending_targets, device)
225 loss_sum_total += batch_loss_sum
226 token_count_total += batch_tokens
227 windows_processed += len(pending_inputs)
229 mean_loss = loss_sum_total / max(token_count_total, 1)
230 perplexity = float(math.exp(min(mean_loss, MAX_SAFE_EXP)))
232 return {
233 "loss": float(mean_loss),
234 "perplexity": perplexity,
235 "tokens": int(token_count_total),
236 "documents": int(docs_processed),
237 "windows": int(windows_processed),
238 }
241def evaluate_checkpoint(model, tokenizer, checkpoint_path: str, pillars, args):
242 try:
243 ckpt = torch.load(checkpoint_path, map_location=args.device, weights_only=False)
244 except (RuntimeError, EOFError, OSError, ValueError, pickle.UnpicklingError) as e:
245 print(f"⚠️ Skipping corrupted checkpoint {checkpoint_path}: {e}")
246 return False
248 if "model_state_dict" not in ckpt: 248 ↛ 252line 248 didn't jump to line 252 because the condition on line 248 was always true
249 print(f"⚠️ Skipping checkpoint {checkpoint_path}: missing 'model_state_dict'")
250 return False
252 try:
253 model.load_state_dict(ckpt["model_state_dict"])
254 model.prepare_for_inference()
255 except Exception as e:
256 print(f"⚠️ Skipping checkpoint {checkpoint_path}: failed to load state dict: {e}")
257 return False
259 checkpoint_step = int(ckpt.get("step", parse_checkpoint_step(checkpoint_path)))
260 checkpoint_tokens = int(ckpt.get("total_tokens", 0))
262 overall_loss_sum = 0.0
263 overall_tokens = 0
265 print(f"Evaluating checkpoint: {checkpoint_path} (step={checkpoint_step})")
266 for pillar in pillars:
267 pillar_name = pillar["name"]
268 pillar_metrics = evaluate_pillar(
269 model,
270 tokenizer,
271 pillar["path"],
272 device=args.device,
273 max_seq_len=args.max_seq_len,
274 batch_size=args.batch_size,
275 max_examples=args.max_examples_per_pillar,
276 )
278 overall_loss_sum += pillar_metrics["loss"] * pillar_metrics["tokens"]
279 overall_tokens += pillar_metrics["tokens"]
281 print(
282 f" {pillar_name:<10} loss={pillar_metrics['loss']:.4f} "
283 f"ppl={pillar_metrics['perplexity']:.2f} tokens={pillar_metrics['tokens']}"
284 )
286 prefix = f"Validation/{pillar_name}"
287 wandb.log(
288 {
289 f"{prefix}/Loss": pillar_metrics["loss"],
290 f"{prefix}/Perplexity": pillar_metrics["perplexity"],
291 f"{prefix}/Tokens": pillar_metrics["tokens"],
292 f"{prefix}/Documents": pillar_metrics["documents"],
293 f"{prefix}/Windows": pillar_metrics["windows"],
294 },
295 step=checkpoint_step,
296 commit=False,
297 )
299 overall_loss = overall_loss_sum / max(overall_tokens, 1)
300 overall_ppl = float(math.exp(min(overall_loss, MAX_SAFE_EXP)))
302 print(f" overall loss={overall_loss:.4f} ppl={overall_ppl:.2f} tokens={overall_tokens}")
303 wandb.log(
304 {
305 "Validation/System/CheckpointStep": checkpoint_step,
306 "Validation/System/TotalTokens": checkpoint_tokens,
307 "Validation/Overall/Loss": overall_loss,
308 "Validation/Overall/Perplexity": overall_ppl,
309 "Validation/Overall/Tokens": overall_tokens,
310 },
311 step=checkpoint_step,
312 )
313 return True
316def parse_args():
317 parser = argparse.ArgumentParser(description="Evaluate checkpoints on validation pillars and log to WandB.")
318 parser.add_argument("--mode", choices=["scout", "parent", "upscaled"], default="scout")
319 parser.add_argument("--manifest-path", default="val_data/validation_manifest.json")
320 parser.add_argument("--tokenizer-path", default="custom_agentic_tokenizer")
322 parser.add_argument("--checkpoint-path", default=None, help="Evaluate a single checkpoint file.")
323 parser.add_argument(
324 "--checkpoint-dir",
325 default=None,
326 help="Directory containing step_*.pt checkpoints. Used when --checkpoint-path is not set.",
327 )
328 parser.add_argument(
329 "--checkpoint-glob",
330 default="step_*.pt",
331 help="Glob pattern inside --checkpoint-dir for checkpoint discovery.",
332 )
333 parser.add_argument(
334 "--max-checkpoints",
335 type=int,
336 default=None,
337 help="If set, evaluate only the latest N discovered checkpoints.",
338 )
340 parser.add_argument("--batch-size", type=int, default=4)
341 parser.add_argument("--max-seq-len", type=int, default=CONTEXT_LENGTH)
342 parser.add_argument(
343 "--max-examples-per-pillar",
344 type=int,
345 default=None,
346 help="Optional cap for faster smoke runs.",
347 )
349 parser.add_argument("--wandb-project", default="Agentic-1.58b-Validation")
350 parser.add_argument("--wandb-name", default=None)
351 parser.add_argument("--wandb-run-id", default=None)
352 parser.add_argument("--wandb-resume", default="allow", choices=["allow", "must", "never", "auto"])
353 parser.add_argument(
354 "--wandb-mode",
355 default=None,
356 choices=["online", "offline", "disabled"],
357 help="Optional WANDB_MODE override. If omitted, keeps current environment setting.",
358 )
359 parser.add_argument("--device", default=DEVICE)
360 return parser.parse_args()
363def main() -> int:
364 args = parse_args()
365 if args.wandb_mode is not None:
366 os.environ["WANDB_MODE"] = args.wandb_mode
368 model_config = resolve_model_config(args.mode)
369 checkpoint_dir = args.checkpoint_dir or resolve_default_checkpoint_dir(args.mode)
370 checkpoints = discover_checkpoints(
371 args.checkpoint_path,
372 checkpoint_dir,
373 args.checkpoint_glob,
374 args.max_checkpoints,
375 )
376 pillars = load_manifest_pillars(args.manifest_path)
378 print(f"Loading tokenizer from {args.tokenizer_path}")
379 tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_path)
380 tokenizer.model_max_length = int(1e9)
382 print(f"Building model for mode={args.mode} on device={args.device}")
383 model = BitMambaLLM(**model_config).to(args.device)
384 model.eval()
386 run_name = args.wandb_name or f"validation-{args.mode}"
387 wandb.init(
388 project=args.wandb_project,
389 name=run_name,
390 id=args.wandb_run_id,
391 resume=args.wandb_resume,
392 settings=wandb.Settings(start_method="thread"),
393 config={
394 "mode": args.mode,
395 "manifest_path": args.manifest_path,
396 "tokenizer_path": args.tokenizer_path,
397 "checkpoint_dir": checkpoint_dir,
398 "checkpoint_glob": args.checkpoint_glob,
399 "checkpoint_count": len(checkpoints),
400 "batch_size": args.batch_size,
401 "max_seq_len": args.max_seq_len,
402 "max_examples_per_pillar": args.max_examples_per_pillar,
403 },
404 )
406 try:
407 for checkpoint_path in checkpoints:
408 success = evaluate_checkpoint(model, tokenizer, checkpoint_path, pillars, args)
409 if not success:
410 continue
411 finally:
412 wandb.finish()
414 print(f"Validation complete for {len(checkpoints)} checkpoint(s).")
415 return 0
418if __name__ == "__main__":
419 raise SystemExit(main())