Coverage for data.py: 91%

112 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-21 23:06 +0000

1# Copyright 2026 venim1103 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14 

15import os 

16import torch 

17from torch.utils.data import IterableDataset, DataLoader 

18from datasets import load_dataset 

19from transformers import AutoTokenizer 

20import random 

21from context_config import CONTEXT_LENGTH 

22 

23def get_text_column(dataset_stream): 

24 sample_row = next(iter(dataset_stream)) 

25 for name in ['text', 'content', 'trajectory', 'prompt', 'response', 'question']: 

26 if name in sample_row and isinstance(sample_row[name], str): return name 

27 for key, value in sample_row.items(): 

28 if isinstance(value, str): return key 

29 raise ValueError("Could not detect text column.") 

30 

31def create_infinite_stream(hf_dataset): 

32 while True: 

33 for item in hf_dataset: yield item 

34 

35def extract_text_from_row(row): 

36 """Dynamically parses complex schemas into a clean training string.""" 

37 parts = [] 

38 if 'messages' in row and isinstance(row['messages'], list): 

39 for msg in row['messages']: 

40 if isinstance(msg, dict) and 'content' in msg: 40 ↛ 39line 40 didn't jump to line 39 because the condition on line 40 was always true

41 parts.append(f"{msg.get('role', 'user')}: {msg['content']}") 

42 

43 for key in ['problem', 'question', 'prompt', 'facts', 'hypothesis']: 

44 if key in row and isinstance(row[key], str): 

45 parts.append(f"{key.capitalize()}: {row[key]}") 

46 

47 if 'proofs' in row and isinstance(row['proofs'], list): 

48 parts.append(f"Proof: {' '.join(row['proofs'])}") 

49 

50 for key in ['solution', 'answer', 'response']: 

51 if key in row and isinstance(row[key], str): 

52 parts.append(f"{key.capitalize()}: {row[key]}") 

53 

54 for key in ['text', 'content', 'trajectory']: 

55 if key in row and isinstance(row[key], str) and not parts: 

56 parts.append(row[key]) 

57 

58 if parts: 

59 return "\n".join(parts) 

60 

61 # Absolute fallback 

62 best_val = "" 

63 for val in row.values(): 

64 if isinstance(val, str) and len(val) > len(best_val): 64 ↛ 63line 64 didn't jump to line 63 because the condition on line 64 was always true

65 best_val = val 

66 return best_val 

67 

68def packed_token_stream(dataset_stream, tokenizer, text_column, max_seq_len): 

69 buffer = [] 

70 doc_lengths = [] 

71 

72 for row in dataset_stream: 72 ↛ exitline 72 didn't return from function 'packed_token_stream' because the loop on line 72 didn't complete

73 # Dynamically extract all available text to handle complex schemas 

74 text = extract_text_from_row(row) 

75 if not text: continue 75 ↛ 72line 75 didn't jump to line 72 because the continue on line 75 wasn't executed

76 

77 tokens = tokenizer(text, add_special_tokens=False)["input_ids"] + [tokenizer.eos_token_id] 

78 buffer.extend(tokens) 

79 doc_lengths.append(len(tokens)) 

80 

81 while len(buffer) >= max_seq_len + 1: 

82 chunk = buffer[:max_seq_len + 1] 

83 cu_seqlens = [0] 

84 current_len = 0 

85 

86 while len(doc_lengths) > 0 and current_len + doc_lengths[0] <= max_seq_len: 

87 current_len += doc_lengths.pop(0) 

88 cu_seqlens.append(current_len) 

89 

90 if current_len < max_seq_len and len(doc_lengths) > 0: 

91 # Partial document fills the remainder of the max_seq_len window 

92 cu_seqlens.append(max_seq_len) 

93 remainder = max_seq_len - current_len 

94 doc_lengths[0] -= remainder 

95 

96 # Fix #6: The chunk consumes max_seq_len + 1 tokens from the buffer 

97 # (the extra +1 is the overlap token for x/y pair construction). 

98 # cu_seqlens only tracks max_seq_len tokens, so we must account for 

99 # the 1 extra token consumed from the trailing document. 

100 if len(doc_lengths) > 0: 100 ↛ 105line 100 didn't jump to line 105 because the condition on line 100 was always true

101 doc_lengths[0] -= 1 

102 if doc_lengths[0] <= 0: 

103 doc_lengths.pop(0) 

104 

105 buffer = buffer[max_seq_len + 1:] 

106 yield ( 

107 torch.tensor(chunk[:-1], dtype=torch.long), 

108 torch.tensor(chunk[1:], dtype=torch.long), 

109 torch.tensor(cu_seqlens, dtype=torch.int32) 

110 ) 

111 

112class AgenticDataMixture(IterableDataset): 

113 def __init__(self, streams_dict, target_proportions): 

114 super().__init__() 

115 self.stream_names = list(streams_dict.keys()) 

116 self.streams = streams_dict 

117 self.weights = target_proportions 

118 def __iter__(self): 

119 while True: 

120 name = random.choices(self.stream_names, weights=self.weights, k=1)[0] 

121 yield next(self.streams[name]) 

122 

123def packed_collate_fn(batch): 

124 """Custom collate that handles variable-length cu_seqlens across batch elements. 

125 

126 Each sample is (x, y, cu_seqlens) where x and y have fixed length but 

127 cu_seqlens varies. We pad cu_seqlens to the longest in the batch with -1 

128 sentinel values and return a lengths tensor so consumers know where the 

129 real values end. 

130 

131 Returns: 

132 x: (batch_size, seq_len) — LongTensor 

133 y: (batch_size, seq_len) — LongTensor 

134 cu_seqlens: (batch_size, max_n_segs) — Int32Tensor, padded with -1 

135 n_segs: (batch_size,) — Int32Tensor, real lengths 

136 """ 

137 xs, ys, cu_list = zip(*batch) 

138 x = torch.stack(xs, dim=0) 

139 y = torch.stack(ys, dim=0) 

140 

141 lengths = [cs.shape[0] for cs in cu_list] 

142 max_len = max(lengths) 

143 

144 padded = torch.full((len(cu_list), max_len), -1, dtype=torch.int32) 

145 for i, cs in enumerate(cu_list): 

146 padded[i, :cs.shape[0]] = cs 

147 

148 n_segs = torch.tensor(lengths, dtype=torch.int32) 

149 return x, y, padded, n_segs 

150 

151 

152def create_dataloaders(datasets_config, tokenizer_path="custom_agentic_tokenizer", max_seq_len=CONTEXT_LENGTH, batch_size=2): 

153 tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) 

154 # Suppress Hugging Face max-length warnings: we pack and slice long docs 

155 # manually into max_seq_len windows in packed_token_stream. 

156 tokenizer.model_max_length = int(1e9) 

157 packed_streams, weights = {}, [] 

158 

159 # Flatten config to handle subdirectories individually and prevent mixed schema CastErrors 

160 flat_config = [] 

161 for config in datasets_config: 

162 name, target_path, fmt, weight = config["name"], config["path"], config["format"], config["weight"] 

163 if os.path.isdir(target_path): 163 ↛ 165line 163 didn't jump to line 165 because the condition on line 163 was never true

164 # Only include subdirectories that actually contain files of the target format 

165 subdirs = [ 

166 d for d in os.listdir(target_path) 

167 if os.path.isdir(os.path.join(target_path, d)) 

168 and any(f.endswith(f'.{fmt}') for _, _, files in os.walk(os.path.join(target_path, d)) for f in files) 

169 ] 

170 if subdirs: 

171 sub_weight = weight / len(subdirs) 

172 for d in subdirs: 

173 flat_config.append({ 

174 "name": f"{name}_{d}", 

175 "path": os.path.join(target_path, d), 

176 "format": fmt, 

177 "weight": sub_weight 

178 }) 

179 continue 

180 flat_config.append(config) 

181 

182 for config in flat_config: 

183 name, target_path, fmt, weight = config["name"], config["path"], config["format"], config["weight"] 

184 data_files = os.path.join(target_path, f"**/*.{fmt}") if os.path.isdir(target_path) else target_path 

185 

186 raw_dataset = load_dataset(fmt, data_files=data_files, split='train', streaming=True) 

187 infinite_raw_stream = create_infinite_stream(raw_dataset) 

188 # Pass None for text_column since packed_token_stream now uses extract_text_from_row 

189 packed_streams[name] = iter(packed_token_stream(infinite_raw_stream, tokenizer, None, max_seq_len)) 

190 weights.append(weight) 

191 

192 mixture_dataset = AgenticDataMixture(packed_streams, weights) 

193 return DataLoader( 

194 mixture_dataset, batch_size=batch_size, num_workers=0, 

195 pin_memory=True, collate_fn=packed_collate_fn, 

196 ), tokenizer