Coverage for sft_data.py: 86%
150 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-21 23:06 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-21 23:06 +0000
1# Copyright 2026 venim1103
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
15import torch
16import os
17import random
18import re
19from torch.utils.data import Dataset, DataLoader, ConcatDataset
20from torch.utils.data.distributed import DistributedSampler
21import torch.distributed as dist
22from datasets import load_dataset
23from context_config import CONTEXT_LENGTH
26# ---------------------------------------------------------------------------
27# 3-Stage SFT Configuration (Llama-Nemotron Nano §4.2, Nanbeige4-3B §3.1-3.2)
28# ---------------------------------------------------------------------------
29# Stage 1 (Cold-start): reasoning-only data, multiple epochs, no toggle
30# Stage 2 (Mixed): reasoning + general chat, reasoning toggle training
31# Stage 3 (Polish): tool-use focus, low epoch count
32# ---------------------------------------------------------------------------
33SFT_STAGES = [
34 {
35 "name": "cold_start",
36 "paths": [
37 {"path": "local_data/sft/reasoning/open-math-reasoning", "format": "parquet"},
38 {"path": "local_data/sft/reasoning/nemotron-post-training", "format": "jsonl"},
39 {"path": "local_data/sft/reasoning/openr1-math", "format": "parquet"},
40 ],
41 "epochs": 4,
42 "lr": 1e-4,
43 "reasoning_off_prob": 0.0, # reasoning always ON
44 "max_seq_len": CONTEXT_LENGTH,
45 },
46 {
47 "name": "mixed",
48 "paths": [
49 {"path": "local_data/sft/reasoning/open-math-reasoning", "format": "parquet"},
50 {"path": "local_data/sft/reasoning/nemotron-post-training", "format": "jsonl"},
51 {"path": "local_data/sft/reasoning/openr1-math", "format": "parquet"},
52 {"path": "local_data/sft/mixed/smol-smoltalk", "format": "parquet"},
53 ],
54 "epochs": 2,
55 "lr": 5e-5,
56 "reasoning_off_prob": 0.3, # 30% reasoning toggle
57 "max_seq_len": CONTEXT_LENGTH,
58 },
59 {
60 "name": "polish",
61 "paths": [
62 {"path": "local_data/sft/tool_calling/apigen-fc", "format": "parquet"},
63 {"path": "local_data/sft/tool_calling/xlam-irrelevance", "format": "json"},
64 ],
65 "epochs": 2,
66 "lr": 2e-5,
67 "reasoning_off_prob": 0.1, # mostly reasoning ON for tool planning
68 "max_seq_len": CONTEXT_LENGTH,
69 },
70]
73def _first_token_id(tokenizer, text, fallback_text=None):
74 token_ids = tokenizer.encode(text, add_special_tokens=False)
75 if token_ids:
76 return token_ids[0]
78 if fallback_text is not None:
79 fallback_ids = tokenizer.encode(fallback_text, add_special_tokens=False)
80 if fallback_ids: 80 ↛ 83line 80 didn't jump to line 83 because the condition on line 80 was always true
81 return fallback_ids[0]
83 for attr in ("eos_token_id", "pad_token_id", "unk_token_id"):
84 token_id = getattr(tokenizer, attr, None)
85 if token_id is not None:
86 return token_id
88 return 0
91class SFTChatDataset(Dataset):
92 def __init__(self, data_path, tokenizer, max_seq_len=CONTEXT_LENGTH, reasoning_off_prob=0.3, format_hint=None):
93 super().__init__()
94 self.tokenizer = tokenizer
95 self.max_seq_len = max_seq_len
96 self.reasoning_off_prob = reasoning_off_prob
98 self.sys_reasoning_on = "You are a deductive reasoning agent. You must analyze the user's request step-by-step within <think> tags before acting."
99 self.sys_reasoning_off = "You are a direct, concise agent. Provide the final answer immediately without internal monologue."
101 if os.path.isdir(data_path):
102 files = []
103 for root, _, filenames in os.walk(data_path):
104 for f in filenames:
105 if f.endswith((".jsonl", ".json", ".parquet")):
106 files.append(os.path.join(root, f))
107 if not files:
108 raise FileNotFoundError(f"No .json/.jsonl/.parquet files found under {data_path}")
110 parquet_files = [f for f in files if f.endswith(".parquet")]
111 jsonl_files = [f for f in files if f.endswith(".jsonl")]
112 json_files = [f for f in files if f.endswith((".json", ".jsonl"))]
114 # Explicit per-source routing for mixed-format directories.
115 if format_hint == "parquet": 115 ↛ 116line 115 didn't jump to line 116 because the condition on line 115 was never true
116 if not parquet_files:
117 raise FileNotFoundError(f"No .parquet files found under {data_path}")
118 self.raw_data = load_dataset("parquet", data_files=parquet_files, split="train")
119 elif format_hint == "jsonl": 119 ↛ 120line 119 didn't jump to line 120 because the condition on line 119 was never true
120 if not jsonl_files:
121 raise FileNotFoundError(f"No .jsonl files found under {data_path}")
122 self.raw_data = load_dataset("json", data_files=jsonl_files, split="train")
123 elif format_hint == "json": 123 ↛ 128line 123 didn't jump to line 128 because the condition on line 123 was always true
124 json_only_files = [f for f in files if f.endswith(".json")]
125 if not json_only_files: 125 ↛ 126line 125 didn't jump to line 126 because the condition on line 125 was never true
126 raise FileNotFoundError(f"No .json files found under {data_path}")
127 self.raw_data = load_dataset("json", data_files=json_only_files, split="train")
128 elif parquet_files:
129 # Default behavior: prefer parquet when no format hint is provided.
130 self.raw_data = load_dataset("parquet", data_files=parquet_files, split="train")
131 else:
132 self.raw_data = load_dataset("json", data_files=json_files, split="train")
133 else:
134 if format_hint == "parquet": 134 ↛ 135line 134 didn't jump to line 135 because the condition on line 134 was never true
135 fmt = "parquet"
136 elif format_hint in ("json", "jsonl"): 136 ↛ 137line 136 didn't jump to line 137 because the condition on line 136 was never true
137 fmt = "json"
138 else:
139 fmt = "parquet" if data_path.endswith(".parquet") else "json"
140 self.raw_data = load_dataset(fmt, data_files=data_path, split="train")
142 self.im_start = _first_token_id(tokenizer, "<|im_start|>", fallback_text="<")
143 self.im_end = _first_token_id(tokenizer, "<|im_end|>", fallback_text=">")
144 self.nl = _first_token_id(tokenizer, "\n", fallback_text=" ")
146 def __len__(self): return len(self.raw_data) 146 ↛ exitline 146 didn't return from function '__len__' because the return on line 146 wasn't executed
148 def _row_to_messages(self, row):
149 """Auto-detect data format and convert to a list of chat messages."""
150 # Format 1: Chat — already has messages/conversations
151 if "messages" in row and row["messages"]:
152 return row["messages"]
153 if "conversations" in row and row["conversations"]:
154 return row["conversations"]
156 # Format 2: Math reasoning — problem + solution (OpenMathReasoning, OpenR1-Math)
157 problem = row.get("problem") or row.get("question") or row.get("prompt", "")
158 solution = row.get("generated_solution") or row.get("solution") or row.get("response", "")
159 if problem and solution:
160 return [
161 {"role": "user", "content": problem},
162 {"role": "assistant", "content": f"<think>\n{solution}\n</think>"},
163 ]
165 # Format 3: Tool-calling — query + tools + answers (xLAM, APIGen)
166 query = row.get("query", "")
167 tools = row.get("tools", "")
168 answers = row.get("answers", "")
169 if query and tools:
170 tool_desc = tools if isinstance(tools, str) else str(tools)
171 answer_str = answers if isinstance(answers, str) else str(answers)
172 return [
173 {"role": "system", "content": f"You have access to the following tools:\n{tool_desc}"},
174 {"role": "user", "content": query},
175 {"role": "assistant", "content": answer_str},
176 ]
178 # Fallback: try to find any text pair
179 for q_key in ["instruction", "input"]:
180 for a_key in ["output", "response"]:
181 if q_key in row and a_key in row:
182 return [
183 {"role": "user", "content": str(row[q_key])},
184 {"role": "assistant", "content": str(row[a_key])},
185 ]
186 return []
188 def __getitem__(self, idx):
189 row = self.raw_data[idx]
190 messages = self._row_to_messages(row)
191 if not messages:
192 # Return an empty-ish sample that the collate fn can pad
193 return torch.tensor([self.tokenizer.eos_token_id], dtype=torch.long), torch.tensor([-100], dtype=torch.long)
195 input_ids, labels = [], []
197 strip_reasoning = random.random() < self.reasoning_off_prob
199 # Only inject system prompt if the messages don't already start with one
200 # (tool-calling format injects its own system prompt with tool defs)
201 if not messages or messages[0].get("role") != "system": 201 ↛ 205line 201 didn't jump to line 205 because the condition on line 201 was always true
202 system_prompt = self.sys_reasoning_off if strip_reasoning else self.sys_reasoning_on
203 messages = [{"role": "system", "content": system_prompt}] + messages
205 for msg in messages:
206 role = msg.get("role", msg.get("from", ""))
207 content = msg.get("content", msg.get("value", ""))
208 if role in ["human", "user"]: role = "user"
209 if role in ["gpt", "assistant"]: role = "assistant"
211 if role == "assistant" and strip_reasoning:
212 content = re.sub(r'<think>.*?</think>\s*', '', content, flags=re.DOTALL).strip()
214 header_tokens = [self.im_start] + self.tokenizer.encode(role, add_special_tokens=False) + [self.nl]
215 content_tokens = self.tokenizer.encode(content, add_special_tokens=False) + [self.im_end, self.nl]
216 msg_tokens = header_tokens + content_tokens
217 input_ids.extend(msg_tokens)
219 if role == "user" or role == "system":
220 labels.extend([-100] * len(msg_tokens))
221 elif role == "assistant": 221 ↛ 205line 221 didn't jump to line 205 because the condition on line 221 was always true
222 labels.extend([-100] * len(header_tokens) + content_tokens)
224 input_ids = input_ids[:self.max_seq_len]
225 labels = labels[:self.max_seq_len]
226 return torch.tensor(input_ids, dtype=torch.long), torch.tensor(labels, dtype=torch.long)
229def sft_collate_fn(batch, pad_token_id=0):
230 """Dynamic padding to the longest sequence in the batch (not global max)."""
231 input_ids_list, labels_list = zip(*batch)
232 max_len = max(ids.size(0) for ids in input_ids_list)
234 padded_ids, padded_labels = [], []
235 for ids, lbl in zip(input_ids_list, labels_list):
236 pad_len = max_len - ids.size(0)
237 if pad_len > 0:
238 ids = torch.cat([ids, torch.full((pad_len,), pad_token_id, dtype=torch.long)])
239 lbl = torch.cat([lbl, torch.full((pad_len,), -100, dtype=torch.long)])
240 padded_ids.append(ids)
241 padded_labels.append(lbl)
243 return torch.stack(padded_ids), torch.stack(padded_labels)
246def create_sft_dataloader(data_paths, tokenizer, max_seq_len=CONTEXT_LENGTH, batch_size=2, reasoning_off_prob=0.3):
247 """Create a DataLoader from one or more data directories/files."""
248 if isinstance(data_paths, str): 248 ↛ 249line 248 didn't jump to line 249 because the condition on line 248 was never true
249 data_paths = [data_paths]
251 datasets = []
252 for source in data_paths:
253 if isinstance(source, dict): 253 ↛ 254line 253 didn't jump to line 254 because the condition on line 253 was never true
254 path = source["path"]
255 format_hint = source.get("format")
256 else:
257 path = source
258 format_hint = None
259 datasets.append(SFTChatDataset(path, tokenizer, max_seq_len, reasoning_off_prob, format_hint=format_hint))
261 import functools
262 pad_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0
263 collate = functools.partial(sft_collate_fn, pad_token_id=pad_id)
264 combined = ConcatDataset(datasets) if len(datasets) > 1 else datasets[0]
266 # Use DistributedSampler when running multi-GPU via torchrun
267 sampler = None
268 shuffle = True
269 if dist.is_initialized():
270 sampler = DistributedSampler(combined, shuffle=True)
271 shuffle = False # sampler handles shuffling
273 return DataLoader(
274 combined, batch_size=batch_size, shuffle=shuffle, sampler=sampler,
275 pin_memory=True, num_workers=4, collate_fn=collate
276 )