Coverage for train_tokenizer.py: 73%

279 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-21 23:06 +0000

1# Copyright 2026 venim1103 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14 

15import collections 

16import json 

17import os 

18import re 

19import sys 

20 

21import pyarrow as pa 

22import pyarrow.parquet as pq 

23from datasets import load_dataset 

24from transformers import AutoTokenizer 

25 

26# ========================================== 

27# 1. Tokenizer Configuration 

28# ========================================== 

29# We use DeepSeek-Coder as the "template" because it contains highly optimized  

30# Regular Expressions (Regex) for splitting code, indentations, and whitespace properly. 

31# We are NOT keeping its vocabulary; we are just borrowing its splitting rules. 

32TEMPLATE_TOKENIZER = "deepseek-ai/deepseek-coder-1.3b-base" 

33 

34# The "Goldilocks" size for a 500M parameter model 

35VOCAB_SIZE = 64_000 

36 

37DATA_DIR = "local_data/train" 

38OUTPUT_DIR = "custom_agentic_tokenizer" 

39def _resolve_profile(): 

40 """Select tokenizer training profile from env or runtime platform.""" 

41 explicit = os.getenv("TOKENIZER_PROFILE", "").strip().lower() 

42 if explicit in {"kaggle", "standard"}: 

43 return explicit 

44 

45 if os.getenv("KAGGLE_KERNEL_RUN_TYPE"): 

46 return "kaggle" 

47 

48 return "standard" 

49 

50 

51PROFILE = _resolve_profile() 

52PROFILE_DEFAULTS = { 

53 "standard": { 

54 "max_batch_examples": "128", 

55 "max_batch_characters": "2000000", 

56 "parquet_batch_size": "256", 

57 "max_text_characters": "2000", 

58 "min_frequency": "3", 

59 "max_unique_words": "200000", 

60 "max_ram_gb": "30", 

61 }, 

62 "kaggle": { 

63 "max_batch_examples": "96", 

64 "max_batch_characters": "1200000", 

65 "parquet_batch_size": "192", 

66 "max_text_characters": "1600", 

67 "min_frequency": "3", 

68 "max_unique_words": "200000", 

69 "max_ram_gb": "13", 

70 }, 

71} 

72 

73 

74def _profile_default(name): 

75 return PROFILE_DEFAULTS[PROFILE][name] 

76 

77 

78MAX_BATCH_EXAMPLES = int(os.getenv("TOKENIZER_MAX_BATCH_EXAMPLES", _profile_default("max_batch_examples"))) 

79MAX_BATCH_CHARACTERS = int(os.getenv("TOKENIZER_MAX_BATCH_CHARACTERS", _profile_default("max_batch_characters"))) 

80PARQUET_BATCH_SIZE = int(os.getenv("TOKENIZER_PARQUET_BATCH_SIZE", _profile_default("parquet_batch_size"))) 

81MAX_TEXT_CHARACTERS = int(os.getenv("TOKENIZER_MAX_TEXT_CHARACTERS", _profile_default("max_text_characters"))) 

82MIN_FREQUENCY = int(os.getenv("TOKENIZER_MIN_FREQUENCY", _profile_default("min_frequency"))) 

83# Target unique word count that fits in RAM. Each unique word costs ~15-20 KB 

84# in the Rust BPE trainer (word string + character-level tokenisation + pair 

85# counts + priority queue entries). Both built-in profiles use conservative 

86# defaults because the Rust BPE merge phase can still OOM below 64 GB RAM. 

87MAX_UNIQUE_WORDS = int(os.getenv("TOKENIZER_MAX_UNIQUE_WORDS", _profile_default("max_unique_words"))) 

88 

89def _corpus_bytes_for_ram(ram_gb): 

90 """Return a generous corpus byte cap as a first-pass limit. 

91 

92 This is a coarse outer bound; the real memory control is 

93 MAX_UNIQUE_WORDS which is enforced by the two-pass approach. 

94 """ 

95 usable = max(ram_gb - 4, 0.5) 

96 return int(usable * 1_073_741_824) # 1x — generous, since pass-2 filters 

97 

98MAX_RAM_GB = float(os.getenv("TOKENIZER_MAX_RAM_GB", _profile_default("max_ram_gb"))) 

99MAX_CORPUS_BYTES = _corpus_bytes_for_ram(MAX_RAM_GB) 

100 

101 

102def _resolve_backend(profile, max_ram_gb): 

103 """Choose tokenizer backend. 

104 

105 - TOKENIZER_BACKEND=hf|spm forces a backend. 

106 - TOKENIZER_BACKEND=auto (or unset) auto-selects SentencePiece on lower-RAM 

107 setups where HF Rust BPE commonly OOMs. 

108 """ 

109 explicit = os.getenv("TOKENIZER_BACKEND", "auto").strip().lower() 

110 if explicit in {"hf", "spm"}: 

111 return explicit 

112 

113 if explicit not in {"", "auto"}: 113 ↛ 114line 113 didn't jump to line 114 because the condition on line 113 was never true

114 print(f"Unknown TOKENIZER_BACKEND={explicit!r}; falling back to auto.") 

115 

116 if profile == "kaggle" or max_ram_gb < 64: 

117 return "spm" 

118 return "hf" 

119 

120 

121BACKEND = _resolve_backend(PROFILE, MAX_RAM_GB) 

122 

123# These must be preserved as single tokens so the model can reason and format perfectly! 

124SPECIAL_TOKENS = [ 

125 "<|im_start|>", 

126 "<|im_end|>", 

127 "<think>", 

128 "</think>", 

129 "<|eos|>", 

130 "<|pad|>" 

131] 

132 

133SUPPORTED_SUFFIXES = (".jsonl", ".json", ".parquet") 

134TRAINING_PIECE_RE = re.compile(r"\s+|\w+|[^\w\s]", re.UNICODE) 

135 

136# ========================================== 

137# 2. RAM-Friendly Data Iterator 

138# ========================================== 

139def get_text_column(sample_row): 

140 """Automatically detects which column holds the text data.""" 

141 for name in ['text', 'content', 'trajectory', 'prompt', 'response']: 

142 if name in sample_row and isinstance(sample_row[name], str): 

143 return name 

144 for key, value in sample_row.items(): 

145 if isinstance(value, str): 

146 return key 

147 return None 

148 

149def maybe_trim_text(text): 

150 if not isinstance(text, str): 

151 return None 

152 

153 text = text.strip() 

154 if not text: 

155 return None 

156 

157 if MAX_TEXT_CHARACTERS > 0: 

158 return text[:MAX_TEXT_CHARACTERS] 

159 

160 return text 

161 

162def iter_data_files(): 

163 for root, dirs, files in os.walk(DATA_DIR): 

164 dirs.sort() 

165 for file_name in sorted(files): 

166 if file_name.endswith(SUPPORTED_SUFFIXES): 

167 yield os.path.join(root, file_name) 

168 

169def iter_text_from_rows(rows, source_name): 

170 text_col = None 

171 

172 for row in rows: 

173 if text_col is None: 

174 text_col = get_text_column(row) 

175 if not text_col: 

176 print(f" -> Skipping {source_name} (No text column found)") 

177 return 

178 

179 text = maybe_trim_text(row.get(text_col)) 

180 if text: 

181 yield text 

182 

183def iter_jsonl_rows(file_path): 

184 with open(file_path, "r", encoding="utf-8") as handle: 

185 for line_number, line in enumerate(handle, start=1): 

186 line = line.strip() 

187 if not line: 

188 continue 

189 

190 try: 

191 row = json.loads(line) 

192 except json.JSONDecodeError as exc: 

193 print(f" -> Skipping malformed JSON line {line_number} in {file_path}: {exc}") 

194 continue 

195 

196 if isinstance(row, dict): 

197 yield row 

198 

199def iter_json_texts(file_path): 

200 dataset = load_dataset("json", data_files=file_path, split="train", streaming=True) 

201 yield from iter_text_from_rows(dataset, os.path.basename(file_path)) 

202 

203def iter_parquet_texts(file_path): 

204 parquet_file = pq.ParquetFile(file_path) 

205 

206 # Detect the text column from the schema so we only read that single column, 

207 # avoiding materialising every column into Python memory. 

208 # Only consider columns whose Arrow type is string/large_string so the 

209 # fallback branch of get_text_column cannot accidentally pick an int column. 

210 string_col_names = [ 

211 field.name for field in parquet_file.schema_arrow 

212 if pa.types.is_string(field.type) or pa.types.is_large_string(field.type) 

213 ] 

214 sample_row = {name: "" for name in string_col_names} 

215 text_col = get_text_column(sample_row) 

216 if not text_col: 

217 print(f" -> Skipping {os.path.basename(file_path)} (No text column found)") 

218 return 

219 

220 for batch in parquet_file.iter_batches(batch_size=PARQUET_BATCH_SIZE, columns=[text_col]): 

221 for text in batch.column(text_col).to_pylist(): 

222 text = maybe_trim_text(text) 

223 if text: 

224 yield text 

225 

226def iter_file_texts(file_path): 

227 print(f"Streaming data from {file_path}...") 

228 

229 if file_path.endswith(".jsonl"): 

230 yield from iter_text_from_rows(iter_jsonl_rows(file_path), os.path.basename(file_path)) 

231 return 

232 

233 if file_path.endswith(".parquet"): 233 ↛ 237line 233 didn't jump to line 237 because the condition on line 233 was always true

234 yield from iter_parquet_texts(file_path) 

235 return 

236 

237 yield from iter_json_texts(file_path) 

238 

239def batch_iterator(max_batch_examples=MAX_BATCH_EXAMPLES, max_batch_characters=MAX_BATCH_CHARACTERS): 

240 """ 

241 Streams local data files with bounded buffering so RAM usage stays stable. 

242 Stops after MAX_CORPUS_BYTES total text has been fed to the trainer so 

243 the Rust BPE word-frequency map does not exhaust system memory. 

244 """ 

245 print(f"Scanning {DATA_DIR} for datasets...") 

246 corpus_limit = MAX_CORPUS_BYTES 

247 print(f"RAM budget: {MAX_RAM_GB:.0f} GB -> corpus cap: " 

248 f"{corpus_limit / 1_073_741_824:.1f} GB " 

249 f"(set TOKENIZER_MAX_RAM_GB to change)") 

250 

251 batch = [] 

252 batch_characters = 0 

253 total_bytes = 0 

254 

255 for file_path in iter_data_files(): 

256 for text in iter_file_texts(file_path): 

257 text_bytes = len(text.encode("utf-8", errors="replace")) 

258 batch.append(text) 

259 batch_characters += len(text) 

260 total_bytes += text_bytes 

261 

262 if len(batch) >= max_batch_examples or batch_characters >= max_batch_characters: 

263 yield batch 

264 batch = [] 

265 batch_characters = 0 

266 

267 if total_bytes >= corpus_limit: 

268 if batch: 268 ↛ 270line 268 didn't jump to line 270 because the condition on line 268 was always true

269 yield batch 

270 print(f"Reached corpus cap ({total_bytes / 1_073_741_824:.2f} GB). " 

271 f"Stopping data feed.") 

272 return 

273 

274 if batch: 

275 yield batch 

276 print(f"All data consumed ({total_bytes / 1_073_741_824:.2f} GB).") 

277 

278# ========================================== 

279# 3. Two-Pass Training 

280# ========================================== 

281def _prune_counter(word_counts, target_size): 

282 """Remove low-frequency words until the counter is at or below target_size.""" 

283 min_count = 2 

284 while len(word_counts) > target_size: 

285 word_counts = collections.Counter( 

286 {w: c for w, c in word_counts.items() if c >= min_count} 

287 ) 

288 min_count += 1 

289 return word_counts, min_count - 1 

290 

291 

292def _iter_training_pieces(text): 

293 """Yield normalized pieces used for pass-1 counting and pass-2 filtering. 

294 

295 We intentionally split more finely than the template pre-tokenizer because 

296 structured strings like JSON blobs can otherwise appear as one giant unique 

297 token, which both explodes memory and causes pass 2 to drop important 

298 punctuation wholesale. 

299 """ 

300 for match in TRAINING_PIECE_RE.finditer(text): 

301 piece = match.group(0) 

302 if piece.isspace(): 

303 yield "space", piece 

304 elif piece[0].isalnum() or piece[0] == "_": 

305 yield "word", piece 

306 else: 

307 yield "punct", piece 

308 

309 

310def _normalize_text_for_training(text, allowed_words): 

311 """Keep punctuation and frequent words while forcing smaller token units. 

312 

313 The emitted text is intentionally normalized with separator spaces around 

314 punctuation and kept words. This prevents the downstream pre-tokenizer from 

315 re-collapsing whole JSON/code snippets into single giant unique tokens. 

316 """ 

317 pieces = [] 

318 needs_separator = False 

319 

320 for kind, piece in _iter_training_pieces(text): 

321 if kind == "space": 

322 if not pieces or pieces[-1] != "\n": 322 ↛ 324line 322 didn't jump to line 324 because the condition on line 322 was always true

323 pieces.append("\n" if "\n" in piece else " ") 

324 needs_separator = False 

325 continue 

326 

327 if kind == "word" and piece not in allowed_words: 

328 if pieces and pieces[-1] != " ": 

329 pieces.append(" ") 

330 needs_separator = False 

331 continue 

332 

333 if needs_separator and pieces and not pieces[-1].isspace(): 

334 pieces.append(" ") 

335 

336 pieces.append(piece) 

337 needs_separator = True 

338 

339 normalized = "".join(pieces).strip() 

340 return normalized or None 

341 

342 

343def _count_word_frequencies(tokenizer): 

344 """Pass 1: stream all data, pre-tokenize, count word frequencies. 

345 

346 Uses the template tokenizer's pre-tokenizer (the DeepSeek code-regex 

347 rules) so the word splits match exactly what the BPE trainer will see. 

348 

349 To keep memory bounded, the counter is periodically pruned: when it 

350 exceeds 3× MAX_UNIQUE_WORDS, low-frequency words are evicted until 

351 it is back to 2× MAX_UNIQUE_WORDS. This means only genuinely 

352 frequent words survive, which is exactly what we want. 

353 """ 

354 word_counts = collections.Counter() 

355 total_bytes = 0 

356 corpus_limit = MAX_CORPUS_BYTES 

357 prune_trigger = MAX_UNIQUE_WORDS * 3 

358 prune_target = MAX_UNIQUE_WORDS * 2 

359 prune_count = 0 

360 

361 print(f"Pass 1 — counting word frequencies (corpus cap: " 

362 f"{corpus_limit / 1_073_741_824:.1f} GB, " 

363 f"prune at {prune_trigger:,} unique words)...") 

364 

365 for file_path in iter_data_files(): 

366 for text in iter_file_texts(file_path): 

367 text_bytes = len(text.encode("utf-8", errors="replace")) 

368 total_bytes += text_bytes 

369 

370 for kind, piece in _iter_training_pieces(text): 

371 if kind == "word": 

372 word_counts[piece] += 1 

373 

374 if len(word_counts) > prune_trigger: 

375 word_counts, threshold = _prune_counter(word_counts, prune_target) 

376 prune_count += 1 

377 print(f" Pruned to {len(word_counts):,} words " 

378 f"(dropped count<{threshold}, " 

379 f"{total_bytes / 1_073_741_824:.2f} GB so far)") 

380 

381 if total_bytes >= corpus_limit: 

382 print(f"Reached corpus cap ({total_bytes / 1_073_741_824:.2f} GB). " 

383 f"Stopping pass 1.") 

384 break 

385 else: 

386 continue 

387 break 

388 

389 print(f"Pass 1 done: {len(word_counts):,} unique words from " 

390 f"{total_bytes / 1_073_741_824:.2f} GB of text " 

391 f"({prune_count} prune cycles).") 

392 return word_counts 

393 

394 

395def _build_allowed_words(word_counts): 

396 """Select the top words by frequency, capped at MAX_UNIQUE_WORDS. 

397 

398 First filters by MIN_FREQUENCY, then takes the most common words up 

399 to MAX_UNIQUE_WORDS. Returns a frozenset for O(1) lookup in pass 2. 

400 """ 

401 frequent = {w for w, c in word_counts.items() if c >= MIN_FREQUENCY} 

402 print(f"After min_frequency={MIN_FREQUENCY} filter: {len(frequent):,} words " 

403 f"(dropped {len(word_counts) - len(frequent):,} rare words).") 

404 

405 if len(frequent) <= MAX_UNIQUE_WORDS: 

406 print(f"Within MAX_UNIQUE_WORDS={MAX_UNIQUE_WORDS:,} limit.") 

407 return frozenset(frequent) 

408 

409 # Keep only the most frequent words 

410 top_words = {w for w, _ in word_counts.most_common(MAX_UNIQUE_WORDS) if w in frequent} 

411 print(f"Trimmed to top {len(top_words):,} words by frequency.") 

412 return frozenset(top_words) 

413 

414 

415def _filtered_batch_iterator(tokenizer, allowed_words): 

416 """Pass 2: stream data again, but replace rare words with a space. 

417 

418 The BPE trainer will only see the allowed words, keeping unique-word 

419 count bounded regardless of corpus size or diversity. 

420 """ 

421 total_bytes = 0 

422 corpus_limit = MAX_CORPUS_BYTES 

423 

424 print(f"\nPass 2 — streaming filtered text to BPE trainer " 

425 f"({len(allowed_words):,} allowed words)...") 

426 

427 batch = [] 

428 batch_characters = 0 

429 

430 for file_path in iter_data_files(): 

431 for text in iter_file_texts(file_path): 

432 text_bytes = len(text.encode("utf-8", errors="replace")) 

433 total_bytes += text_bytes 

434 

435 filtered_text = _normalize_text_for_training(text, allowed_words) 

436 if not filtered_text: 

437 continue 

438 

439 batch.append(filtered_text) 

440 batch_characters += len(filtered_text) 

441 

442 if len(batch) >= MAX_BATCH_EXAMPLES or batch_characters >= MAX_BATCH_CHARACTERS: 

443 yield batch 

444 batch = [] 

445 batch_characters = 0 

446 

447 if total_bytes >= corpus_limit: 

448 if batch: 

449 yield batch 

450 print(f"Reached corpus cap ({total_bytes / 1_073_741_824:.2f} GB). " 

451 f"Stopping pass 2.") 

452 return 

453 

454 if batch: 

455 yield batch 

456 print(f"Pass 2 done ({total_bytes / 1_073_741_824:.2f} GB).") 

457 

458 

459def _run_sentencepiece_backend(): 

460 script_path = os.path.join(os.path.dirname(__file__), "train_tokenizer_spm.py") 

461 command = [ 

462 sys.executable, 

463 script_path, 

464 "--profile", 

465 PROFILE, 

466 "--vocab-size", 

467 str(VOCAB_SIZE), 

468 "--output-dir", 

469 OUTPUT_DIR, 

470 "--model-type", 

471 os.getenv("TOKENIZER_SPM_MODEL_TYPE", "unigram"), 

472 ] 

473 

474 input_sentence_size = os.getenv("TOKENIZER_SPM_INPUT_SENTENCE_SIZE") 

475 if input_sentence_size: 475 ↛ 478line 475 didn't jump to line 478 because the condition on line 475 was always true

476 command.extend(["--input-sentence-size", input_sentence_size]) 

477 

478 max_sentence_length = os.getenv("TOKENIZER_SPM_MAX_SENTENCE_LENGTH") 

479 if max_sentence_length: 479 ↛ 480line 479 didn't jump to line 480 because the condition on line 479 was never true

480 command.extend(["--max-sentence-length", max_sentence_length]) 

481 

482 code_fidelity_mode = os.getenv("TOKENIZER_SPM_CODE_FIDELITY") 

483 if code_fidelity_mode and code_fidelity_mode.strip().lower() in {"1", "true", "yes", "on"}: 483 ↛ 484line 483 didn't jump to line 484 because the condition on line 483 was never true

484 command.append("--code-fidelity-mode") 

485 

486 deterministic_mode = os.getenv("TOKENIZER_SPM_DETERMINISTIC") 

487 if deterministic_mode and deterministic_mode.strip().lower() in {"1", "true", "yes", "on"}: 487 ↛ 490line 487 didn't jump to line 490 because the condition on line 487 was always true

488 command.append("--deterministic") 

489 

490 print( 

491 "Tokenizer backend: SentencePiece " 

492 f"(selected via TOKENIZER_BACKEND={os.getenv('TOKENIZER_BACKEND', 'auto')})." 

493 ) 

494 print("Delegating to train_tokenizer_spm.py for low-RAM-safe training...") 

495 os.execv(sys.executable, command) 

496 

497 

498def _train_hf_backend(): 

499 print(f"Loading template tokenizer ({TEMPLATE_TOKENIZER})...") 

500 old_tokenizer = AutoTokenizer.from_pretrained(TEMPLATE_TOKENIZER) 

501 

502 # Pass 1: count word frequencies (streaming, bounded memory) 

503 word_counts = _count_word_frequencies(old_tokenizer) 

504 

505 # Build allowed word set 

506 allowed_words = _build_allowed_words(word_counts) 

507 

508 # Free the counter before pass 2 

509 del word_counts 

510 

511 print( 

512 f"\nTraining new vocabulary of size {VOCAB_SIZE}.\n" 

513 f"Profile: {PROFILE} | " 

514 f"Backend: hf | " 

515 f"RAM budget: {MAX_RAM_GB:.0f} GB | " 

516 f"max unique words: {MAX_UNIQUE_WORDS:,} | " 

517 f"min_frequency: {MIN_FREQUENCY}" 

518 ) 

519 

520 # Pass 2: train tokenizer on filtered text 

521 new_tokenizer = old_tokenizer.train_new_from_iterator( 

522 text_iterator=_filtered_batch_iterator(old_tokenizer, allowed_words), 

523 vocab_size=VOCAB_SIZE, 

524 new_special_tokens=SPECIAL_TOKENS, 

525 min_frequency=MIN_FREQUENCY, 

526 ) 

527 

528 # Set the pad and eos tokens 

529 new_tokenizer.pad_token = "<|pad|>" 

530 new_tokenizer.eos_token = "<|eos|>" 

531 

532 print(f"Saving custom tokenizer to ./{OUTPUT_DIR} ...") 

533 os.makedirs(OUTPUT_DIR, exist_ok=True) 

534 new_tokenizer.save_pretrained(OUTPUT_DIR) 

535 

536 print("Done! Your model's vocabulary is now perfectly mathematically tuned to your data.") 

537 

538 

539def main(): 

540 if BACKEND == "spm": 

541 _run_sentencepiece_backend() 

542 return 

543 

544 print( 

545 "Tokenizer backend: HuggingFace BPE " 

546 f"(selected via TOKENIZER_BACKEND={os.getenv('TOKENIZER_BACKEND', 'auto')})." 

547 ) 

548 _train_hf_backend() 

549 

550if __name__ == "__main__": 

551 main() 

552