Coverage for data.py: 91%
112 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-21 23:06 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-21 23:06 +0000
1# Copyright 2026 venim1103
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
15import os
16import torch
17from torch.utils.data import IterableDataset, DataLoader
18from datasets import load_dataset
19from transformers import AutoTokenizer
20import random
21from context_config import CONTEXT_LENGTH
23def get_text_column(dataset_stream):
24 sample_row = next(iter(dataset_stream))
25 for name in ['text', 'content', 'trajectory', 'prompt', 'response', 'question']:
26 if name in sample_row and isinstance(sample_row[name], str): return name
27 for key, value in sample_row.items():
28 if isinstance(value, str): return key
29 raise ValueError("Could not detect text column.")
31def create_infinite_stream(hf_dataset):
32 while True:
33 for item in hf_dataset: yield item
35def extract_text_from_row(row):
36 """Dynamically parses complex schemas into a clean training string."""
37 parts = []
38 if 'messages' in row and isinstance(row['messages'], list):
39 for msg in row['messages']:
40 if isinstance(msg, dict) and 'content' in msg: 40 ↛ 39line 40 didn't jump to line 39 because the condition on line 40 was always true
41 parts.append(f"{msg.get('role', 'user')}: {msg['content']}")
43 for key in ['problem', 'question', 'prompt', 'facts', 'hypothesis']:
44 if key in row and isinstance(row[key], str):
45 parts.append(f"{key.capitalize()}: {row[key]}")
47 if 'proofs' in row and isinstance(row['proofs'], list):
48 parts.append(f"Proof: {' '.join(row['proofs'])}")
50 for key in ['solution', 'answer', 'response']:
51 if key in row and isinstance(row[key], str):
52 parts.append(f"{key.capitalize()}: {row[key]}")
54 for key in ['text', 'content', 'trajectory']:
55 if key in row and isinstance(row[key], str) and not parts:
56 parts.append(row[key])
58 if parts:
59 return "\n".join(parts)
61 # Absolute fallback
62 best_val = ""
63 for val in row.values():
64 if isinstance(val, str) and len(val) > len(best_val): 64 ↛ 63line 64 didn't jump to line 63 because the condition on line 64 was always true
65 best_val = val
66 return best_val
68def packed_token_stream(dataset_stream, tokenizer, text_column, max_seq_len):
69 buffer = []
70 doc_lengths = []
72 for row in dataset_stream: 72 ↛ exitline 72 didn't return from function 'packed_token_stream' because the loop on line 72 didn't complete
73 # Dynamically extract all available text to handle complex schemas
74 text = extract_text_from_row(row)
75 if not text: continue 75 ↛ 72line 75 didn't jump to line 72 because the continue on line 75 wasn't executed
77 tokens = tokenizer(text, add_special_tokens=False)["input_ids"] + [tokenizer.eos_token_id]
78 buffer.extend(tokens)
79 doc_lengths.append(len(tokens))
81 while len(buffer) >= max_seq_len + 1:
82 chunk = buffer[:max_seq_len + 1]
83 cu_seqlens = [0]
84 current_len = 0
86 while len(doc_lengths) > 0 and current_len + doc_lengths[0] <= max_seq_len:
87 current_len += doc_lengths.pop(0)
88 cu_seqlens.append(current_len)
90 if current_len < max_seq_len and len(doc_lengths) > 0:
91 # Partial document fills the remainder of the max_seq_len window
92 cu_seqlens.append(max_seq_len)
93 remainder = max_seq_len - current_len
94 doc_lengths[0] -= remainder
96 # Fix #6: The chunk consumes max_seq_len + 1 tokens from the buffer
97 # (the extra +1 is the overlap token for x/y pair construction).
98 # cu_seqlens only tracks max_seq_len tokens, so we must account for
99 # the 1 extra token consumed from the trailing document.
100 if len(doc_lengths) > 0: 100 ↛ 105line 100 didn't jump to line 105 because the condition on line 100 was always true
101 doc_lengths[0] -= 1
102 if doc_lengths[0] <= 0:
103 doc_lengths.pop(0)
105 buffer = buffer[max_seq_len + 1:]
106 yield (
107 torch.tensor(chunk[:-1], dtype=torch.long),
108 torch.tensor(chunk[1:], dtype=torch.long),
109 torch.tensor(cu_seqlens, dtype=torch.int32)
110 )
112class AgenticDataMixture(IterableDataset):
113 def __init__(self, streams_dict, target_proportions):
114 super().__init__()
115 self.stream_names = list(streams_dict.keys())
116 self.streams = streams_dict
117 self.weights = target_proportions
118 def __iter__(self):
119 while True:
120 name = random.choices(self.stream_names, weights=self.weights, k=1)[0]
121 yield next(self.streams[name])
123def packed_collate_fn(batch):
124 """Custom collate that handles variable-length cu_seqlens across batch elements.
126 Each sample is (x, y, cu_seqlens) where x and y have fixed length but
127 cu_seqlens varies. We pad cu_seqlens to the longest in the batch with -1
128 sentinel values and return a lengths tensor so consumers know where the
129 real values end.
131 Returns:
132 x: (batch_size, seq_len) — LongTensor
133 y: (batch_size, seq_len) — LongTensor
134 cu_seqlens: (batch_size, max_n_segs) — Int32Tensor, padded with -1
135 n_segs: (batch_size,) — Int32Tensor, real lengths
136 """
137 xs, ys, cu_list = zip(*batch)
138 x = torch.stack(xs, dim=0)
139 y = torch.stack(ys, dim=0)
141 lengths = [cs.shape[0] for cs in cu_list]
142 max_len = max(lengths)
144 padded = torch.full((len(cu_list), max_len), -1, dtype=torch.int32)
145 for i, cs in enumerate(cu_list):
146 padded[i, :cs.shape[0]] = cs
148 n_segs = torch.tensor(lengths, dtype=torch.int32)
149 return x, y, padded, n_segs
152def create_dataloaders(datasets_config, tokenizer_path="custom_agentic_tokenizer", max_seq_len=CONTEXT_LENGTH, batch_size=2):
153 tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
154 # Suppress Hugging Face max-length warnings: we pack and slice long docs
155 # manually into max_seq_len windows in packed_token_stream.
156 tokenizer.model_max_length = int(1e9)
157 packed_streams, weights = {}, []
159 # Flatten config to handle subdirectories individually and prevent mixed schema CastErrors
160 flat_config = []
161 for config in datasets_config:
162 name, target_path, fmt, weight = config["name"], config["path"], config["format"], config["weight"]
163 if os.path.isdir(target_path): 163 ↛ 165line 163 didn't jump to line 165 because the condition on line 163 was never true
164 # Only include subdirectories that actually contain files of the target format
165 subdirs = [
166 d for d in os.listdir(target_path)
167 if os.path.isdir(os.path.join(target_path, d))
168 and any(f.endswith(f'.{fmt}') for _, _, files in os.walk(os.path.join(target_path, d)) for f in files)
169 ]
170 if subdirs:
171 sub_weight = weight / len(subdirs)
172 for d in subdirs:
173 flat_config.append({
174 "name": f"{name}_{d}",
175 "path": os.path.join(target_path, d),
176 "format": fmt,
177 "weight": sub_weight
178 })
179 continue
180 flat_config.append(config)
182 for config in flat_config:
183 name, target_path, fmt, weight = config["name"], config["path"], config["format"], config["weight"]
184 data_files = os.path.join(target_path, f"**/*.{fmt}") if os.path.isdir(target_path) else target_path
186 raw_dataset = load_dataset(fmt, data_files=data_files, split='train', streaming=True)
187 infinite_raw_stream = create_infinite_stream(raw_dataset)
188 # Pass None for text_column since packed_token_stream now uses extract_text_from_row
189 packed_streams[name] = iter(packed_token_stream(infinite_raw_stream, tokenizer, None, max_seq_len))
190 weights.append(weight)
192 mixture_dataset = AgenticDataMixture(packed_streams, weights)
193 return DataLoader(
194 mixture_dataset, batch_size=batch_size, num_workers=0,
195 pin_memory=True, collate_fn=packed_collate_fn,
196 ), tokenizer