Coverage for synth_data.py: 78%

82 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-21 23:06 +0000

1# Copyright 2026 venim1103 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14 

15"""Synthetic data generation pipeline (Nemotron-H §2.3, Nanbeige4-3B §2.2). 

16 

17Generates augmented training samples from raw source documents using 

18the trained model itself as a rewriter. Five prompt strategies are 

19supported (following Nemotron-H): 

20 

211. diverse_qa — Generate diverse QA pairs from a source passage. 

222. distill — Rewrite into concise, information-dense form. 

233. extract — Pull structured knowledge from unstructured text. 

244. knowledge — Produce bulleted knowledge lists. 

255. rephrase — Rewrite low-quality text in clean, encyclopedic style. 

26 

27Usage: 

28 python synth_data.py --strategy diverse_qa --input local_data/train/web --output local_data/synth/web_qa 

29 python synth_data.py --strategy distill --input local_data/train/code --output local_data/synth/code_distill 

30""" 

31 

32import argparse 

33import json 

34import os 

35import torch 

36import torch.nn.functional as F 

37from datasets import load_dataset 

38from transformers import AutoTokenizer 

39 

40from model import BitMambaLLM 

41 

42DEVICE = "cuda" if torch.cuda.is_available() else "cpu" 

43DEFAULT_CKPT = "checkpoints/bitmamba_parent/step_1000000.pt" 

44MODEL_CONFIG = dict(vocab_size=64000, dim=1024, n_layers=40, d_state=128, expand=2) 

45 

46# --------------------------------------------------------------------------- 

47# Prompt templates for each strategy (Nemotron-H §2.3) 

48# --------------------------------------------------------------------------- 

49 

50STRATEGY_PROMPTS = { 

51 "diverse_qa": ( 

52 "<|im_start|>system\nYou are a meticulous educator. " 

53 "Given a passage, create 3 diverse question-answer pairs that test different " 

54 "levels of understanding (factual recall, inference, and application).<|im_end|>\n" 

55 "<|im_start|>user\nPassage:\n{text}\n\n" 

56 "Generate 3 QA pairs in the format:\nQ1: ...\nA1: ...\nQ2: ...\nA2: ...\nQ3: ...\nA3: ...<|im_end|>\n" 

57 "<|im_start|>assistant\n" 

58 ), 

59 "distill": ( 

60 "<|im_start|>system\nYou are a technical writer who distills long documents into " 

61 "concise, information-dense summaries. Preserve all key facts and relationships.<|im_end|>\n" 

62 "<|im_start|>user\nRewrite the following text in a concise, clear form. " 

63 "Keep all essential information but remove redundancy:\n\n{text}<|im_end|>\n" 

64 "<|im_start|>assistant\n" 

65 ), 

66 "extract": ( 

67 "<|im_start|>system\nYou are a knowledge extraction engine. " 

68 "Extract structured facts from unstructured text.<|im_end|>\n" 

69 "<|im_start|>user\nExtract all key entities, relationships, and facts from this text " 

70 "as a structured list:\n\n{text}<|im_end|>\n" 

71 "<|im_start|>assistant\n" 

72 ), 

73 "knowledge": ( 

74 "<|im_start|>system\nYou are an encyclopedic knowledge organizer.<|im_end|>\n" 

75 "<|im_start|>user\nConvert the following text into a bulleted knowledge list. " 

76 "Each bullet should be a self-contained fact:\n\n{text}<|im_end|>\n" 

77 "<|im_start|>assistant\n" 

78 ), 

79 "rephrase": ( 

80 "<|im_start|>system\nYou are a Wikipedia editor. Rewrite low-quality or informal text " 

81 "into clean, encyclopedic prose. Fix grammar, improve clarity, and add structure.<|im_end|>\n" 

82 "<|im_start|>user\nRewrite the following text in a clear, encyclopedic style:\n\n{text}<|im_end|>\n" 

83 "<|im_start|>assistant\n" 

84 ), 

85} 

86 

87 

88# --------------------------------------------------------------------------- 

89# Generation — uses model.generate() for O(n) cached inference 

90# --------------------------------------------------------------------------- 

91 

92 

93def truncate_source(text, tokenizer, max_source_tokens=1024): 

94 """Truncate source text to fit within the prompt budget.""" 

95 tokens = tokenizer.encode(text, add_special_tokens=False) 

96 if len(tokens) > max_source_tokens: 

97 tokens = tokens[:max_source_tokens] 

98 text = tokenizer.decode(tokens, skip_special_tokens=True) 

99 return text 

100 

101 

102# --------------------------------------------------------------------------- 

103# Data loading helpers 

104# --------------------------------------------------------------------------- 

105 

106def iter_source_texts(input_path): 

107 """Yield raw text strings from a directory of parquet/json/jsonl files.""" 

108 files = [] 

109 for root, _, filenames in os.walk(input_path): 

110 for f in filenames: 

111 if f.endswith((".parquet", ".json", ".jsonl")): 111 ↛ 110line 111 didn't jump to line 110 because the condition on line 111 was always true

112 files.append(os.path.join(root, f)) 

113 if not files: 113 ↛ 114line 113 didn't jump to line 114 because the condition on line 113 was never true

114 raise FileNotFoundError(f"No data files found under {input_path}") 

115 

116 parquet_files = [f for f in files if f.endswith(".parquet")] 

117 json_files = [f for f in files if f.endswith((".json", ".jsonl"))] 

118 

119 if parquet_files: 119 ↛ 120line 119 didn't jump to line 120 because the condition on line 119 was never true

120 ds = load_dataset("parquet", data_files=parquet_files, split="train", streaming=True) 

121 else: 

122 ds = load_dataset("json", data_files=json_files, split="train", streaming=True) 

123 

124 for row in ds: 

125 # Try common text column names 

126 for col in ["text", "content", "passage", "question", "problem", "instruction"]: 126 ↛ 124line 126 didn't jump to line 124 because the loop on line 126 didn't complete

127 if col in row and isinstance(row[col], str) and len(row[col]) > 50: 127 ↛ 126line 127 didn't jump to line 126 because the condition on line 127 was always true

128 yield row[col] 

129 break 

130 

131 

132# --------------------------------------------------------------------------- 

133# Main pipeline 

134# --------------------------------------------------------------------------- 

135 

136def run_pipeline(args): 

137 print(f"Strategy: {args.strategy}") 

138 print(f"Input: {args.input}") 

139 print(f"Output: {args.output}") 

140 print(f"Samples: {args.num_samples}") 

141 

142 tokenizer = AutoTokenizer.from_pretrained("custom_agentic_tokenizer") 

143 eos_id = tokenizer.encode("<|im_end|>", add_special_tokens=False)[0] 

144 

145 print("Loading model...") 

146 model = BitMambaLLM(**MODEL_CONFIG).to(DEVICE) 

147 ckpt = torch.load(args.checkpoint, map_location=DEVICE) 

148 model.load_state_dict(ckpt["model_state_dict"]) 

149 model.eval() 

150 

151 prompt_template = STRATEGY_PROMPTS[args.strategy] 

152 os.makedirs(args.output, exist_ok=True) 

153 out_path = os.path.join(args.output, f"{args.strategy}.jsonl") 

154 

155 count = 0 

156 with open(out_path, "w") as fout: 

157 with torch.no_grad(): 

158 for text in iter_source_texts(args.input): 

159 if count >= args.num_samples: 159 ↛ 160line 159 didn't jump to line 160 because the condition on line 159 was never true

160 break 

161 

162 text = truncate_source(text, tokenizer, max_source_tokens=args.max_source_tokens) 

163 prompt = prompt_template.format(text=text) 

164 input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(DEVICE) 

165 

166 if input_ids.shape[1] > 2048: 166 ↛ 167line 166 didn't jump to line 167 because the condition on line 166 was never true

167 continue # skip overly long prompts 

168 

169 output_ids = model.generate( 

170 input_ids, max_new_tokens=args.max_new_tokens, 

171 temperature=args.temperature, do_sample=(args.temperature > 0), 

172 eos_token_id=eos_id, 

173 ) 

174 gen_text = tokenizer.decode(output_ids[0, input_ids.shape[1]:], skip_special_tokens=True) 

175 

176 record = { 

177 "strategy": args.strategy, 

178 "source": text[:500], # truncated source for provenance 

179 "generated": gen_text, 

180 } 

181 fout.write(json.dumps(record, ensure_ascii=False) + "\n") 

182 count += 1 

183 

184 if count % 100 == 0: 184 ↛ 185line 184 didn't jump to line 185 because the condition on line 184 was never true

185 print(f" Generated {count}/{args.num_samples}") 

186 

187 print(f"Done. Wrote {count} samples to {out_path}") 

188 

189 

190def main(): 

191 parser = argparse.ArgumentParser(description="Synthetic data generation pipeline") 

192 parser.add_argument("--strategy", required=True, choices=list(STRATEGY_PROMPTS.keys()), 

193 help="Generation strategy to use") 

194 parser.add_argument("--input", required=True, help="Path to source data directory") 

195 parser.add_argument("--output", required=True, help="Path to output directory") 

196 parser.add_argument("--checkpoint", default=DEFAULT_CKPT, help="Model checkpoint path") 

197 parser.add_argument("--num_samples", type=int, default=10000, help="Number of samples to generate") 

198 parser.add_argument("--max_new_tokens", type=int, default=512, help="Max tokens per generation") 

199 parser.add_argument("--max_source_tokens", type=int, default=1024, help="Max source tokens in prompt") 

200 parser.add_argument("--temperature", type=float, default=0.7, help="Sampling temperature (0=greedy)") 

201 args = parser.parse_args() 

202 run_pipeline(args) 

203 

204 

205if __name__ == "__main__": 

206 main()