Coverage for compare_tokenizers.py: 93%
156 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-21 23:06 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-21 23:06 +0000
1# Copyright 2026 venim1103
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
15"""Tokenizer comparison and manual regression-gate utility.
17Quick start examples:
191) Basic comparison against local custom tokenizer:
20 /opt/conda/envs/ai/bin/python compare_tokenizers.py \
21 --custom-tokenizer custom_agentic_tokenizer
232) Save a JSON report for later inspection:
24 /opt/conda/envs/ai/bin/python compare_tokenizers.py \
25 --custom-tokenizer custom_agentic_tokenizer \
26 --report-json /tmp/tokenizer_report.json
283) Enable regression gates (exit code 1 on failure):
29 /opt/conda/envs/ai/bin/python compare_tokenizers.py \
30 --custom-tokenizer custom_agentic_tokenizer \
31 --max-average-ratio 1.10 \
32 --max-sample-ratio 1.40 \
33 --require-roundtrip
354) Exclude one or more samples during manual checks:
36 /opt/conda/envs/ai/bin/python compare_tokenizers.py \
37 --custom-tokenizer custom_agentic_tokenizer \
38 --exclude-sample python_code
40 /opt/conda/envs/ai/bin/python compare_tokenizers.py \
41 --custom-tokenizer custom_agentic_tokenizer \
42 --exclude-sample python_code,tool_calling
445) Use a custom sample set:
45 /opt/conda/envs/ai/bin/python compare_tokenizers.py \
46 --custom-tokenizer custom_agentic_tokenizer \
47 --samples-file tests/tokenizer_samples.json
496) Generic diagnostics gates (UNK/byte markers/normalized roundtrip):
50 /opt/conda/envs/ai/bin/python compare_tokenizers.py \
51 --custom-tokenizer custom_agentic_tokenizer \
52 --max-average-ratio 1.10 \
53 --max-unk-ratio 0.01 \
54 --max-byte-markers 0 \
55 --require-normalized-roundtrip \
56 --report-json /tmp/tokenizer_report_generic_checks.json
58Notes:
59- Ratios are custom/base token counts per sample.
60- Lower ratio is better (fewer custom tokens than base).
61- If all base token counts are zero, average ratio is unavailable.
62- Normalized roundtrip compares text after collapsing whitespace.
63- Byte markers refer to decoded literals like <0x0A>.
64"""
66import argparse
67import json
68import re
69import sys
70from pathlib import Path
71from statistics import mean
73from transformers import AutoTokenizer
76DEFAULT_BASE_TOKENIZER = "deepseek-ai/deepseek-coder-1.3b-base"
77DEFAULT_CUSTOM_TOKENIZER = "custom_agentic_tokenizer"
78BYTE_MARKER_RE = re.compile(r"<0x[0-9A-Fa-f]{2}>")
80DEFAULT_SAMPLES = [
81 {
82 "name": "python_code",
83 "text": (
84 "def binary_search(arr, target):\n"
85 " left, right = 0, len(arr) - 1\n"
86 " while left <= right:\n"
87 " mid = (left + right) // 2\n"
88 " if arr[mid] == target:\n"
89 " return mid\n"
90 " if arr[mid] < target:\n"
91 " left = mid + 1\n"
92 " else:\n"
93 " right = mid - 1\n"
94 " return -1"
95 ),
96 },
97 {
98 "name": "math_reasoning",
99 "text": (
100 "Solve the quadratic equation x^2 + 5x + 6 = 0. "
101 "We factor it as (x + 2)(x + 3) = 0, so the solutions are x = -2 and x = -3."
102 ),
103 },
104 {
105 "name": "tool_calling",
106 "text": (
107 "<|im_start|>assistant\n"
108 "<think>I should call the weather tool with the provided city and unit.</think>\n"
109 '{"name":"get_weather","arguments":{"city":"Berlin","unit":"celsius"}}'
110 "<|im_end|>"
111 ),
112 },
113 {
114 "name": "web_text",
115 "text": (
116 "A transformer language model predicts the next token in a sequence. "
117 "Training quality depends on data curation, tokenizer coverage, and "
118 "optimization stability across domains like code, math, and natural language."
119 ),
120 },
121 {
122 "name": "mixed_agentic",
123 "text": (
124 "User: write a Python function that parses JSON and retries on timeout. "
125 "Assistant: <think>I need robust error handling and exponential backoff.</think> "
126 "Sure, here is an implementation using try/except and time.sleep."
127 ),
128 },
129]
132def parse_args():
133 parser = argparse.ArgumentParser(
134 description="Compare token counts between two tokenizers on representative samples."
135 )
136 parser.add_argument(
137 "--base-tokenizer",
138 default=DEFAULT_BASE_TOKENIZER,
139 help="Base tokenizer model ID or local path.",
140 )
141 parser.add_argument(
142 "--custom-tokenizer",
143 default=DEFAULT_CUSTOM_TOKENIZER,
144 help="Custom tokenizer model ID or local path.",
145 )
146 parser.add_argument(
147 "--samples-file",
148 help=(
149 "Optional path to a JSON file containing a list of objects with "
150 "'name' and 'text' fields."
151 ),
152 )
153 parser.add_argument(
154 "--show-ids",
155 action="store_true",
156 help="Print the first 20 token IDs for each tokenizer per sample.",
157 )
158 parser.add_argument(
159 "--max-average-ratio",
160 type=float,
161 help=(
162 "Optional regression gate. Fail with exit code 1 if average "
163 "custom/base ratio exceeds this value."
164 ),
165 )
166 parser.add_argument(
167 "--max-sample-ratio",
168 type=float,
169 help=(
170 "Optional regression gate. Fail with exit code 1 if any sample "
171 "custom/base ratio exceeds this value."
172 ),
173 )
174 parser.add_argument(
175 "--require-roundtrip",
176 action="store_true",
177 help="Optional regression gate. Fail if any sample does not roundtrip exactly.",
178 )
179 parser.add_argument(
180 "--require-normalized-roundtrip",
181 action="store_true",
182 help=(
183 "Optional regression gate. Fail if any sample does not roundtrip after "
184 "collapsing whitespace differences."
185 ),
186 )
187 parser.add_argument(
188 "--max-unk-ratio",
189 type=float,
190 help=(
191 "Optional regression gate. Fail if custom unknown-token ratio for any "
192 "sample exceeds this threshold."
193 ),
194 )
195 parser.add_argument(
196 "--max-byte-markers",
197 type=int,
198 help=(
199 "Optional regression gate. Fail if decoded custom text contains more than "
200 "this many <0x..> byte-fallback markers for any sample."
201 ),
202 )
203 parser.add_argument(
204 "--report-json",
205 help="Optional output path for a JSON report with per-sample metrics.",
206 )
207 parser.add_argument(
208 "--exclude-sample",
209 action="append",
210 default=[],
211 help=(
212 "Exclude sample(s) by name. Can be repeated and/or passed as comma-separated "
213 "values, e.g. --exclude-sample python_code --exclude-sample web_text,tool_calling."
214 ),
215 )
216 return parser.parse_args()
219def load_samples(samples_file):
220 if not samples_file:
221 return DEFAULT_SAMPLES
223 path = Path(samples_file)
224 samples = json.loads(path.read_text(encoding="utf-8"))
225 if not isinstance(samples, list):
226 raise ValueError("Samples file must contain a JSON list.")
228 normalized = []
229 for index, sample in enumerate(samples, start=1):
230 if not isinstance(sample, dict):
231 raise ValueError(f"Sample {index} is not an object.")
232 name = sample.get("name")
233 text = sample.get("text")
234 if not isinstance(name, str) or not isinstance(text, str):
235 raise ValueError(f"Sample {index} must have string 'name' and 'text' fields.")
236 normalized.append({"name": name, "text": text})
237 return normalized
240def filter_samples(samples, excluded_names):
241 if not excluded_names:
242 return samples
244 # Support both repeated --exclude-sample flags and comma-separated values.
245 excluded = {
246 name.strip()
247 for raw in excluded_names
248 for name in raw.split(",")
249 if name.strip()
250 }
251 return [sample for sample in samples if sample["name"] not in excluded]
254def format_ratio(base_count, custom_count):
255 if base_count == 0:
256 return "n/a"
258 ratio = custom_count / base_count
259 if ratio <= 1:
260 return f"{ratio:.3f} ({(1 - ratio) * 100:.1f}% fewer tokens)"
261 return f"{ratio:.3f} ({(ratio - 1) * 100:.1f}% more tokens)"
264def normalize_whitespace(text):
265 return " ".join(text.split())
268def count_byte_markers(text):
269 return len(BYTE_MARKER_RE.findall(text))
272def compare_tokenizers(base_tokenizer, custom_tokenizer, samples, show_ids=False):
273 ratios = []
274 results = []
276 print("Comparison: base tokenizer vs custom tokenizer")
277 print(f"Base: {base_tokenizer.name_or_path}")
278 print(f"Custom: {custom_tokenizer.name_or_path}")
279 print()
281 custom_unk_id = getattr(custom_tokenizer, "unk_token_id", None)
283 for sample in samples:
284 name = sample["name"]
285 text = sample["text"]
286 base_ids = base_tokenizer.encode(text, add_special_tokens=False)
287 custom_ids = custom_tokenizer.encode(text, add_special_tokens=False)
288 decoded = custom_tokenizer.decode(custom_ids, skip_special_tokens=False)
289 roundtrip = decoded == text
290 normalized_roundtrip = normalize_whitespace(decoded) == normalize_whitespace(text)
291 custom_unk_count = 0
292 if custom_unk_id is not None: 292 ↛ 294line 292 didn't jump to line 294 because the condition on line 292 was always true
293 custom_unk_count = custom_ids.count(custom_unk_id)
294 custom_unk_ratio = (custom_unk_count / len(custom_ids)) if custom_ids else 0.0
295 byte_marker_count = count_byte_markers(decoded)
297 if base_ids:
298 ratios.append(len(custom_ids) / len(base_ids))
300 ratio = None
301 if base_ids:
302 ratio = len(custom_ids) / len(base_ids)
304 results.append(
305 {
306 "name": name,
307 "chars": len(text),
308 "base_tokens": len(base_ids),
309 "custom_tokens": len(custom_ids),
310 "ratio": ratio,
311 "roundtrip_exact": roundtrip,
312 "roundtrip_normalized": normalized_roundtrip,
313 "custom_unk_count": custom_unk_count,
314 "custom_unk_ratio": custom_unk_ratio,
315 "byte_marker_count": byte_marker_count,
316 "decoded_preview": decoded[:160],
317 }
318 )
320 print(f"{name}:")
321 print(f" chars: {len(text)}")
322 print(f" base: {len(base_ids):4d} tokens")
323 print(f" custom: {len(custom_ids):4d} tokens")
324 print(f" ratio: {format_ratio(len(base_ids), len(custom_ids))}")
325 print(f" roundtrip exact: {roundtrip}")
326 print(f" roundtrip normalized: {normalized_roundtrip}")
327 print(f" custom unk: {custom_unk_count} ({custom_unk_ratio:.3%})")
328 print(f" decoded byte markers: {byte_marker_count}")
329 if not roundtrip:
330 print(f" decoded preview: {decoded[:160]!r}")
331 if show_ids:
332 print(f" base ids: {base_ids[:20]}")
333 print(f" custom ids: {custom_ids[:20]}")
334 print()
336 average_ratio = None
337 if ratios:
338 average_ratio = mean(ratios)
339 print(f"Average custom/base ratio: {average_ratio:.3f}")
341 return {
342 "base_tokenizer": base_tokenizer.name_or_path,
343 "custom_tokenizer": custom_tokenizer.name_or_path,
344 "average_ratio": average_ratio,
345 "samples": results,
346 }
349def evaluate_regressions(
350 report,
351 max_average_ratio=None,
352 max_sample_ratio=None,
353 require_roundtrip=False,
354 require_normalized_roundtrip=False,
355 max_unk_ratio=None,
356 max_byte_markers=None,
357):
358 failures = []
360 if max_average_ratio is not None:
361 avg = report.get("average_ratio")
362 if avg is None: 362 ↛ 363line 362 didn't jump to line 363 because the condition on line 362 was never true
363 failures.append("average ratio unavailable (all base token counts were zero)")
364 elif avg > max_average_ratio:
365 failures.append(
366 f"average ratio {avg:.3f} exceeded max-average-ratio {max_average_ratio:.3f}"
367 )
369 if max_sample_ratio is not None:
370 for sample in report.get("samples", []):
371 ratio = sample.get("ratio")
372 if ratio is not None and ratio > max_sample_ratio:
373 failures.append(
374 f"sample '{sample['name']}' ratio {ratio:.3f} exceeded "
375 f"max-sample-ratio {max_sample_ratio:.3f}"
376 )
378 if require_roundtrip:
379 for sample in report.get("samples", []):
380 if not sample.get("roundtrip_exact", False):
381 failures.append(f"sample '{sample['name']}' failed exact roundtrip")
383 if require_normalized_roundtrip:
384 for sample in report.get("samples", []):
385 if not sample.get("roundtrip_normalized", False):
386 failures.append(f"sample '{sample['name']}' failed normalized roundtrip")
388 if max_unk_ratio is not None:
389 for sample in report.get("samples", []):
390 unk_ratio = sample.get("custom_unk_ratio")
391 if unk_ratio is not None and unk_ratio > max_unk_ratio:
392 failures.append(
393 f"sample '{sample['name']}' custom unk ratio {unk_ratio:.3f} exceeded "
394 f"max-unk-ratio {max_unk_ratio:.3f}"
395 )
397 if max_byte_markers is not None:
398 for sample in report.get("samples", []):
399 byte_markers = sample.get("byte_marker_count")
400 if byte_markers is not None and byte_markers > max_byte_markers:
401 failures.append(
402 f"sample '{sample['name']}' byte marker count {byte_markers} exceeded "
403 f"max-byte-markers {max_byte_markers}"
404 )
406 return failures
409def main():
410 args = parse_args()
411 samples = load_samples(args.samples_file)
412 samples = filter_samples(samples, args.exclude_sample)
413 if not samples: 413 ↛ 414line 413 didn't jump to line 414 because the condition on line 413 was never true
414 raise ValueError("No samples left after applying --exclude-sample filters.")
415 base_tokenizer = AutoTokenizer.from_pretrained(args.base_tokenizer)
416 custom_tokenizer = AutoTokenizer.from_pretrained(args.custom_tokenizer)
417 report = compare_tokenizers(
418 base_tokenizer=base_tokenizer,
419 custom_tokenizer=custom_tokenizer,
420 samples=samples,
421 show_ids=args.show_ids,
422 )
424 if args.report_json: 424 ↛ 425line 424 didn't jump to line 425 because the condition on line 424 was never true
425 Path(args.report_json).write_text(json.dumps(report, indent=2), encoding="utf-8")
426 print(f"Saved JSON report to {args.report_json}")
428 failures = evaluate_regressions(
429 report,
430 max_average_ratio=args.max_average_ratio,
431 max_sample_ratio=args.max_sample_ratio,
432 require_roundtrip=args.require_roundtrip,
433 require_normalized_roundtrip=args.require_normalized_roundtrip,
434 max_unk_ratio=args.max_unk_ratio,
435 max_byte_markers=args.max_byte_markers,
436 )
437 if failures: 437 ↛ 438line 437 didn't jump to line 438 because the condition on line 437 was never true
438 print("\nRegression check failed:")
439 for failure in failures:
440 print(f" - {failure}")
441 sys.exit(1)
444if __name__ == "__main__":
445 main()