Coverage for sft_data.py: 86%

150 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-21 23:06 +0000

1# Copyright 2026 venim1103 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14 

15import torch 

16import os 

17import random 

18import re 

19from torch.utils.data import Dataset, DataLoader, ConcatDataset 

20from torch.utils.data.distributed import DistributedSampler 

21import torch.distributed as dist 

22from datasets import load_dataset 

23from context_config import CONTEXT_LENGTH 

24 

25 

26# --------------------------------------------------------------------------- 

27# 3-Stage SFT Configuration (Llama-Nemotron Nano §4.2, Nanbeige4-3B §3.1-3.2) 

28# --------------------------------------------------------------------------- 

29# Stage 1 (Cold-start): reasoning-only data, multiple epochs, no toggle 

30# Stage 2 (Mixed): reasoning + general chat, reasoning toggle training 

31# Stage 3 (Polish): tool-use focus, low epoch count 

32# --------------------------------------------------------------------------- 

33SFT_STAGES = [ 

34 { 

35 "name": "cold_start", 

36 "paths": [ 

37 {"path": "local_data/sft/reasoning/open-math-reasoning", "format": "parquet"}, 

38 {"path": "local_data/sft/reasoning/nemotron-post-training", "format": "jsonl"}, 

39 {"path": "local_data/sft/reasoning/openr1-math", "format": "parquet"}, 

40 ], 

41 "epochs": 4, 

42 "lr": 1e-4, 

43 "reasoning_off_prob": 0.0, # reasoning always ON 

44 "max_seq_len": CONTEXT_LENGTH, 

45 }, 

46 { 

47 "name": "mixed", 

48 "paths": [ 

49 {"path": "local_data/sft/reasoning/open-math-reasoning", "format": "parquet"}, 

50 {"path": "local_data/sft/reasoning/nemotron-post-training", "format": "jsonl"}, 

51 {"path": "local_data/sft/reasoning/openr1-math", "format": "parquet"}, 

52 {"path": "local_data/sft/mixed/smol-smoltalk", "format": "parquet"}, 

53 ], 

54 "epochs": 2, 

55 "lr": 5e-5, 

56 "reasoning_off_prob": 0.3, # 30% reasoning toggle 

57 "max_seq_len": CONTEXT_LENGTH, 

58 }, 

59 { 

60 "name": "polish", 

61 "paths": [ 

62 {"path": "local_data/sft/tool_calling/apigen-fc", "format": "parquet"}, 

63 {"path": "local_data/sft/tool_calling/xlam-irrelevance", "format": "json"}, 

64 ], 

65 "epochs": 2, 

66 "lr": 2e-5, 

67 "reasoning_off_prob": 0.1, # mostly reasoning ON for tool planning 

68 "max_seq_len": CONTEXT_LENGTH, 

69 }, 

70] 

71 

72 

73def _first_token_id(tokenizer, text, fallback_text=None): 

74 token_ids = tokenizer.encode(text, add_special_tokens=False) 

75 if token_ids: 

76 return token_ids[0] 

77 

78 if fallback_text is not None: 

79 fallback_ids = tokenizer.encode(fallback_text, add_special_tokens=False) 

80 if fallback_ids: 80 ↛ 83line 80 didn't jump to line 83 because the condition on line 80 was always true

81 return fallback_ids[0] 

82 

83 for attr in ("eos_token_id", "pad_token_id", "unk_token_id"): 

84 token_id = getattr(tokenizer, attr, None) 

85 if token_id is not None: 

86 return token_id 

87 

88 return 0 

89 

90 

91class SFTChatDataset(Dataset): 

92 def __init__(self, data_path, tokenizer, max_seq_len=CONTEXT_LENGTH, reasoning_off_prob=0.3, format_hint=None): 

93 super().__init__() 

94 self.tokenizer = tokenizer 

95 self.max_seq_len = max_seq_len 

96 self.reasoning_off_prob = reasoning_off_prob 

97 

98 self.sys_reasoning_on = "You are a deductive reasoning agent. You must analyze the user's request step-by-step within <think> tags before acting." 

99 self.sys_reasoning_off = "You are a direct, concise agent. Provide the final answer immediately without internal monologue." 

100 

101 if os.path.isdir(data_path): 

102 files = [] 

103 for root, _, filenames in os.walk(data_path): 

104 for f in filenames: 

105 if f.endswith((".jsonl", ".json", ".parquet")): 

106 files.append(os.path.join(root, f)) 

107 if not files: 

108 raise FileNotFoundError(f"No .json/.jsonl/.parquet files found under {data_path}") 

109 

110 parquet_files = [f for f in files if f.endswith(".parquet")] 

111 jsonl_files = [f for f in files if f.endswith(".jsonl")] 

112 json_files = [f for f in files if f.endswith((".json", ".jsonl"))] 

113 

114 # Explicit per-source routing for mixed-format directories. 

115 if format_hint == "parquet": 115 ↛ 116line 115 didn't jump to line 116 because the condition on line 115 was never true

116 if not parquet_files: 

117 raise FileNotFoundError(f"No .parquet files found under {data_path}") 

118 self.raw_data = load_dataset("parquet", data_files=parquet_files, split="train") 

119 elif format_hint == "jsonl": 119 ↛ 120line 119 didn't jump to line 120 because the condition on line 119 was never true

120 if not jsonl_files: 

121 raise FileNotFoundError(f"No .jsonl files found under {data_path}") 

122 self.raw_data = load_dataset("json", data_files=jsonl_files, split="train") 

123 elif format_hint == "json": 123 ↛ 128line 123 didn't jump to line 128 because the condition on line 123 was always true

124 json_only_files = [f for f in files if f.endswith(".json")] 

125 if not json_only_files: 125 ↛ 126line 125 didn't jump to line 126 because the condition on line 125 was never true

126 raise FileNotFoundError(f"No .json files found under {data_path}") 

127 self.raw_data = load_dataset("json", data_files=json_only_files, split="train") 

128 elif parquet_files: 

129 # Default behavior: prefer parquet when no format hint is provided. 

130 self.raw_data = load_dataset("parquet", data_files=parquet_files, split="train") 

131 else: 

132 self.raw_data = load_dataset("json", data_files=json_files, split="train") 

133 else: 

134 if format_hint == "parquet": 134 ↛ 135line 134 didn't jump to line 135 because the condition on line 134 was never true

135 fmt = "parquet" 

136 elif format_hint in ("json", "jsonl"): 136 ↛ 137line 136 didn't jump to line 137 because the condition on line 136 was never true

137 fmt = "json" 

138 else: 

139 fmt = "parquet" if data_path.endswith(".parquet") else "json" 

140 self.raw_data = load_dataset(fmt, data_files=data_path, split="train") 

141 

142 self.im_start = _first_token_id(tokenizer, "<|im_start|>", fallback_text="<") 

143 self.im_end = _first_token_id(tokenizer, "<|im_end|>", fallback_text=">") 

144 self.nl = _first_token_id(tokenizer, "\n", fallback_text=" ") 

145 

146 def __len__(self): return len(self.raw_data) 146 ↛ exitline 146 didn't return from function '__len__' because the return on line 146 wasn't executed

147 

148 def _row_to_messages(self, row): 

149 """Auto-detect data format and convert to a list of chat messages.""" 

150 # Format 1: Chat — already has messages/conversations 

151 if "messages" in row and row["messages"]: 

152 return row["messages"] 

153 if "conversations" in row and row["conversations"]: 

154 return row["conversations"] 

155 

156 # Format 2: Math reasoning — problem + solution (OpenMathReasoning, OpenR1-Math) 

157 problem = row.get("problem") or row.get("question") or row.get("prompt", "") 

158 solution = row.get("generated_solution") or row.get("solution") or row.get("response", "") 

159 if problem and solution: 

160 return [ 

161 {"role": "user", "content": problem}, 

162 {"role": "assistant", "content": f"<think>\n{solution}\n</think>"}, 

163 ] 

164 

165 # Format 3: Tool-calling — query + tools + answers (xLAM, APIGen) 

166 query = row.get("query", "") 

167 tools = row.get("tools", "") 

168 answers = row.get("answers", "") 

169 if query and tools: 

170 tool_desc = tools if isinstance(tools, str) else str(tools) 

171 answer_str = answers if isinstance(answers, str) else str(answers) 

172 return [ 

173 {"role": "system", "content": f"You have access to the following tools:\n{tool_desc}"}, 

174 {"role": "user", "content": query}, 

175 {"role": "assistant", "content": answer_str}, 

176 ] 

177 

178 # Fallback: try to find any text pair 

179 for q_key in ["instruction", "input"]: 

180 for a_key in ["output", "response"]: 

181 if q_key in row and a_key in row: 

182 return [ 

183 {"role": "user", "content": str(row[q_key])}, 

184 {"role": "assistant", "content": str(row[a_key])}, 

185 ] 

186 return [] 

187 

188 def __getitem__(self, idx): 

189 row = self.raw_data[idx] 

190 messages = self._row_to_messages(row) 

191 if not messages: 

192 # Return an empty-ish sample that the collate fn can pad 

193 return torch.tensor([self.tokenizer.eos_token_id], dtype=torch.long), torch.tensor([-100], dtype=torch.long) 

194 

195 input_ids, labels = [], [] 

196 

197 strip_reasoning = random.random() < self.reasoning_off_prob 

198 

199 # Only inject system prompt if the messages don't already start with one 

200 # (tool-calling format injects its own system prompt with tool defs) 

201 if not messages or messages[0].get("role") != "system": 201 ↛ 205line 201 didn't jump to line 205 because the condition on line 201 was always true

202 system_prompt = self.sys_reasoning_off if strip_reasoning else self.sys_reasoning_on 

203 messages = [{"role": "system", "content": system_prompt}] + messages 

204 

205 for msg in messages: 

206 role = msg.get("role", msg.get("from", "")) 

207 content = msg.get("content", msg.get("value", "")) 

208 if role in ["human", "user"]: role = "user" 

209 if role in ["gpt", "assistant"]: role = "assistant" 

210 

211 if role == "assistant" and strip_reasoning: 

212 content = re.sub(r'<think>.*?</think>\s*', '', content, flags=re.DOTALL).strip() 

213 

214 header_tokens = [self.im_start] + self.tokenizer.encode(role, add_special_tokens=False) + [self.nl] 

215 content_tokens = self.tokenizer.encode(content, add_special_tokens=False) + [self.im_end, self.nl] 

216 msg_tokens = header_tokens + content_tokens 

217 input_ids.extend(msg_tokens) 

218 

219 if role == "user" or role == "system": 

220 labels.extend([-100] * len(msg_tokens)) 

221 elif role == "assistant": 221 ↛ 205line 221 didn't jump to line 205 because the condition on line 221 was always true

222 labels.extend([-100] * len(header_tokens) + content_tokens) 

223 

224 input_ids = input_ids[:self.max_seq_len] 

225 labels = labels[:self.max_seq_len] 

226 return torch.tensor(input_ids, dtype=torch.long), torch.tensor(labels, dtype=torch.long) 

227 

228 

229def sft_collate_fn(batch, pad_token_id=0): 

230 """Dynamic padding to the longest sequence in the batch (not global max).""" 

231 input_ids_list, labels_list = zip(*batch) 

232 max_len = max(ids.size(0) for ids in input_ids_list) 

233 

234 padded_ids, padded_labels = [], [] 

235 for ids, lbl in zip(input_ids_list, labels_list): 

236 pad_len = max_len - ids.size(0) 

237 if pad_len > 0: 

238 ids = torch.cat([ids, torch.full((pad_len,), pad_token_id, dtype=torch.long)]) 

239 lbl = torch.cat([lbl, torch.full((pad_len,), -100, dtype=torch.long)]) 

240 padded_ids.append(ids) 

241 padded_labels.append(lbl) 

242 

243 return torch.stack(padded_ids), torch.stack(padded_labels) 

244 

245 

246def create_sft_dataloader(data_paths, tokenizer, max_seq_len=CONTEXT_LENGTH, batch_size=2, reasoning_off_prob=0.3): 

247 """Create a DataLoader from one or more data directories/files.""" 

248 if isinstance(data_paths, str): 248 ↛ 249line 248 didn't jump to line 249 because the condition on line 248 was never true

249 data_paths = [data_paths] 

250 

251 datasets = [] 

252 for source in data_paths: 

253 if isinstance(source, dict): 253 ↛ 254line 253 didn't jump to line 254 because the condition on line 253 was never true

254 path = source["path"] 

255 format_hint = source.get("format") 

256 else: 

257 path = source 

258 format_hint = None 

259 datasets.append(SFTChatDataset(path, tokenizer, max_seq_len, reasoning_off_prob, format_hint=format_hint)) 

260 

261 import functools 

262 pad_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0 

263 collate = functools.partial(sft_collate_fn, pad_token_id=pad_id) 

264 combined = ConcatDataset(datasets) if len(datasets) > 1 else datasets[0] 

265 

266 # Use DistributedSampler when running multi-GPU via torchrun 

267 sampler = None 

268 shuffle = True 

269 if dist.is_initialized(): 

270 sampler = DistributedSampler(combined, shuffle=True) 

271 shuffle = False # sampler handles shuffling 

272 

273 return DataLoader( 

274 combined, batch_size=batch_size, shuffle=shuffle, sampler=sampler, 

275 pin_memory=True, num_workers=4, collate_fn=collate 

276 ) 

277