Coverage for train_tokenizer_spm.py: 81%
269 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-21 23:06 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-21 23:06 +0000
1# Copyright 2026 venim1103
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
15import argparse
16import copy
17import json
18import os
19import tempfile
20from collections import defaultdict
21from pathlib import Path
23import sentencepiece as spm
25import train_tokenizer as base
26from context_config import CONTEXT_LENGTH
29DOMAINS = ("logic", "code", "tools", "web", "other")
32def _resolve_profile(explicit_profile, profile_registry):
33 if explicit_profile:
34 if explicit_profile in profile_registry:
35 return explicit_profile
36 available = ", ".join(sorted(profile_registry))
37 raise ValueError(f"Unknown profile {explicit_profile!r}. Available profiles: {available}")
38 if os.getenv("KAGGLE_KERNEL_RUN_TYPE"):
39 if "kaggle" in profile_registry: 39 ↛ 41line 39 didn't jump to line 41 because the condition on line 39 was always true
40 return "kaggle"
41 return next(iter(sorted(profile_registry)))
42 if "standard" in profile_registry: 42 ↛ 44line 42 didn't jump to line 44 because the condition on line 42 was always true
43 return "standard"
44 return next(iter(sorted(profile_registry)))
47def _parse_domain_map(text, name):
48 result = {}
49 for item in text.split(","):
50 item = item.strip()
51 if not item: 51 ↛ 52line 51 didn't jump to line 52 because the condition on line 51 was never true
52 continue
53 if "=" not in item: 53 ↛ 54line 53 didn't jump to line 54 because the condition on line 53 was never true
54 raise ValueError(f"{name} must use key=value entries. Got {item!r}.")
55 key, raw_value = item.split("=", 1)
56 key = key.strip().lower()
57 if key not in DOMAINS: 57 ↛ 58line 57 didn't jump to line 58 because the condition on line 57 was never true
58 raise ValueError(f"Unsupported domain {key!r} in {name}. Expected one of: {', '.join(DOMAINS)}")
59 result[key] = float(raw_value.strip())
60 return result
63def _normalize_domain_weights(raw_weights):
64 for domain in DOMAINS:
65 if domain not in raw_weights: 65 ↛ 66line 65 didn't jump to line 66 because the condition on line 65 was never true
66 raise ValueError(f"Missing domain weight for {domain!r}.")
67 if raw_weights[domain] < 0: 67 ↛ 68line 67 didn't jump to line 68 because the condition on line 67 was never true
68 raise ValueError(f"Domain weight for {domain!r} must be non-negative.")
69 total_weight = sum(raw_weights.values())
70 if total_weight <= 0: 70 ↛ 71line 70 didn't jump to line 71 because the condition on line 70 was never true
71 raise ValueError("Domain weights must sum to a positive value.")
72 return {domain: raw_weights[domain] / total_weight for domain in DOMAINS}
75def _calculate_domain_quotas(total_lines, normalized_weights):
76 if total_lines <= 0: 76 ↛ 77line 76 didn't jump to line 77 because the condition on line 76 was never true
77 raise ValueError("quota_total_lines must be > 0.")
78 quotas = {domain: int(total_lines * normalized_weights[domain]) for domain in DOMAINS}
79 assigned = sum(quotas.values())
80 remainder = total_lines - assigned
81 if remainder > 0: 81 ↛ 82line 81 didn't jump to line 82 because the condition on line 81 was never true
82 domains_by_fraction = sorted(
83 DOMAINS,
84 key=lambda domain: (total_lines * normalized_weights[domain]) - quotas[domain],
85 reverse=True,
86 )
87 for idx in range(remainder):
88 quotas[domains_by_fraction[idx % len(domains_by_fraction)]] += 1
89 return quotas
92def _load_profile_registry(profile_file):
93 profile_registry = copy.deepcopy(SPM_PROFILE_DEFAULTS)
94 if not profile_file: 94 ↛ 97line 94 didn't jump to line 97 because the condition on line 94 was always true
95 return profile_registry
97 with open(profile_file, "r", encoding="utf-8") as handle:
98 raw = json.load(handle)
100 custom_profiles = raw.get("profiles", raw)
101 if not isinstance(custom_profiles, dict):
102 raise ValueError("Profile file must contain a JSON object (or {'profiles': {...}}).")
104 for name, settings in custom_profiles.items():
105 if not isinstance(settings, dict):
106 raise ValueError(f"Profile {name!r} must map to an object.")
107 profile_registry[name] = settings
109 return profile_registry
112def _resolve_profile_settings(args, profile_name, profile_registry):
113 defaults = profile_registry[profile_name]
114 input_sentence_size = args.input_sentence_size if args.input_sentence_size is not None else defaults["input_sentence_size"]
115 max_sentence_length = args.max_sentence_length if args.max_sentence_length is not None else defaults["max_sentence_length"]
117 quota_total_lines = (
118 args.quota_total_lines
119 if args.quota_total_lines is not None
120 else defaults.get("quota_total_lines", sum(defaults.get("domain_quota", {}).values()))
121 )
123 if "domain_weights" in defaults: 123 ↛ 126line 123 didn't jump to line 126 because the condition on line 123 was always true
124 domain_weights = dict(defaults["domain_weights"])
125 else:
126 legacy_quota = defaults.get("domain_quota", {})
127 legacy_total = sum(legacy_quota.get(domain, 0) for domain in DOMAINS)
128 if legacy_total <= 0:
129 raise ValueError(f"Profile {profile_name!r} does not define usable domain quotas/weights.")
130 domain_weights = {domain: legacy_quota.get(domain, 0) / legacy_total for domain in DOMAINS}
132 if args.quota_weights:
133 domain_weights.update(_parse_domain_map(args.quota_weights, "--quota-weights"))
134 normalized_weights = _normalize_domain_weights(domain_weights)
135 quotas = _calculate_domain_quotas(quota_total_lines, normalized_weights)
137 if args.quota_overrides:
138 for domain, value in _parse_domain_map(args.quota_overrides, "--quota-overrides").items():
139 quotas[domain] = max(0, int(value))
141 return {
142 "input_sentence_size": input_sentence_size,
143 "max_sentence_length": max_sentence_length,
144 "quota_total_lines": quota_total_lines,
145 "domain_weights": normalized_weights,
146 "domain_quotas": quotas,
147 }
150def _auto_tune_input_sentence_size(profile, input_sentence_size):
151 """Scale SPM sampling upward on higher-RAM machines when using kaggle profile.
153 This keeps the kaggle defaults safe for low-RAM environments while making
154 better use of local machines that intentionally run with kaggle settings.
155 """
156 if profile != "kaggle":
157 return input_sentence_size
159 raw_ram = os.getenv("TOKENIZER_MAX_RAM_GB")
160 if not raw_ram:
161 return input_sentence_size
163 try:
164 ram_gb = float(raw_ram)
165 except ValueError:
166 return input_sentence_size
168 base = SPM_PROFILE_DEFAULTS["kaggle"]["input_sentence_size"]
169 if ram_gb <= 13: 169 ↛ 170line 169 didn't jump to line 170 because the condition on line 169 was never true
170 return input_sentence_size
172 # Linear scale from 13 GB -> 30 GB, capped for OOM safety.
173 # With current settings, 30 GB machines were underutilized (~11 GB peak),
174 # so this default cap is intentionally more aggressive.
175 target_cap = 2_800_000
176 raw_cap = os.getenv("TOKENIZER_SPM_AUTO_MAX_INPUT_SENTENCE_SIZE")
177 if raw_cap:
178 try:
179 target_cap = max(base, int(raw_cap))
180 except ValueError:
181 pass
183 scaled = int(base + (ram_gb - 13.0) * (target_cap - base) / (30.0 - 13.0))
184 tuned = max(base, min(target_cap, scaled))
186 if tuned != input_sentence_size: 186 ↛ 191line 186 didn't jump to line 191 because the condition on line 186 was always true
187 print(
188 "Auto-tuning input_sentence_size for kaggle profile based on "
189 f"TOKENIZER_MAX_RAM_GB={ram_gb:g}: {input_sentence_size:,} -> {tuned:,}"
190 )
191 return tuned
194SPM_PROFILE_DEFAULTS = {
195 "standard": {
196 "input_sentence_size": 1_500_000,
197 "max_sentence_length": 2048,
198 "quota_total_lines": 2_500_000,
199 "domain_weights": {
200 "logic": 0.22,
201 "code": 0.30,
202 "tools": 0.16,
203 "web": 0.28,
204 "other": 0.04,
205 },
206 },
207 "kaggle": {
208 "input_sentence_size": 750_000,
209 "max_sentence_length": 2048,
210 "quota_total_lines": 2_100_000,
211 "domain_weights": {
212 "logic": 0.22,
213 "code": 0.30,
214 "tools": 0.16,
215 "web": 0.28,
216 "other": 0.04,
217 },
218 },
219}
222def parse_args():
223 parser = argparse.ArgumentParser(description="Train a low-RAM SentencePiece tokenizer.")
224 parser.add_argument("--profile", help="Training profile name (built-in or from --profile-file).")
225 parser.add_argument(
226 "--profile-file",
227 help="Optional JSON file with custom profiles. Accepts either {name: {...}} or {'profiles': {name: {...}}}.",
228 )
229 parser.add_argument("--vocab-size", type=int, default=64_000)
230 parser.add_argument("--model-type", choices=["bpe", "unigram"], default="unigram")
231 parser.add_argument("--character-coverage", type=float, default=0.9995)
232 parser.add_argument(
233 "--byte-fallback",
234 action=argparse.BooleanOptionalAction,
235 default=True,
236 help=(
237 "Enable SentencePiece byte fallback. Recommended for code so unseen bytes "
238 "are represented as byte pieces instead of <|unk|>."
239 ),
240 )
241 parser.add_argument("--output-dir", default="custom_agentic_tokenizer_spm")
242 parser.add_argument(
243 "--input-sentence-size",
244 type=int,
245 help="Max sentence count for SPM trainer sampling. Uses profile default if omitted.",
246 )
247 parser.add_argument(
248 "--max-sentence-length",
249 type=int,
250 help="Max sentence length passed to SPM trainer. Uses profile default if omitted.",
251 )
252 parser.add_argument(
253 "--quota-total-lines",
254 type=int,
255 help="Total sampled corpus lines before SPM training. Per-domain quotas are derived from normalized weights.",
256 )
257 parser.add_argument(
258 "--quota-weights",
259 help=(
260 "Comma-separated domain weights (logic=0.2,code=0.4,tools=0.1,web=0.25,other=0.05). "
261 "Values are normalized automatically."
262 ),
263 )
264 parser.add_argument(
265 "--quota-overrides",
266 help="Comma-separated absolute per-domain quotas (e.g. code=400000,web=500000).",
267 )
268 parser.add_argument(
269 "--code-fidelity-mode",
270 action=argparse.BooleanOptionalAction,
271 default=False,
272 help=(
273 "Preserve code whitespace structure better by avoiding newline flattening in corpus "
274 "construction and disabling aggressive whitespace normalization in SPM."
275 ),
276 )
277 parser.add_argument(
278 "--deterministic",
279 action=argparse.BooleanOptionalAction,
280 default=False,
281 help=(
282 "Make SPM training deterministic/reproducible by disabling sentence shuffle, "
283 "using a single trainer thread, and sampling the full built corpus."
284 ),
285 )
286 return parser.parse_args()
289def _infer_domain(file_path):
290 normalized = file_path.replace("\\", "/").lower()
291 if "/logic/" in normalized:
292 return "logic"
293 if "/code/" in normalized:
294 return "code"
295 if "/tools/" in normalized:
296 return "tools"
297 if "/web/" in normalized:
298 return "web"
299 return "other"
302def _build_temp_corpus(profile, quotas, max_sentence_length, code_fidelity_mode=False):
303 domain_counts = defaultdict(int)
304 total = 0
306 temp = tempfile.NamedTemporaryFile(mode="w", delete=False, encoding="utf-8")
307 temp_path = temp.name
309 print(f"Building temporary corpus file: {temp_path}")
310 print(f"Profile: {profile} | domain quotas: {quotas}")
312 try:
313 for file_path in base.iter_data_files():
314 domain = _infer_domain(file_path)
315 if domain_counts[domain] >= quotas[domain]:
316 continue
318 for text in base.iter_file_texts(file_path):
319 if domain_counts[domain] >= quotas[domain]:
320 break
322 if code_fidelity_mode:
323 # Keep line structure/indentation signal for code-heavy text.
324 text = text.replace("\r\n", "\n").replace("\r", "\n")
325 if not text.strip(): 325 ↛ 326line 325 didn't jump to line 326 because the condition on line 325 was never true
326 continue
327 else:
328 text = text.replace("\n", " ").strip()
329 if not text:
330 continue
332 if max_sentence_length > 0:
333 text = text[:max_sentence_length]
335 temp.write(text)
336 temp.write("\n")
338 domain_counts[domain] += 1
339 total += 1
341 temp.flush()
342 print(f"Temporary corpus built with {total:,} lines.")
343 for domain in sorted(quotas):
344 print(f" {domain:>5}: {domain_counts[domain]:,} / {quotas[domain]:,}")
345 return temp_path, total, dict(domain_counts)
346 finally:
347 temp.close()
350def _write_run_manifest(
351 output_dir,
352 profile,
353 settings,
354 requested_input_sentence_size,
355 effective_input_sentence_size,
356 total_lines,
357 domain_counts,
358 args,
359):
360 manifest = {
361 "profile": profile,
362 "vocab_size": args.vocab_size,
363 "model_type": args.model_type,
364 "character_coverage": args.character_coverage,
365 "byte_fallback": args.byte_fallback,
366 "code_fidelity_mode": args.code_fidelity_mode,
367 "deterministic": args.deterministic,
368 "requested_input_sentence_size": requested_input_sentence_size,
369 "effective_input_sentence_size": effective_input_sentence_size,
370 "max_sentence_length": settings["max_sentence_length"],
371 "quota_total_lines": settings["quota_total_lines"],
372 "domain_weights": settings["domain_weights"],
373 "derived_quotas": settings["domain_quotas"],
374 "corpus_total_lines": total_lines,
375 "corpus_domain_counts": domain_counts,
376 "tokenizer_model_max_length": _resolve_model_max_length(),
377 }
378 manifest_path = Path(output_dir) / "training_manifest.json"
379 manifest_path.write_text(json.dumps(manifest, indent=2, sort_keys=True), encoding="utf-8")
380 print(f"Saved training manifest to ./{manifest_path}")
383def _resolve_model_max_length():
384 raw = os.getenv("TOKENIZER_MODEL_MAX_LENGTH")
385 if raw:
386 try:
387 parsed = int(raw)
388 if parsed > 0: 388 ↛ 392line 388 didn't jump to line 392 because the condition on line 388 was always true
389 return parsed
390 except ValueError:
391 pass
392 return CONTEXT_LENGTH
395def _export_hf_tokenizer(spm_model_path, output_dir, model_max_length=CONTEXT_LENGTH):
396 processor = spm.SentencePieceProcessor(model_file=spm_model_path)
398 try:
399 import sys
401 import sentencepiece.sentencepiece_model_pb2 as sp_pb2
402 from tokenizers import Tokenizer
403 from tokenizers import decoders
404 from tokenizers.decoders import Metaspace as MetaspaceDecoder
405 from tokenizers.implementations import SentencePieceUnigramTokenizer
406 from tokenizers.models import Unigram
407 from tokenizers.normalizers import NFKC
408 from tokenizers.pre_tokenizers import Metaspace
409 from transformers import PreTrainedTokenizerFast
410 except Exception as exc: # pragma: no cover
411 print(f"Skipping HuggingFace export (fast tokenizer dependencies unavailable): {exc}")
412 return
414 try:
415 # Preferred path: import directly from the .model to preserve SentencePiece behavior.
416 sys.modules.setdefault("sentencepiece_model_pb2", sp_pb2)
417 impl = SentencePieceUnigramTokenizer.from_spm(spm_model_path)
418 impl._tokenizer.decoder = decoders.Sequence(
419 [
420 decoders.ByteFallback(),
421 decoders.Metaspace(replacement="▁", prepend_scheme="always", split=True),
422 ]
423 )
425 tokenizer = PreTrainedTokenizerFast(
426 tokenizer_object=impl._tokenizer,
427 bos_token="<s>",
428 eos_token="<|eos|>",
429 unk_token="<|unk|>",
430 pad_token="<|pad|>",
431 model_max_length=model_max_length,
432 )
433 tokenizer.save_pretrained(output_dir)
434 print(f"Saved HuggingFace tokenizer files to ./{output_dir}")
435 return
436 except Exception as spm_exc:
437 print(f"Direct SentencePiece export path unavailable, falling back to manual conversion: {spm_exc}")
439 try:
440 byte_fallback = False
441 try:
442 from sentencepiece import sentencepiece_model_pb2 as sp_pb2
444 model_proto = sp_pb2.ModelProto()
445 with open(spm_model_path, "rb") as fh:
446 model_proto.ParseFromString(fh.read())
447 byte_fallback = bool(model_proto.trainer_spec.byte_fallback)
448 except Exception:
449 # Best-effort: if protobuf parsing is unavailable, keep legacy default.
450 pass
452 vocab = [(processor.id_to_piece(i), processor.get_score(i)) for i in range(processor.vocab_size())]
453 backend = Tokenizer(Unigram(vocab, unk_id=processor.unk_id(), byte_fallback=byte_fallback))
454 backend.normalizer = NFKC()
455 backend.pre_tokenizer = Metaspace(replacement="▁", prepend_scheme="first")
456 backend.decoder = decoders.Sequence(
457 [
458 decoders.ByteFallback(),
459 MetaspaceDecoder(replacement="▁", prepend_scheme="first"),
460 ]
461 )
463 tokenizer = PreTrainedTokenizerFast(
464 tokenizer_object=backend,
465 bos_token="<s>",
466 eos_token="<|eos|>",
467 unk_token="<|unk|>",
468 pad_token="<|pad|>",
469 model_max_length=model_max_length,
470 )
471 tokenizer.save_pretrained(output_dir)
472 print(f"Saved HuggingFace tokenizer files to ./{output_dir}")
473 except Exception as exc: # pragma: no cover
474 print(f"Error during HuggingFace tokenizer export: {exc}")
475 raise
478def main():
479 args = parse_args()
480 profile_registry = _load_profile_registry(args.profile_file)
481 profile = _resolve_profile(args.profile, profile_registry)
482 settings = _resolve_profile_settings(args, profile, profile_registry)
483 input_sentence_size = settings["input_sentence_size"]
484 max_sentence_length = settings["max_sentence_length"]
486 if args.input_sentence_size is None:
487 input_sentence_size = _auto_tune_input_sentence_size(profile, input_sentence_size)
489 output_dir = Path(args.output_dir)
490 output_dir.mkdir(parents=True, exist_ok=True)
492 corpus_path, total_lines, domain_counts = _build_temp_corpus(
493 profile=profile,
494 quotas=settings["domain_quotas"],
495 max_sentence_length=max_sentence_length,
496 code_fidelity_mode=args.code_fidelity_mode,
497 )
499 try:
500 if total_lines == 0:
501 raise RuntimeError("No training text found for SentencePiece corpus.")
503 requested_input_sentence_size = input_sentence_size
504 if args.deterministic:
505 effective_input_sentence_size = max(1, total_lines)
506 shuffle_input_sentence = False
507 num_threads = 1
508 else:
509 effective_input_sentence_size = input_sentence_size
510 shuffle_input_sentence = True
511 num_threads = max(os.cpu_count() or 1, 1)
513 model_prefix = output_dir / "spm"
515 print(
516 f"Training SentencePiece ({args.model_type}) with vocab={args.vocab_size:,}, "
517 f"input_sentence_size={effective_input_sentence_size:,}, max_sentence_length={max_sentence_length}, "
518 f"byte_fallback={args.byte_fallback}, code_fidelity_mode={args.code_fidelity_mode}, "
519 f"deterministic={args.deterministic}."
520 )
521 print(
522 "Quota design: total_lines="
523 f"{settings['quota_total_lines']:,}, normalized_weights={settings['domain_weights']}, "
524 f"derived_quotas={settings['domain_quotas']}"
525 )
527 spm.SentencePieceTrainer.train(
528 input=corpus_path,
529 model_prefix=str(model_prefix),
530 model_type=args.model_type,
531 vocab_size=args.vocab_size,
532 character_coverage=args.character_coverage,
533 input_sentence_size=effective_input_sentence_size,
534 shuffle_input_sentence=shuffle_input_sentence,
535 max_sentence_length=max_sentence_length,
536 pad_id=0,
537 unk_id=1,
538 bos_id=2,
539 eos_id=3,
540 pad_piece="<|pad|>",
541 unk_piece="<|unk|>",
542 bos_piece="<s>",
543 eos_piece="<|eos|>",
544 user_defined_symbols=["<|im_start|>", "<|im_end|>", "<think>", "</think>"],
545 byte_fallback=args.byte_fallback,
546 split_by_whitespace=not args.code_fidelity_mode,
547 remove_extra_whitespaces=not args.code_fidelity_mode,
548 num_threads=num_threads,
549 )
551 _write_run_manifest(
552 output_dir=str(output_dir),
553 profile=profile,
554 settings=settings,
555 requested_input_sentence_size=requested_input_sentence_size,
556 effective_input_sentence_size=effective_input_sentence_size,
557 total_lines=total_lines,
558 domain_counts=domain_counts,
559 args=args,
560 )
561 finally:
562 os.remove(corpus_path)
564 spm_model_path = str(model_prefix) + ".model"
565 _export_hf_tokenizer(
566 spm_model_path,
567 str(output_dir),
568 model_max_length=_resolve_model_max_length(),
569 )
571 print(f"Done. SentencePiece artifacts in ./{output_dir}")
574if __name__ == "__main__":
575 main()