Coverage for validate.py: 31%

214 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-21 23:06 +0000

1# Copyright 2026 venim1103 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14 

15"""Run validation on held-out pillars and log metrics to Weights & Biases. 

16 

17Example: 

18 python validate.py --mode scout --manifest-path val_data/validation_manifest.json 

19 python validate.py --mode scout --checkpoint-path checkpoints/bitmamba_scout/step_010000.pt 

20""" 

21 

22from __future__ import annotations 

23 

24import argparse 

25import glob 

26import json 

27import math 

28import os 

29import pickle 

30import re 

31from pathlib import Path 

32from typing import Dict, Iterable, List, Sequence, Tuple 

33 

34import torch 

35import wandb 

36from datasets import load_dataset 

37from transformers import AutoTokenizer 

38 

39from context_config import CONTEXT_LENGTH 

40from data import extract_text_from_row 

41from model import BitMambaLLM, maybe_autocast 

42 

43 

44DEVICE = "cuda" if torch.cuda.is_available() else "cpu" 

45MAX_SAFE_EXP = 20.0 

46 

47 

48def resolve_model_config(mode: str) -> Dict[str, int | bool | float]: 

49 mode = mode.lower() 

50 if mode == "scout": 

51 return dict(vocab_size=64000, dim=512, n_layers=24, d_state=64, expand=2, use_checkpoint=True) 

52 if mode == "parent": 

53 return dict(vocab_size=64000, dim=1024, n_layers=40, d_state=128, expand=2, use_checkpoint=True) 

54 if mode == "upscaled": 

55 return dict( 

56 vocab_size=64000, 

57 dim=1024, 

58 n_layers=64, 

59 d_state=128, 

60 expand=2, 

61 use_checkpoint=True, 

62 use_attn=True, 

63 attn_pct=0.08, 

64 ) 

65 raise ValueError(f"Unsupported --mode {mode!r}. Expected scout, parent, or upscaled.") 

66 

67 

68def parse_checkpoint_step(path: str | Path) -> int: 

69 name = Path(path).name 

70 match = re.search(r"step_(\d+)\.pt$", name) 

71 if match: 

72 return int(match.group(1)) 

73 return -1 

74 

75 

76def resolve_default_checkpoint_dir(mode: str) -> str: 

77 return str(Path("checkpoints") / f"bitmamba_{mode.lower()}") 

78 

79 

80def discover_checkpoints( 

81 checkpoint_path: str | None, 

82 checkpoint_dir: str, 

83 checkpoint_glob: str, 

84 max_checkpoints: int | None, 

85) -> List[str]: 

86 if checkpoint_path: 

87 cp = Path(checkpoint_path) 

88 if not cp.exists(): 

89 raise FileNotFoundError(f"Checkpoint path does not exist: {cp}") 

90 return [str(cp)] 

91 

92 pattern = str(Path(checkpoint_dir) / checkpoint_glob) 

93 found = [p for p in glob.glob(pattern) if Path(p).is_file()] 

94 if not found: 

95 raise FileNotFoundError( 

96 f"No checkpoints found with glob {pattern!r}. " 

97 "Pass --checkpoint-path or adjust --checkpoint-dir/--checkpoint-glob." 

98 ) 

99 

100 found.sort(key=lambda p: (parse_checkpoint_step(p), p)) 

101 if max_checkpoints is not None and max_checkpoints > 0: 

102 found = found[-max_checkpoints:] 

103 return found 

104 

105 

106def _resolve_existing_path(path_str: str, manifest_dir: Path) -> Path: 

107 candidate = Path(path_str) 

108 if candidate.exists(): 108 ↛ 109line 108 didn't jump to line 109 because the condition on line 108 was never true

109 return candidate 

110 

111 by_manifest = (manifest_dir / path_str).resolve() 

112 if by_manifest.exists(): 

113 return by_manifest 

114 

115 raise FileNotFoundError( 

116 f"Could not resolve pillar parquet path {path_str!r}. " 

117 f"Checked {candidate} and {by_manifest}." 

118 ) 

119 

120 

121def load_manifest_pillars(manifest_path: str | Path) -> List[Dict[str, str]]: 

122 manifest_file = Path(manifest_path) 

123 with manifest_file.open("r", encoding="utf-8") as handle: 

124 manifest = json.load(handle) 

125 

126 pillars = manifest.get("pillars", []) 

127 if not pillars: 127 ↛ 128line 127 didn't jump to line 128 because the condition on line 127 was never true

128 raise ValueError("Manifest has no pillars. Expected a non-empty 'pillars' list.") 

129 

130 manifest_dir = manifest_file.resolve().parent 

131 resolved = [] 

132 for item in pillars: 

133 name = item.get("name") 

134 path = item.get("path") 

135 if not name or not path: 135 ↛ 136line 135 didn't jump to line 136 because the condition on line 135 was never true

136 raise ValueError("Each pillar in manifest must include 'name' and 'path'.") 

137 resolved_path = _resolve_existing_path(path, manifest_dir) 

138 resolved.append({"name": str(name), "path": str(resolved_path)}) 

139 return resolved 

140 

141 

142def iter_token_windows(token_ids: Sequence[int], max_seq_len: int) -> Iterable[Tuple[List[int], List[int]]]: 

143 if max_seq_len <= 0: 143 ↛ 144line 143 didn't jump to line 144 because the condition on line 143 was never true

144 raise ValueError("max_seq_len must be > 0") 

145 if len(token_ids) < 2: 145 ↛ 146line 145 didn't jump to line 146 because the condition on line 145 was never true

146 return 

147 

148 upper = len(token_ids) - 1 

149 for start in range(0, upper, max_seq_len): 

150 window = token_ids[start : start + max_seq_len + 1] 

151 if len(window) < 2: 151 ↛ 152line 151 didn't jump to line 152 because the condition on line 151 was never true

152 continue 

153 yield window[:-1], window[1:] 

154 

155 

156def _run_loss_batch(model, input_batch, target_batch, device: str): 

157 max_len = max(len(x) for x in input_batch) 

158 batch_size = len(input_batch) 

159 

160 input_ids = torch.zeros((batch_size, max_len), dtype=torch.long, device=device) 

161 targets = torch.full((batch_size, max_len), -100, dtype=torch.long, device=device) 

162 for i, (x_ids, y_ids) in enumerate(zip(input_batch, target_batch)): 

163 n = len(x_ids) 

164 input_ids[i, :n] = torch.tensor(x_ids, dtype=torch.long, device=device) 

165 targets[i, :n] = torch.tensor(y_ids, dtype=torch.long, device=device) 

166 

167 with torch.no_grad(): 

168 with maybe_autocast(device): 

169 loss_sum, valid_tokens = model(input_ids, targets=targets) 

170 

171 return float(loss_sum.item()), int(valid_tokens.item()) 

172 

173 

174def evaluate_pillar( 

175 model, 

176 tokenizer, 

177 parquet_path: str, 

178 *, 

179 device: str, 

180 max_seq_len: int, 

181 batch_size: int, 

182 max_examples: int | None, 

183): 

184 dataset = load_dataset("parquet", data_files=parquet_path, split="train") 

185 

186 loss_sum_total = 0.0 

187 token_count_total = 0 

188 docs_processed = 0 

189 windows_processed = 0 

190 

191 pending_inputs: List[List[int]] = [] 

192 pending_targets: List[List[int]] = [] 

193 

194 for row_idx, row in enumerate(dataset): 

195 if max_examples is not None and row_idx >= max_examples: 

196 break 

197 

198 text = extract_text_from_row(row) 

199 if not text: 

200 continue 

201 

202 token_ids = tokenizer(text, add_special_tokens=False)["input_ids"] 

203 if tokenizer.eos_token_id is not None: 

204 token_ids = token_ids + [int(tokenizer.eos_token_id)] 

205 

206 has_windows = False 

207 for x_ids, y_ids in iter_token_windows(token_ids, max_seq_len=max_seq_len): 

208 has_windows = True 

209 pending_inputs.append(x_ids) 

210 pending_targets.append(y_ids) 

211 

212 if len(pending_inputs) >= batch_size: 

213 batch_loss_sum, batch_tokens = _run_loss_batch(model, pending_inputs, pending_targets, device) 

214 loss_sum_total += batch_loss_sum 

215 token_count_total += batch_tokens 

216 windows_processed += len(pending_inputs) 

217 pending_inputs.clear() 

218 pending_targets.clear() 

219 

220 if has_windows: 

221 docs_processed += 1 

222 

223 if pending_inputs: 

224 batch_loss_sum, batch_tokens = _run_loss_batch(model, pending_inputs, pending_targets, device) 

225 loss_sum_total += batch_loss_sum 

226 token_count_total += batch_tokens 

227 windows_processed += len(pending_inputs) 

228 

229 mean_loss = loss_sum_total / max(token_count_total, 1) 

230 perplexity = float(math.exp(min(mean_loss, MAX_SAFE_EXP))) 

231 

232 return { 

233 "loss": float(mean_loss), 

234 "perplexity": perplexity, 

235 "tokens": int(token_count_total), 

236 "documents": int(docs_processed), 

237 "windows": int(windows_processed), 

238 } 

239 

240 

241def evaluate_checkpoint(model, tokenizer, checkpoint_path: str, pillars, args): 

242 try: 

243 ckpt = torch.load(checkpoint_path, map_location=args.device, weights_only=False) 

244 except (RuntimeError, EOFError, OSError, ValueError, pickle.UnpicklingError) as e: 

245 print(f"⚠️ Skipping corrupted checkpoint {checkpoint_path}: {e}") 

246 return False 

247 

248 if "model_state_dict" not in ckpt: 248 ↛ 252line 248 didn't jump to line 252 because the condition on line 248 was always true

249 print(f"⚠️ Skipping checkpoint {checkpoint_path}: missing 'model_state_dict'") 

250 return False 

251 

252 try: 

253 model.load_state_dict(ckpt["model_state_dict"]) 

254 model.prepare_for_inference() 

255 except Exception as e: 

256 print(f"⚠️ Skipping checkpoint {checkpoint_path}: failed to load state dict: {e}") 

257 return False 

258 

259 checkpoint_step = int(ckpt.get("step", parse_checkpoint_step(checkpoint_path))) 

260 checkpoint_tokens = int(ckpt.get("total_tokens", 0)) 

261 

262 overall_loss_sum = 0.0 

263 overall_tokens = 0 

264 

265 print(f"Evaluating checkpoint: {checkpoint_path} (step={checkpoint_step})") 

266 for pillar in pillars: 

267 pillar_name = pillar["name"] 

268 pillar_metrics = evaluate_pillar( 

269 model, 

270 tokenizer, 

271 pillar["path"], 

272 device=args.device, 

273 max_seq_len=args.max_seq_len, 

274 batch_size=args.batch_size, 

275 max_examples=args.max_examples_per_pillar, 

276 ) 

277 

278 overall_loss_sum += pillar_metrics["loss"] * pillar_metrics["tokens"] 

279 overall_tokens += pillar_metrics["tokens"] 

280 

281 print( 

282 f" {pillar_name:<10} loss={pillar_metrics['loss']:.4f} " 

283 f"ppl={pillar_metrics['perplexity']:.2f} tokens={pillar_metrics['tokens']}" 

284 ) 

285 

286 prefix = f"Validation/{pillar_name}" 

287 wandb.log( 

288 { 

289 f"{prefix}/Loss": pillar_metrics["loss"], 

290 f"{prefix}/Perplexity": pillar_metrics["perplexity"], 

291 f"{prefix}/Tokens": pillar_metrics["tokens"], 

292 f"{prefix}/Documents": pillar_metrics["documents"], 

293 f"{prefix}/Windows": pillar_metrics["windows"], 

294 }, 

295 step=checkpoint_step, 

296 commit=False, 

297 ) 

298 

299 overall_loss = overall_loss_sum / max(overall_tokens, 1) 

300 overall_ppl = float(math.exp(min(overall_loss, MAX_SAFE_EXP))) 

301 

302 print(f" overall loss={overall_loss:.4f} ppl={overall_ppl:.2f} tokens={overall_tokens}") 

303 wandb.log( 

304 { 

305 "Validation/System/CheckpointStep": checkpoint_step, 

306 "Validation/System/TotalTokens": checkpoint_tokens, 

307 "Validation/Overall/Loss": overall_loss, 

308 "Validation/Overall/Perplexity": overall_ppl, 

309 "Validation/Overall/Tokens": overall_tokens, 

310 }, 

311 step=checkpoint_step, 

312 ) 

313 return True 

314 

315 

316def parse_args(): 

317 parser = argparse.ArgumentParser(description="Evaluate checkpoints on validation pillars and log to WandB.") 

318 parser.add_argument("--mode", choices=["scout", "parent", "upscaled"], default="scout") 

319 parser.add_argument("--manifest-path", default="val_data/validation_manifest.json") 

320 parser.add_argument("--tokenizer-path", default="custom_agentic_tokenizer") 

321 

322 parser.add_argument("--checkpoint-path", default=None, help="Evaluate a single checkpoint file.") 

323 parser.add_argument( 

324 "--checkpoint-dir", 

325 default=None, 

326 help="Directory containing step_*.pt checkpoints. Used when --checkpoint-path is not set.", 

327 ) 

328 parser.add_argument( 

329 "--checkpoint-glob", 

330 default="step_*.pt", 

331 help="Glob pattern inside --checkpoint-dir for checkpoint discovery.", 

332 ) 

333 parser.add_argument( 

334 "--max-checkpoints", 

335 type=int, 

336 default=None, 

337 help="If set, evaluate only the latest N discovered checkpoints.", 

338 ) 

339 

340 parser.add_argument("--batch-size", type=int, default=4) 

341 parser.add_argument("--max-seq-len", type=int, default=CONTEXT_LENGTH) 

342 parser.add_argument( 

343 "--max-examples-per-pillar", 

344 type=int, 

345 default=None, 

346 help="Optional cap for faster smoke runs.", 

347 ) 

348 

349 parser.add_argument("--wandb-project", default="Agentic-1.58b-Validation") 

350 parser.add_argument("--wandb-name", default=None) 

351 parser.add_argument("--wandb-run-id", default=None) 

352 parser.add_argument("--wandb-resume", default="allow", choices=["allow", "must", "never", "auto"]) 

353 parser.add_argument( 

354 "--wandb-mode", 

355 default=None, 

356 choices=["online", "offline", "disabled"], 

357 help="Optional WANDB_MODE override. If omitted, keeps current environment setting.", 

358 ) 

359 parser.add_argument("--device", default=DEVICE) 

360 return parser.parse_args() 

361 

362 

363def main() -> int: 

364 args = parse_args() 

365 if args.wandb_mode is not None: 

366 os.environ["WANDB_MODE"] = args.wandb_mode 

367 

368 model_config = resolve_model_config(args.mode) 

369 checkpoint_dir = args.checkpoint_dir or resolve_default_checkpoint_dir(args.mode) 

370 checkpoints = discover_checkpoints( 

371 args.checkpoint_path, 

372 checkpoint_dir, 

373 args.checkpoint_glob, 

374 args.max_checkpoints, 

375 ) 

376 pillars = load_manifest_pillars(args.manifest_path) 

377 

378 print(f"Loading tokenizer from {args.tokenizer_path}") 

379 tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_path) 

380 tokenizer.model_max_length = int(1e9) 

381 

382 print(f"Building model for mode={args.mode} on device={args.device}") 

383 model = BitMambaLLM(**model_config).to(args.device) 

384 model.eval() 

385 

386 run_name = args.wandb_name or f"validation-{args.mode}" 

387 wandb.init( 

388 project=args.wandb_project, 

389 name=run_name, 

390 id=args.wandb_run_id, 

391 resume=args.wandb_resume, 

392 settings=wandb.Settings(start_method="thread"), 

393 config={ 

394 "mode": args.mode, 

395 "manifest_path": args.manifest_path, 

396 "tokenizer_path": args.tokenizer_path, 

397 "checkpoint_dir": checkpoint_dir, 

398 "checkpoint_glob": args.checkpoint_glob, 

399 "checkpoint_count": len(checkpoints), 

400 "batch_size": args.batch_size, 

401 "max_seq_len": args.max_seq_len, 

402 "max_examples_per_pillar": args.max_examples_per_pillar, 

403 }, 

404 ) 

405 

406 try: 

407 for checkpoint_path in checkpoints: 

408 success = evaluate_checkpoint(model, tokenizer, checkpoint_path, pillars, args) 

409 if not success: 

410 continue 

411 finally: 

412 wandb.finish() 

413 

414 print(f"Validation complete for {len(checkpoints)} checkpoint(s).") 

415 return 0 

416 

417 

418if __name__ == "__main__": 

419 raise SystemExit(main())