Coverage for train_tokenizer.py: 73%
279 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-21 23:06 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-21 23:06 +0000
1# Copyright 2026 venim1103
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
15import collections
16import json
17import os
18import re
19import sys
21import pyarrow as pa
22import pyarrow.parquet as pq
23from datasets import load_dataset
24from transformers import AutoTokenizer
26# ==========================================
27# 1. Tokenizer Configuration
28# ==========================================
29# We use DeepSeek-Coder as the "template" because it contains highly optimized
30# Regular Expressions (Regex) for splitting code, indentations, and whitespace properly.
31# We are NOT keeping its vocabulary; we are just borrowing its splitting rules.
32TEMPLATE_TOKENIZER = "deepseek-ai/deepseek-coder-1.3b-base"
34# The "Goldilocks" size for a 500M parameter model
35VOCAB_SIZE = 64_000
37DATA_DIR = "local_data/train"
38OUTPUT_DIR = "custom_agentic_tokenizer"
39def _resolve_profile():
40 """Select tokenizer training profile from env or runtime platform."""
41 explicit = os.getenv("TOKENIZER_PROFILE", "").strip().lower()
42 if explicit in {"kaggle", "standard"}:
43 return explicit
45 if os.getenv("KAGGLE_KERNEL_RUN_TYPE"):
46 return "kaggle"
48 return "standard"
51PROFILE = _resolve_profile()
52PROFILE_DEFAULTS = {
53 "standard": {
54 "max_batch_examples": "128",
55 "max_batch_characters": "2000000",
56 "parquet_batch_size": "256",
57 "max_text_characters": "2000",
58 "min_frequency": "3",
59 "max_unique_words": "200000",
60 "max_ram_gb": "30",
61 },
62 "kaggle": {
63 "max_batch_examples": "96",
64 "max_batch_characters": "1200000",
65 "parquet_batch_size": "192",
66 "max_text_characters": "1600",
67 "min_frequency": "3",
68 "max_unique_words": "200000",
69 "max_ram_gb": "13",
70 },
71}
74def _profile_default(name):
75 return PROFILE_DEFAULTS[PROFILE][name]
78MAX_BATCH_EXAMPLES = int(os.getenv("TOKENIZER_MAX_BATCH_EXAMPLES", _profile_default("max_batch_examples")))
79MAX_BATCH_CHARACTERS = int(os.getenv("TOKENIZER_MAX_BATCH_CHARACTERS", _profile_default("max_batch_characters")))
80PARQUET_BATCH_SIZE = int(os.getenv("TOKENIZER_PARQUET_BATCH_SIZE", _profile_default("parquet_batch_size")))
81MAX_TEXT_CHARACTERS = int(os.getenv("TOKENIZER_MAX_TEXT_CHARACTERS", _profile_default("max_text_characters")))
82MIN_FREQUENCY = int(os.getenv("TOKENIZER_MIN_FREQUENCY", _profile_default("min_frequency")))
83# Target unique word count that fits in RAM. Each unique word costs ~15-20 KB
84# in the Rust BPE trainer (word string + character-level tokenisation + pair
85# counts + priority queue entries). Both built-in profiles use conservative
86# defaults because the Rust BPE merge phase can still OOM below 64 GB RAM.
87MAX_UNIQUE_WORDS = int(os.getenv("TOKENIZER_MAX_UNIQUE_WORDS", _profile_default("max_unique_words")))
89def _corpus_bytes_for_ram(ram_gb):
90 """Return a generous corpus byte cap as a first-pass limit.
92 This is a coarse outer bound; the real memory control is
93 MAX_UNIQUE_WORDS which is enforced by the two-pass approach.
94 """
95 usable = max(ram_gb - 4, 0.5)
96 return int(usable * 1_073_741_824) # 1x — generous, since pass-2 filters
98MAX_RAM_GB = float(os.getenv("TOKENIZER_MAX_RAM_GB", _profile_default("max_ram_gb")))
99MAX_CORPUS_BYTES = _corpus_bytes_for_ram(MAX_RAM_GB)
102def _resolve_backend(profile, max_ram_gb):
103 """Choose tokenizer backend.
105 - TOKENIZER_BACKEND=hf|spm forces a backend.
106 - TOKENIZER_BACKEND=auto (or unset) auto-selects SentencePiece on lower-RAM
107 setups where HF Rust BPE commonly OOMs.
108 """
109 explicit = os.getenv("TOKENIZER_BACKEND", "auto").strip().lower()
110 if explicit in {"hf", "spm"}:
111 return explicit
113 if explicit not in {"", "auto"}: 113 ↛ 114line 113 didn't jump to line 114 because the condition on line 113 was never true
114 print(f"Unknown TOKENIZER_BACKEND={explicit!r}; falling back to auto.")
116 if profile == "kaggle" or max_ram_gb < 64:
117 return "spm"
118 return "hf"
121BACKEND = _resolve_backend(PROFILE, MAX_RAM_GB)
123# These must be preserved as single tokens so the model can reason and format perfectly!
124SPECIAL_TOKENS = [
125 "<|im_start|>",
126 "<|im_end|>",
127 "<think>",
128 "</think>",
129 "<|eos|>",
130 "<|pad|>"
131]
133SUPPORTED_SUFFIXES = (".jsonl", ".json", ".parquet")
134TRAINING_PIECE_RE = re.compile(r"\s+|\w+|[^\w\s]", re.UNICODE)
136# ==========================================
137# 2. RAM-Friendly Data Iterator
138# ==========================================
139def get_text_column(sample_row):
140 """Automatically detects which column holds the text data."""
141 for name in ['text', 'content', 'trajectory', 'prompt', 'response']:
142 if name in sample_row and isinstance(sample_row[name], str):
143 return name
144 for key, value in sample_row.items():
145 if isinstance(value, str):
146 return key
147 return None
149def maybe_trim_text(text):
150 if not isinstance(text, str):
151 return None
153 text = text.strip()
154 if not text:
155 return None
157 if MAX_TEXT_CHARACTERS > 0:
158 return text[:MAX_TEXT_CHARACTERS]
160 return text
162def iter_data_files():
163 for root, dirs, files in os.walk(DATA_DIR):
164 dirs.sort()
165 for file_name in sorted(files):
166 if file_name.endswith(SUPPORTED_SUFFIXES):
167 yield os.path.join(root, file_name)
169def iter_text_from_rows(rows, source_name):
170 text_col = None
172 for row in rows:
173 if text_col is None:
174 text_col = get_text_column(row)
175 if not text_col:
176 print(f" -> Skipping {source_name} (No text column found)")
177 return
179 text = maybe_trim_text(row.get(text_col))
180 if text:
181 yield text
183def iter_jsonl_rows(file_path):
184 with open(file_path, "r", encoding="utf-8") as handle:
185 for line_number, line in enumerate(handle, start=1):
186 line = line.strip()
187 if not line:
188 continue
190 try:
191 row = json.loads(line)
192 except json.JSONDecodeError as exc:
193 print(f" -> Skipping malformed JSON line {line_number} in {file_path}: {exc}")
194 continue
196 if isinstance(row, dict):
197 yield row
199def iter_json_texts(file_path):
200 dataset = load_dataset("json", data_files=file_path, split="train", streaming=True)
201 yield from iter_text_from_rows(dataset, os.path.basename(file_path))
203def iter_parquet_texts(file_path):
204 parquet_file = pq.ParquetFile(file_path)
206 # Detect the text column from the schema so we only read that single column,
207 # avoiding materialising every column into Python memory.
208 # Only consider columns whose Arrow type is string/large_string so the
209 # fallback branch of get_text_column cannot accidentally pick an int column.
210 string_col_names = [
211 field.name for field in parquet_file.schema_arrow
212 if pa.types.is_string(field.type) or pa.types.is_large_string(field.type)
213 ]
214 sample_row = {name: "" for name in string_col_names}
215 text_col = get_text_column(sample_row)
216 if not text_col:
217 print(f" -> Skipping {os.path.basename(file_path)} (No text column found)")
218 return
220 for batch in parquet_file.iter_batches(batch_size=PARQUET_BATCH_SIZE, columns=[text_col]):
221 for text in batch.column(text_col).to_pylist():
222 text = maybe_trim_text(text)
223 if text:
224 yield text
226def iter_file_texts(file_path):
227 print(f"Streaming data from {file_path}...")
229 if file_path.endswith(".jsonl"):
230 yield from iter_text_from_rows(iter_jsonl_rows(file_path), os.path.basename(file_path))
231 return
233 if file_path.endswith(".parquet"): 233 ↛ 237line 233 didn't jump to line 237 because the condition on line 233 was always true
234 yield from iter_parquet_texts(file_path)
235 return
237 yield from iter_json_texts(file_path)
239def batch_iterator(max_batch_examples=MAX_BATCH_EXAMPLES, max_batch_characters=MAX_BATCH_CHARACTERS):
240 """
241 Streams local data files with bounded buffering so RAM usage stays stable.
242 Stops after MAX_CORPUS_BYTES total text has been fed to the trainer so
243 the Rust BPE word-frequency map does not exhaust system memory.
244 """
245 print(f"Scanning {DATA_DIR} for datasets...")
246 corpus_limit = MAX_CORPUS_BYTES
247 print(f"RAM budget: {MAX_RAM_GB:.0f} GB -> corpus cap: "
248 f"{corpus_limit / 1_073_741_824:.1f} GB "
249 f"(set TOKENIZER_MAX_RAM_GB to change)")
251 batch = []
252 batch_characters = 0
253 total_bytes = 0
255 for file_path in iter_data_files():
256 for text in iter_file_texts(file_path):
257 text_bytes = len(text.encode("utf-8", errors="replace"))
258 batch.append(text)
259 batch_characters += len(text)
260 total_bytes += text_bytes
262 if len(batch) >= max_batch_examples or batch_characters >= max_batch_characters:
263 yield batch
264 batch = []
265 batch_characters = 0
267 if total_bytes >= corpus_limit:
268 if batch: 268 ↛ 270line 268 didn't jump to line 270 because the condition on line 268 was always true
269 yield batch
270 print(f"Reached corpus cap ({total_bytes / 1_073_741_824:.2f} GB). "
271 f"Stopping data feed.")
272 return
274 if batch:
275 yield batch
276 print(f"All data consumed ({total_bytes / 1_073_741_824:.2f} GB).")
278# ==========================================
279# 3. Two-Pass Training
280# ==========================================
281def _prune_counter(word_counts, target_size):
282 """Remove low-frequency words until the counter is at or below target_size."""
283 min_count = 2
284 while len(word_counts) > target_size:
285 word_counts = collections.Counter(
286 {w: c for w, c in word_counts.items() if c >= min_count}
287 )
288 min_count += 1
289 return word_counts, min_count - 1
292def _iter_training_pieces(text):
293 """Yield normalized pieces used for pass-1 counting and pass-2 filtering.
295 We intentionally split more finely than the template pre-tokenizer because
296 structured strings like JSON blobs can otherwise appear as one giant unique
297 token, which both explodes memory and causes pass 2 to drop important
298 punctuation wholesale.
299 """
300 for match in TRAINING_PIECE_RE.finditer(text):
301 piece = match.group(0)
302 if piece.isspace():
303 yield "space", piece
304 elif piece[0].isalnum() or piece[0] == "_":
305 yield "word", piece
306 else:
307 yield "punct", piece
310def _normalize_text_for_training(text, allowed_words):
311 """Keep punctuation and frequent words while forcing smaller token units.
313 The emitted text is intentionally normalized with separator spaces around
314 punctuation and kept words. This prevents the downstream pre-tokenizer from
315 re-collapsing whole JSON/code snippets into single giant unique tokens.
316 """
317 pieces = []
318 needs_separator = False
320 for kind, piece in _iter_training_pieces(text):
321 if kind == "space":
322 if not pieces or pieces[-1] != "\n": 322 ↛ 324line 322 didn't jump to line 324 because the condition on line 322 was always true
323 pieces.append("\n" if "\n" in piece else " ")
324 needs_separator = False
325 continue
327 if kind == "word" and piece not in allowed_words:
328 if pieces and pieces[-1] != " ":
329 pieces.append(" ")
330 needs_separator = False
331 continue
333 if needs_separator and pieces and not pieces[-1].isspace():
334 pieces.append(" ")
336 pieces.append(piece)
337 needs_separator = True
339 normalized = "".join(pieces).strip()
340 return normalized or None
343def _count_word_frequencies(tokenizer):
344 """Pass 1: stream all data, pre-tokenize, count word frequencies.
346 Uses the template tokenizer's pre-tokenizer (the DeepSeek code-regex
347 rules) so the word splits match exactly what the BPE trainer will see.
349 To keep memory bounded, the counter is periodically pruned: when it
350 exceeds 3× MAX_UNIQUE_WORDS, low-frequency words are evicted until
351 it is back to 2× MAX_UNIQUE_WORDS. This means only genuinely
352 frequent words survive, which is exactly what we want.
353 """
354 word_counts = collections.Counter()
355 total_bytes = 0
356 corpus_limit = MAX_CORPUS_BYTES
357 prune_trigger = MAX_UNIQUE_WORDS * 3
358 prune_target = MAX_UNIQUE_WORDS * 2
359 prune_count = 0
361 print(f"Pass 1 — counting word frequencies (corpus cap: "
362 f"{corpus_limit / 1_073_741_824:.1f} GB, "
363 f"prune at {prune_trigger:,} unique words)...")
365 for file_path in iter_data_files():
366 for text in iter_file_texts(file_path):
367 text_bytes = len(text.encode("utf-8", errors="replace"))
368 total_bytes += text_bytes
370 for kind, piece in _iter_training_pieces(text):
371 if kind == "word":
372 word_counts[piece] += 1
374 if len(word_counts) > prune_trigger:
375 word_counts, threshold = _prune_counter(word_counts, prune_target)
376 prune_count += 1
377 print(f" Pruned to {len(word_counts):,} words "
378 f"(dropped count<{threshold}, "
379 f"{total_bytes / 1_073_741_824:.2f} GB so far)")
381 if total_bytes >= corpus_limit:
382 print(f"Reached corpus cap ({total_bytes / 1_073_741_824:.2f} GB). "
383 f"Stopping pass 1.")
384 break
385 else:
386 continue
387 break
389 print(f"Pass 1 done: {len(word_counts):,} unique words from "
390 f"{total_bytes / 1_073_741_824:.2f} GB of text "
391 f"({prune_count} prune cycles).")
392 return word_counts
395def _build_allowed_words(word_counts):
396 """Select the top words by frequency, capped at MAX_UNIQUE_WORDS.
398 First filters by MIN_FREQUENCY, then takes the most common words up
399 to MAX_UNIQUE_WORDS. Returns a frozenset for O(1) lookup in pass 2.
400 """
401 frequent = {w for w, c in word_counts.items() if c >= MIN_FREQUENCY}
402 print(f"After min_frequency={MIN_FREQUENCY} filter: {len(frequent):,} words "
403 f"(dropped {len(word_counts) - len(frequent):,} rare words).")
405 if len(frequent) <= MAX_UNIQUE_WORDS:
406 print(f"Within MAX_UNIQUE_WORDS={MAX_UNIQUE_WORDS:,} limit.")
407 return frozenset(frequent)
409 # Keep only the most frequent words
410 top_words = {w for w, _ in word_counts.most_common(MAX_UNIQUE_WORDS) if w in frequent}
411 print(f"Trimmed to top {len(top_words):,} words by frequency.")
412 return frozenset(top_words)
415def _filtered_batch_iterator(tokenizer, allowed_words):
416 """Pass 2: stream data again, but replace rare words with a space.
418 The BPE trainer will only see the allowed words, keeping unique-word
419 count bounded regardless of corpus size or diversity.
420 """
421 total_bytes = 0
422 corpus_limit = MAX_CORPUS_BYTES
424 print(f"\nPass 2 — streaming filtered text to BPE trainer "
425 f"({len(allowed_words):,} allowed words)...")
427 batch = []
428 batch_characters = 0
430 for file_path in iter_data_files():
431 for text in iter_file_texts(file_path):
432 text_bytes = len(text.encode("utf-8", errors="replace"))
433 total_bytes += text_bytes
435 filtered_text = _normalize_text_for_training(text, allowed_words)
436 if not filtered_text:
437 continue
439 batch.append(filtered_text)
440 batch_characters += len(filtered_text)
442 if len(batch) >= MAX_BATCH_EXAMPLES or batch_characters >= MAX_BATCH_CHARACTERS:
443 yield batch
444 batch = []
445 batch_characters = 0
447 if total_bytes >= corpus_limit:
448 if batch:
449 yield batch
450 print(f"Reached corpus cap ({total_bytes / 1_073_741_824:.2f} GB). "
451 f"Stopping pass 2.")
452 return
454 if batch:
455 yield batch
456 print(f"Pass 2 done ({total_bytes / 1_073_741_824:.2f} GB).")
459def _run_sentencepiece_backend():
460 script_path = os.path.join(os.path.dirname(__file__), "train_tokenizer_spm.py")
461 command = [
462 sys.executable,
463 script_path,
464 "--profile",
465 PROFILE,
466 "--vocab-size",
467 str(VOCAB_SIZE),
468 "--output-dir",
469 OUTPUT_DIR,
470 "--model-type",
471 os.getenv("TOKENIZER_SPM_MODEL_TYPE", "unigram"),
472 ]
474 input_sentence_size = os.getenv("TOKENIZER_SPM_INPUT_SENTENCE_SIZE")
475 if input_sentence_size: 475 ↛ 478line 475 didn't jump to line 478 because the condition on line 475 was always true
476 command.extend(["--input-sentence-size", input_sentence_size])
478 max_sentence_length = os.getenv("TOKENIZER_SPM_MAX_SENTENCE_LENGTH")
479 if max_sentence_length: 479 ↛ 480line 479 didn't jump to line 480 because the condition on line 479 was never true
480 command.extend(["--max-sentence-length", max_sentence_length])
482 code_fidelity_mode = os.getenv("TOKENIZER_SPM_CODE_FIDELITY")
483 if code_fidelity_mode and code_fidelity_mode.strip().lower() in {"1", "true", "yes", "on"}: 483 ↛ 484line 483 didn't jump to line 484 because the condition on line 483 was never true
484 command.append("--code-fidelity-mode")
486 deterministic_mode = os.getenv("TOKENIZER_SPM_DETERMINISTIC")
487 if deterministic_mode and deterministic_mode.strip().lower() in {"1", "true", "yes", "on"}: 487 ↛ 490line 487 didn't jump to line 490 because the condition on line 487 was always true
488 command.append("--deterministic")
490 print(
491 "Tokenizer backend: SentencePiece "
492 f"(selected via TOKENIZER_BACKEND={os.getenv('TOKENIZER_BACKEND', 'auto')})."
493 )
494 print("Delegating to train_tokenizer_spm.py for low-RAM-safe training...")
495 os.execv(sys.executable, command)
498def _train_hf_backend():
499 print(f"Loading template tokenizer ({TEMPLATE_TOKENIZER})...")
500 old_tokenizer = AutoTokenizer.from_pretrained(TEMPLATE_TOKENIZER)
502 # Pass 1: count word frequencies (streaming, bounded memory)
503 word_counts = _count_word_frequencies(old_tokenizer)
505 # Build allowed word set
506 allowed_words = _build_allowed_words(word_counts)
508 # Free the counter before pass 2
509 del word_counts
511 print(
512 f"\nTraining new vocabulary of size {VOCAB_SIZE}.\n"
513 f"Profile: {PROFILE} | "
514 f"Backend: hf | "
515 f"RAM budget: {MAX_RAM_GB:.0f} GB | "
516 f"max unique words: {MAX_UNIQUE_WORDS:,} | "
517 f"min_frequency: {MIN_FREQUENCY}"
518 )
520 # Pass 2: train tokenizer on filtered text
521 new_tokenizer = old_tokenizer.train_new_from_iterator(
522 text_iterator=_filtered_batch_iterator(old_tokenizer, allowed_words),
523 vocab_size=VOCAB_SIZE,
524 new_special_tokens=SPECIAL_TOKENS,
525 min_frequency=MIN_FREQUENCY,
526 )
528 # Set the pad and eos tokens
529 new_tokenizer.pad_token = "<|pad|>"
530 new_tokenizer.eos_token = "<|eos|>"
532 print(f"Saving custom tokenizer to ./{OUTPUT_DIR} ...")
533 os.makedirs(OUTPUT_DIR, exist_ok=True)
534 new_tokenizer.save_pretrained(OUTPUT_DIR)
536 print("Done! Your model's vocabulary is now perfectly mathematically tuned to your data.")
539def main():
540 if BACKEND == "spm":
541 _run_sentencepiece_backend()
542 return
544 print(
545 "Tokenizer backend: HuggingFace BPE "
546 f"(selected via TOKENIZER_BACKEND={os.getenv('TOKENIZER_BACKEND', 'auto')})."
547 )
548 _train_hf_backend()
550if __name__ == "__main__":
551 main()