Coverage for synth_data.py: 78%
82 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-21 23:06 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-21 23:06 +0000
1# Copyright 2026 venim1103
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
15"""Synthetic data generation pipeline (Nemotron-H §2.3, Nanbeige4-3B §2.2).
17Generates augmented training samples from raw source documents using
18the trained model itself as a rewriter. Five prompt strategies are
19supported (following Nemotron-H):
211. diverse_qa — Generate diverse QA pairs from a source passage.
222. distill — Rewrite into concise, information-dense form.
233. extract — Pull structured knowledge from unstructured text.
244. knowledge — Produce bulleted knowledge lists.
255. rephrase — Rewrite low-quality text in clean, encyclopedic style.
27Usage:
28 python synth_data.py --strategy diverse_qa --input local_data/train/web --output local_data/synth/web_qa
29 python synth_data.py --strategy distill --input local_data/train/code --output local_data/synth/code_distill
30"""
32import argparse
33import json
34import os
35import torch
36import torch.nn.functional as F
37from datasets import load_dataset
38from transformers import AutoTokenizer
40from model import BitMambaLLM
42DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
43DEFAULT_CKPT = "checkpoints/bitmamba_parent/step_1000000.pt"
44MODEL_CONFIG = dict(vocab_size=64000, dim=1024, n_layers=40, d_state=128, expand=2)
46# ---------------------------------------------------------------------------
47# Prompt templates for each strategy (Nemotron-H §2.3)
48# ---------------------------------------------------------------------------
50STRATEGY_PROMPTS = {
51 "diverse_qa": (
52 "<|im_start|>system\nYou are a meticulous educator. "
53 "Given a passage, create 3 diverse question-answer pairs that test different "
54 "levels of understanding (factual recall, inference, and application).<|im_end|>\n"
55 "<|im_start|>user\nPassage:\n{text}\n\n"
56 "Generate 3 QA pairs in the format:\nQ1: ...\nA1: ...\nQ2: ...\nA2: ...\nQ3: ...\nA3: ...<|im_end|>\n"
57 "<|im_start|>assistant\n"
58 ),
59 "distill": (
60 "<|im_start|>system\nYou are a technical writer who distills long documents into "
61 "concise, information-dense summaries. Preserve all key facts and relationships.<|im_end|>\n"
62 "<|im_start|>user\nRewrite the following text in a concise, clear form. "
63 "Keep all essential information but remove redundancy:\n\n{text}<|im_end|>\n"
64 "<|im_start|>assistant\n"
65 ),
66 "extract": (
67 "<|im_start|>system\nYou are a knowledge extraction engine. "
68 "Extract structured facts from unstructured text.<|im_end|>\n"
69 "<|im_start|>user\nExtract all key entities, relationships, and facts from this text "
70 "as a structured list:\n\n{text}<|im_end|>\n"
71 "<|im_start|>assistant\n"
72 ),
73 "knowledge": (
74 "<|im_start|>system\nYou are an encyclopedic knowledge organizer.<|im_end|>\n"
75 "<|im_start|>user\nConvert the following text into a bulleted knowledge list. "
76 "Each bullet should be a self-contained fact:\n\n{text}<|im_end|>\n"
77 "<|im_start|>assistant\n"
78 ),
79 "rephrase": (
80 "<|im_start|>system\nYou are a Wikipedia editor. Rewrite low-quality or informal text "
81 "into clean, encyclopedic prose. Fix grammar, improve clarity, and add structure.<|im_end|>\n"
82 "<|im_start|>user\nRewrite the following text in a clear, encyclopedic style:\n\n{text}<|im_end|>\n"
83 "<|im_start|>assistant\n"
84 ),
85}
88# ---------------------------------------------------------------------------
89# Generation — uses model.generate() for O(n) cached inference
90# ---------------------------------------------------------------------------
93def truncate_source(text, tokenizer, max_source_tokens=1024):
94 """Truncate source text to fit within the prompt budget."""
95 tokens = tokenizer.encode(text, add_special_tokens=False)
96 if len(tokens) > max_source_tokens:
97 tokens = tokens[:max_source_tokens]
98 text = tokenizer.decode(tokens, skip_special_tokens=True)
99 return text
102# ---------------------------------------------------------------------------
103# Data loading helpers
104# ---------------------------------------------------------------------------
106def iter_source_texts(input_path):
107 """Yield raw text strings from a directory of parquet/json/jsonl files."""
108 files = []
109 for root, _, filenames in os.walk(input_path):
110 for f in filenames:
111 if f.endswith((".parquet", ".json", ".jsonl")): 111 ↛ 110line 111 didn't jump to line 110 because the condition on line 111 was always true
112 files.append(os.path.join(root, f))
113 if not files: 113 ↛ 114line 113 didn't jump to line 114 because the condition on line 113 was never true
114 raise FileNotFoundError(f"No data files found under {input_path}")
116 parquet_files = [f for f in files if f.endswith(".parquet")]
117 json_files = [f for f in files if f.endswith((".json", ".jsonl"))]
119 if parquet_files: 119 ↛ 120line 119 didn't jump to line 120 because the condition on line 119 was never true
120 ds = load_dataset("parquet", data_files=parquet_files, split="train", streaming=True)
121 else:
122 ds = load_dataset("json", data_files=json_files, split="train", streaming=True)
124 for row in ds:
125 # Try common text column names
126 for col in ["text", "content", "passage", "question", "problem", "instruction"]: 126 ↛ 124line 126 didn't jump to line 124 because the loop on line 126 didn't complete
127 if col in row and isinstance(row[col], str) and len(row[col]) > 50: 127 ↛ 126line 127 didn't jump to line 126 because the condition on line 127 was always true
128 yield row[col]
129 break
132# ---------------------------------------------------------------------------
133# Main pipeline
134# ---------------------------------------------------------------------------
136def run_pipeline(args):
137 print(f"Strategy: {args.strategy}")
138 print(f"Input: {args.input}")
139 print(f"Output: {args.output}")
140 print(f"Samples: {args.num_samples}")
142 tokenizer = AutoTokenizer.from_pretrained("custom_agentic_tokenizer")
143 eos_id = tokenizer.encode("<|im_end|>", add_special_tokens=False)[0]
145 print("Loading model...")
146 model = BitMambaLLM(**MODEL_CONFIG).to(DEVICE)
147 ckpt = torch.load(args.checkpoint, map_location=DEVICE)
148 model.load_state_dict(ckpt["model_state_dict"])
149 model.eval()
151 prompt_template = STRATEGY_PROMPTS[args.strategy]
152 os.makedirs(args.output, exist_ok=True)
153 out_path = os.path.join(args.output, f"{args.strategy}.jsonl")
155 count = 0
156 with open(out_path, "w") as fout:
157 with torch.no_grad():
158 for text in iter_source_texts(args.input):
159 if count >= args.num_samples: 159 ↛ 160line 159 didn't jump to line 160 because the condition on line 159 was never true
160 break
162 text = truncate_source(text, tokenizer, max_source_tokens=args.max_source_tokens)
163 prompt = prompt_template.format(text=text)
164 input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(DEVICE)
166 if input_ids.shape[1] > 2048: 166 ↛ 167line 166 didn't jump to line 167 because the condition on line 166 was never true
167 continue # skip overly long prompts
169 output_ids = model.generate(
170 input_ids, max_new_tokens=args.max_new_tokens,
171 temperature=args.temperature, do_sample=(args.temperature > 0),
172 eos_token_id=eos_id,
173 )
174 gen_text = tokenizer.decode(output_ids[0, input_ids.shape[1]:], skip_special_tokens=True)
176 record = {
177 "strategy": args.strategy,
178 "source": text[:500], # truncated source for provenance
179 "generated": gen_text,
180 }
181 fout.write(json.dumps(record, ensure_ascii=False) + "\n")
182 count += 1
184 if count % 100 == 0: 184 ↛ 185line 184 didn't jump to line 185 because the condition on line 184 was never true
185 print(f" Generated {count}/{args.num_samples}")
187 print(f"Done. Wrote {count} samples to {out_path}")
190def main():
191 parser = argparse.ArgumentParser(description="Synthetic data generation pipeline")
192 parser.add_argument("--strategy", required=True, choices=list(STRATEGY_PROMPTS.keys()),
193 help="Generation strategy to use")
194 parser.add_argument("--input", required=True, help="Path to source data directory")
195 parser.add_argument("--output", required=True, help="Path to output directory")
196 parser.add_argument("--checkpoint", default=DEFAULT_CKPT, help="Model checkpoint path")
197 parser.add_argument("--num_samples", type=int, default=10000, help="Number of samples to generate")
198 parser.add_argument("--max_new_tokens", type=int, default=512, help="Max tokens per generation")
199 parser.add_argument("--max_source_tokens", type=int, default=1024, help="Max source tokens in prompt")
200 parser.add_argument("--temperature", type=float, default=0.7, help="Sampling temperature (0=greedy)")
201 args = parser.parse_args()
202 run_pipeline(args)
205if __name__ == "__main__":
206 main()