Coverage for train_tokenizer_spm.py: 81%

269 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-21 23:06 +0000

1# Copyright 2026 venim1103 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14 

15import argparse 

16import copy 

17import json 

18import os 

19import tempfile 

20from collections import defaultdict 

21from pathlib import Path 

22 

23import sentencepiece as spm 

24 

25import train_tokenizer as base 

26from context_config import CONTEXT_LENGTH 

27 

28 

29DOMAINS = ("logic", "code", "tools", "web", "other") 

30 

31 

32def _resolve_profile(explicit_profile, profile_registry): 

33 if explicit_profile: 

34 if explicit_profile in profile_registry: 

35 return explicit_profile 

36 available = ", ".join(sorted(profile_registry)) 

37 raise ValueError(f"Unknown profile {explicit_profile!r}. Available profiles: {available}") 

38 if os.getenv("KAGGLE_KERNEL_RUN_TYPE"): 

39 if "kaggle" in profile_registry: 39 ↛ 41line 39 didn't jump to line 41 because the condition on line 39 was always true

40 return "kaggle" 

41 return next(iter(sorted(profile_registry))) 

42 if "standard" in profile_registry: 42 ↛ 44line 42 didn't jump to line 44 because the condition on line 42 was always true

43 return "standard" 

44 return next(iter(sorted(profile_registry))) 

45 

46 

47def _parse_domain_map(text, name): 

48 result = {} 

49 for item in text.split(","): 

50 item = item.strip() 

51 if not item: 51 ↛ 52line 51 didn't jump to line 52 because the condition on line 51 was never true

52 continue 

53 if "=" not in item: 53 ↛ 54line 53 didn't jump to line 54 because the condition on line 53 was never true

54 raise ValueError(f"{name} must use key=value entries. Got {item!r}.") 

55 key, raw_value = item.split("=", 1) 

56 key = key.strip().lower() 

57 if key not in DOMAINS: 57 ↛ 58line 57 didn't jump to line 58 because the condition on line 57 was never true

58 raise ValueError(f"Unsupported domain {key!r} in {name}. Expected one of: {', '.join(DOMAINS)}") 

59 result[key] = float(raw_value.strip()) 

60 return result 

61 

62 

63def _normalize_domain_weights(raw_weights): 

64 for domain in DOMAINS: 

65 if domain not in raw_weights: 65 ↛ 66line 65 didn't jump to line 66 because the condition on line 65 was never true

66 raise ValueError(f"Missing domain weight for {domain!r}.") 

67 if raw_weights[domain] < 0: 67 ↛ 68line 67 didn't jump to line 68 because the condition on line 67 was never true

68 raise ValueError(f"Domain weight for {domain!r} must be non-negative.") 

69 total_weight = sum(raw_weights.values()) 

70 if total_weight <= 0: 70 ↛ 71line 70 didn't jump to line 71 because the condition on line 70 was never true

71 raise ValueError("Domain weights must sum to a positive value.") 

72 return {domain: raw_weights[domain] / total_weight for domain in DOMAINS} 

73 

74 

75def _calculate_domain_quotas(total_lines, normalized_weights): 

76 if total_lines <= 0: 76 ↛ 77line 76 didn't jump to line 77 because the condition on line 76 was never true

77 raise ValueError("quota_total_lines must be > 0.") 

78 quotas = {domain: int(total_lines * normalized_weights[domain]) for domain in DOMAINS} 

79 assigned = sum(quotas.values()) 

80 remainder = total_lines - assigned 

81 if remainder > 0: 81 ↛ 82line 81 didn't jump to line 82 because the condition on line 81 was never true

82 domains_by_fraction = sorted( 

83 DOMAINS, 

84 key=lambda domain: (total_lines * normalized_weights[domain]) - quotas[domain], 

85 reverse=True, 

86 ) 

87 for idx in range(remainder): 

88 quotas[domains_by_fraction[idx % len(domains_by_fraction)]] += 1 

89 return quotas 

90 

91 

92def _load_profile_registry(profile_file): 

93 profile_registry = copy.deepcopy(SPM_PROFILE_DEFAULTS) 

94 if not profile_file: 94 ↛ 97line 94 didn't jump to line 97 because the condition on line 94 was always true

95 return profile_registry 

96 

97 with open(profile_file, "r", encoding="utf-8") as handle: 

98 raw = json.load(handle) 

99 

100 custom_profiles = raw.get("profiles", raw) 

101 if not isinstance(custom_profiles, dict): 

102 raise ValueError("Profile file must contain a JSON object (or {'profiles': {...}}).") 

103 

104 for name, settings in custom_profiles.items(): 

105 if not isinstance(settings, dict): 

106 raise ValueError(f"Profile {name!r} must map to an object.") 

107 profile_registry[name] = settings 

108 

109 return profile_registry 

110 

111 

112def _resolve_profile_settings(args, profile_name, profile_registry): 

113 defaults = profile_registry[profile_name] 

114 input_sentence_size = args.input_sentence_size if args.input_sentence_size is not None else defaults["input_sentence_size"] 

115 max_sentence_length = args.max_sentence_length if args.max_sentence_length is not None else defaults["max_sentence_length"] 

116 

117 quota_total_lines = ( 

118 args.quota_total_lines 

119 if args.quota_total_lines is not None 

120 else defaults.get("quota_total_lines", sum(defaults.get("domain_quota", {}).values())) 

121 ) 

122 

123 if "domain_weights" in defaults: 123 ↛ 126line 123 didn't jump to line 126 because the condition on line 123 was always true

124 domain_weights = dict(defaults["domain_weights"]) 

125 else: 

126 legacy_quota = defaults.get("domain_quota", {}) 

127 legacy_total = sum(legacy_quota.get(domain, 0) for domain in DOMAINS) 

128 if legacy_total <= 0: 

129 raise ValueError(f"Profile {profile_name!r} does not define usable domain quotas/weights.") 

130 domain_weights = {domain: legacy_quota.get(domain, 0) / legacy_total for domain in DOMAINS} 

131 

132 if args.quota_weights: 

133 domain_weights.update(_parse_domain_map(args.quota_weights, "--quota-weights")) 

134 normalized_weights = _normalize_domain_weights(domain_weights) 

135 quotas = _calculate_domain_quotas(quota_total_lines, normalized_weights) 

136 

137 if args.quota_overrides: 

138 for domain, value in _parse_domain_map(args.quota_overrides, "--quota-overrides").items(): 

139 quotas[domain] = max(0, int(value)) 

140 

141 return { 

142 "input_sentence_size": input_sentence_size, 

143 "max_sentence_length": max_sentence_length, 

144 "quota_total_lines": quota_total_lines, 

145 "domain_weights": normalized_weights, 

146 "domain_quotas": quotas, 

147 } 

148 

149 

150def _auto_tune_input_sentence_size(profile, input_sentence_size): 

151 """Scale SPM sampling upward on higher-RAM machines when using kaggle profile. 

152 

153 This keeps the kaggle defaults safe for low-RAM environments while making 

154 better use of local machines that intentionally run with kaggle settings. 

155 """ 

156 if profile != "kaggle": 

157 return input_sentence_size 

158 

159 raw_ram = os.getenv("TOKENIZER_MAX_RAM_GB") 

160 if not raw_ram: 

161 return input_sentence_size 

162 

163 try: 

164 ram_gb = float(raw_ram) 

165 except ValueError: 

166 return input_sentence_size 

167 

168 base = SPM_PROFILE_DEFAULTS["kaggle"]["input_sentence_size"] 

169 if ram_gb <= 13: 169 ↛ 170line 169 didn't jump to line 170 because the condition on line 169 was never true

170 return input_sentence_size 

171 

172 # Linear scale from 13 GB -> 30 GB, capped for OOM safety. 

173 # With current settings, 30 GB machines were underutilized (~11 GB peak), 

174 # so this default cap is intentionally more aggressive. 

175 target_cap = 2_800_000 

176 raw_cap = os.getenv("TOKENIZER_SPM_AUTO_MAX_INPUT_SENTENCE_SIZE") 

177 if raw_cap: 

178 try: 

179 target_cap = max(base, int(raw_cap)) 

180 except ValueError: 

181 pass 

182 

183 scaled = int(base + (ram_gb - 13.0) * (target_cap - base) / (30.0 - 13.0)) 

184 tuned = max(base, min(target_cap, scaled)) 

185 

186 if tuned != input_sentence_size: 186 ↛ 191line 186 didn't jump to line 191 because the condition on line 186 was always true

187 print( 

188 "Auto-tuning input_sentence_size for kaggle profile based on " 

189 f"TOKENIZER_MAX_RAM_GB={ram_gb:g}: {input_sentence_size:,} -> {tuned:,}" 

190 ) 

191 return tuned 

192 

193 

194SPM_PROFILE_DEFAULTS = { 

195 "standard": { 

196 "input_sentence_size": 1_500_000, 

197 "max_sentence_length": 2048, 

198 "quota_total_lines": 2_500_000, 

199 "domain_weights": { 

200 "logic": 0.22, 

201 "code": 0.30, 

202 "tools": 0.16, 

203 "web": 0.28, 

204 "other": 0.04, 

205 }, 

206 }, 

207 "kaggle": { 

208 "input_sentence_size": 750_000, 

209 "max_sentence_length": 2048, 

210 "quota_total_lines": 2_100_000, 

211 "domain_weights": { 

212 "logic": 0.22, 

213 "code": 0.30, 

214 "tools": 0.16, 

215 "web": 0.28, 

216 "other": 0.04, 

217 }, 

218 }, 

219} 

220 

221 

222def parse_args(): 

223 parser = argparse.ArgumentParser(description="Train a low-RAM SentencePiece tokenizer.") 

224 parser.add_argument("--profile", help="Training profile name (built-in or from --profile-file).") 

225 parser.add_argument( 

226 "--profile-file", 

227 help="Optional JSON file with custom profiles. Accepts either {name: {...}} or {'profiles': {name: {...}}}.", 

228 ) 

229 parser.add_argument("--vocab-size", type=int, default=64_000) 

230 parser.add_argument("--model-type", choices=["bpe", "unigram"], default="unigram") 

231 parser.add_argument("--character-coverage", type=float, default=0.9995) 

232 parser.add_argument( 

233 "--byte-fallback", 

234 action=argparse.BooleanOptionalAction, 

235 default=True, 

236 help=( 

237 "Enable SentencePiece byte fallback. Recommended for code so unseen bytes " 

238 "are represented as byte pieces instead of <|unk|>." 

239 ), 

240 ) 

241 parser.add_argument("--output-dir", default="custom_agentic_tokenizer_spm") 

242 parser.add_argument( 

243 "--input-sentence-size", 

244 type=int, 

245 help="Max sentence count for SPM trainer sampling. Uses profile default if omitted.", 

246 ) 

247 parser.add_argument( 

248 "--max-sentence-length", 

249 type=int, 

250 help="Max sentence length passed to SPM trainer. Uses profile default if omitted.", 

251 ) 

252 parser.add_argument( 

253 "--quota-total-lines", 

254 type=int, 

255 help="Total sampled corpus lines before SPM training. Per-domain quotas are derived from normalized weights.", 

256 ) 

257 parser.add_argument( 

258 "--quota-weights", 

259 help=( 

260 "Comma-separated domain weights (logic=0.2,code=0.4,tools=0.1,web=0.25,other=0.05). " 

261 "Values are normalized automatically." 

262 ), 

263 ) 

264 parser.add_argument( 

265 "--quota-overrides", 

266 help="Comma-separated absolute per-domain quotas (e.g. code=400000,web=500000).", 

267 ) 

268 parser.add_argument( 

269 "--code-fidelity-mode", 

270 action=argparse.BooleanOptionalAction, 

271 default=False, 

272 help=( 

273 "Preserve code whitespace structure better by avoiding newline flattening in corpus " 

274 "construction and disabling aggressive whitespace normalization in SPM." 

275 ), 

276 ) 

277 parser.add_argument( 

278 "--deterministic", 

279 action=argparse.BooleanOptionalAction, 

280 default=False, 

281 help=( 

282 "Make SPM training deterministic/reproducible by disabling sentence shuffle, " 

283 "using a single trainer thread, and sampling the full built corpus." 

284 ), 

285 ) 

286 return parser.parse_args() 

287 

288 

289def _infer_domain(file_path): 

290 normalized = file_path.replace("\\", "/").lower() 

291 if "/logic/" in normalized: 

292 return "logic" 

293 if "/code/" in normalized: 

294 return "code" 

295 if "/tools/" in normalized: 

296 return "tools" 

297 if "/web/" in normalized: 

298 return "web" 

299 return "other" 

300 

301 

302def _build_temp_corpus(profile, quotas, max_sentence_length, code_fidelity_mode=False): 

303 domain_counts = defaultdict(int) 

304 total = 0 

305 

306 temp = tempfile.NamedTemporaryFile(mode="w", delete=False, encoding="utf-8") 

307 temp_path = temp.name 

308 

309 print(f"Building temporary corpus file: {temp_path}") 

310 print(f"Profile: {profile} | domain quotas: {quotas}") 

311 

312 try: 

313 for file_path in base.iter_data_files(): 

314 domain = _infer_domain(file_path) 

315 if domain_counts[domain] >= quotas[domain]: 

316 continue 

317 

318 for text in base.iter_file_texts(file_path): 

319 if domain_counts[domain] >= quotas[domain]: 

320 break 

321 

322 if code_fidelity_mode: 

323 # Keep line structure/indentation signal for code-heavy text. 

324 text = text.replace("\r\n", "\n").replace("\r", "\n") 

325 if not text.strip(): 325 ↛ 326line 325 didn't jump to line 326 because the condition on line 325 was never true

326 continue 

327 else: 

328 text = text.replace("\n", " ").strip() 

329 if not text: 

330 continue 

331 

332 if max_sentence_length > 0: 

333 text = text[:max_sentence_length] 

334 

335 temp.write(text) 

336 temp.write("\n") 

337 

338 domain_counts[domain] += 1 

339 total += 1 

340 

341 temp.flush() 

342 print(f"Temporary corpus built with {total:,} lines.") 

343 for domain in sorted(quotas): 

344 print(f" {domain:>5}: {domain_counts[domain]:,} / {quotas[domain]:,}") 

345 return temp_path, total, dict(domain_counts) 

346 finally: 

347 temp.close() 

348 

349 

350def _write_run_manifest( 

351 output_dir, 

352 profile, 

353 settings, 

354 requested_input_sentence_size, 

355 effective_input_sentence_size, 

356 total_lines, 

357 domain_counts, 

358 args, 

359): 

360 manifest = { 

361 "profile": profile, 

362 "vocab_size": args.vocab_size, 

363 "model_type": args.model_type, 

364 "character_coverage": args.character_coverage, 

365 "byte_fallback": args.byte_fallback, 

366 "code_fidelity_mode": args.code_fidelity_mode, 

367 "deterministic": args.deterministic, 

368 "requested_input_sentence_size": requested_input_sentence_size, 

369 "effective_input_sentence_size": effective_input_sentence_size, 

370 "max_sentence_length": settings["max_sentence_length"], 

371 "quota_total_lines": settings["quota_total_lines"], 

372 "domain_weights": settings["domain_weights"], 

373 "derived_quotas": settings["domain_quotas"], 

374 "corpus_total_lines": total_lines, 

375 "corpus_domain_counts": domain_counts, 

376 "tokenizer_model_max_length": _resolve_model_max_length(), 

377 } 

378 manifest_path = Path(output_dir) / "training_manifest.json" 

379 manifest_path.write_text(json.dumps(manifest, indent=2, sort_keys=True), encoding="utf-8") 

380 print(f"Saved training manifest to ./{manifest_path}") 

381 

382 

383def _resolve_model_max_length(): 

384 raw = os.getenv("TOKENIZER_MODEL_MAX_LENGTH") 

385 if raw: 

386 try: 

387 parsed = int(raw) 

388 if parsed > 0: 388 ↛ 392line 388 didn't jump to line 392 because the condition on line 388 was always true

389 return parsed 

390 except ValueError: 

391 pass 

392 return CONTEXT_LENGTH 

393 

394 

395def _export_hf_tokenizer(spm_model_path, output_dir, model_max_length=CONTEXT_LENGTH): 

396 processor = spm.SentencePieceProcessor(model_file=spm_model_path) 

397 

398 try: 

399 import sys 

400 

401 import sentencepiece.sentencepiece_model_pb2 as sp_pb2 

402 from tokenizers import Tokenizer 

403 from tokenizers import decoders 

404 from tokenizers.decoders import Metaspace as MetaspaceDecoder 

405 from tokenizers.implementations import SentencePieceUnigramTokenizer 

406 from tokenizers.models import Unigram 

407 from tokenizers.normalizers import NFKC 

408 from tokenizers.pre_tokenizers import Metaspace 

409 from transformers import PreTrainedTokenizerFast 

410 except Exception as exc: # pragma: no cover 

411 print(f"Skipping HuggingFace export (fast tokenizer dependencies unavailable): {exc}") 

412 return 

413 

414 try: 

415 # Preferred path: import directly from the .model to preserve SentencePiece behavior. 

416 sys.modules.setdefault("sentencepiece_model_pb2", sp_pb2) 

417 impl = SentencePieceUnigramTokenizer.from_spm(spm_model_path) 

418 impl._tokenizer.decoder = decoders.Sequence( 

419 [ 

420 decoders.ByteFallback(), 

421 decoders.Metaspace(replacement="▁", prepend_scheme="always", split=True), 

422 ] 

423 ) 

424 

425 tokenizer = PreTrainedTokenizerFast( 

426 tokenizer_object=impl._tokenizer, 

427 bos_token="<s>", 

428 eos_token="<|eos|>", 

429 unk_token="<|unk|>", 

430 pad_token="<|pad|>", 

431 model_max_length=model_max_length, 

432 ) 

433 tokenizer.save_pretrained(output_dir) 

434 print(f"Saved HuggingFace tokenizer files to ./{output_dir}") 

435 return 

436 except Exception as spm_exc: 

437 print(f"Direct SentencePiece export path unavailable, falling back to manual conversion: {spm_exc}") 

438 

439 try: 

440 byte_fallback = False 

441 try: 

442 from sentencepiece import sentencepiece_model_pb2 as sp_pb2 

443 

444 model_proto = sp_pb2.ModelProto() 

445 with open(spm_model_path, "rb") as fh: 

446 model_proto.ParseFromString(fh.read()) 

447 byte_fallback = bool(model_proto.trainer_spec.byte_fallback) 

448 except Exception: 

449 # Best-effort: if protobuf parsing is unavailable, keep legacy default. 

450 pass 

451 

452 vocab = [(processor.id_to_piece(i), processor.get_score(i)) for i in range(processor.vocab_size())] 

453 backend = Tokenizer(Unigram(vocab, unk_id=processor.unk_id(), byte_fallback=byte_fallback)) 

454 backend.normalizer = NFKC() 

455 backend.pre_tokenizer = Metaspace(replacement="▁", prepend_scheme="first") 

456 backend.decoder = decoders.Sequence( 

457 [ 

458 decoders.ByteFallback(), 

459 MetaspaceDecoder(replacement="▁", prepend_scheme="first"), 

460 ] 

461 ) 

462 

463 tokenizer = PreTrainedTokenizerFast( 

464 tokenizer_object=backend, 

465 bos_token="<s>", 

466 eos_token="<|eos|>", 

467 unk_token="<|unk|>", 

468 pad_token="<|pad|>", 

469 model_max_length=model_max_length, 

470 ) 

471 tokenizer.save_pretrained(output_dir) 

472 print(f"Saved HuggingFace tokenizer files to ./{output_dir}") 

473 except Exception as exc: # pragma: no cover 

474 print(f"Error during HuggingFace tokenizer export: {exc}") 

475 raise 

476 

477 

478def main(): 

479 args = parse_args() 

480 profile_registry = _load_profile_registry(args.profile_file) 

481 profile = _resolve_profile(args.profile, profile_registry) 

482 settings = _resolve_profile_settings(args, profile, profile_registry) 

483 input_sentence_size = settings["input_sentence_size"] 

484 max_sentence_length = settings["max_sentence_length"] 

485 

486 if args.input_sentence_size is None: 

487 input_sentence_size = _auto_tune_input_sentence_size(profile, input_sentence_size) 

488 

489 output_dir = Path(args.output_dir) 

490 output_dir.mkdir(parents=True, exist_ok=True) 

491 

492 corpus_path, total_lines, domain_counts = _build_temp_corpus( 

493 profile=profile, 

494 quotas=settings["domain_quotas"], 

495 max_sentence_length=max_sentence_length, 

496 code_fidelity_mode=args.code_fidelity_mode, 

497 ) 

498 

499 try: 

500 if total_lines == 0: 

501 raise RuntimeError("No training text found for SentencePiece corpus.") 

502 

503 requested_input_sentence_size = input_sentence_size 

504 if args.deterministic: 

505 effective_input_sentence_size = max(1, total_lines) 

506 shuffle_input_sentence = False 

507 num_threads = 1 

508 else: 

509 effective_input_sentence_size = input_sentence_size 

510 shuffle_input_sentence = True 

511 num_threads = max(os.cpu_count() or 1, 1) 

512 

513 model_prefix = output_dir / "spm" 

514 

515 print( 

516 f"Training SentencePiece ({args.model_type}) with vocab={args.vocab_size:,}, " 

517 f"input_sentence_size={effective_input_sentence_size:,}, max_sentence_length={max_sentence_length}, " 

518 f"byte_fallback={args.byte_fallback}, code_fidelity_mode={args.code_fidelity_mode}, " 

519 f"deterministic={args.deterministic}." 

520 ) 

521 print( 

522 "Quota design: total_lines=" 

523 f"{settings['quota_total_lines']:,}, normalized_weights={settings['domain_weights']}, " 

524 f"derived_quotas={settings['domain_quotas']}" 

525 ) 

526 

527 spm.SentencePieceTrainer.train( 

528 input=corpus_path, 

529 model_prefix=str(model_prefix), 

530 model_type=args.model_type, 

531 vocab_size=args.vocab_size, 

532 character_coverage=args.character_coverage, 

533 input_sentence_size=effective_input_sentence_size, 

534 shuffle_input_sentence=shuffle_input_sentence, 

535 max_sentence_length=max_sentence_length, 

536 pad_id=0, 

537 unk_id=1, 

538 bos_id=2, 

539 eos_id=3, 

540 pad_piece="<|pad|>", 

541 unk_piece="<|unk|>", 

542 bos_piece="<s>", 

543 eos_piece="<|eos|>", 

544 user_defined_symbols=["<|im_start|>", "<|im_end|>", "<think>", "</think>"], 

545 byte_fallback=args.byte_fallback, 

546 split_by_whitespace=not args.code_fidelity_mode, 

547 remove_extra_whitespaces=not args.code_fidelity_mode, 

548 num_threads=num_threads, 

549 ) 

550 

551 _write_run_manifest( 

552 output_dir=str(output_dir), 

553 profile=profile, 

554 settings=settings, 

555 requested_input_sentence_size=requested_input_sentence_size, 

556 effective_input_sentence_size=effective_input_sentence_size, 

557 total_lines=total_lines, 

558 domain_counts=domain_counts, 

559 args=args, 

560 ) 

561 finally: 

562 os.remove(corpus_path) 

563 

564 spm_model_path = str(model_prefix) + ".model" 

565 _export_hf_tokenizer( 

566 spm_model_path, 

567 str(output_dir), 

568 model_max_length=_resolve_model_max_length(), 

569 ) 

570 

571 print(f"Done. SentencePiece artifacts in ./{output_dir}") 

572 

573 

574if __name__ == "__main__": 

575 main()