Coverage for rl_train.py: 73%
203 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-21 23:06 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-21 23:06 +0000
1# Copyright 2026 venim1103
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
15import torch
16import torch.nn.functional as F
17import os
18import wandb
19from datasets import load_dataset
20from transformers import AutoTokenizer
21from model import BitMambaLLM, maybe_autocast
22from optim import setup_mamba_optimizers
24DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
25SFT_CKPT = "checkpoints/sft/sft_final.pt"
26LOCAL_RL_DATA = "local_data/rl/reasoning"
27CHECKPOINT_DIR = "checkpoints/rl"
28MODEL_CONFIG = dict(vocab_size=64000, dim=1024, n_layers=40, d_state=128, expand=2, use_checkpoint=True)
30BATCH_SIZE = 1
31GROUP_SIZE = 8 # Increased from 4 (Nanbeige4-3B / Llama-Nemotron recommend 8-16)
32TOTAL_STEPS = 10_000
33PEAK_LR = 1e-6
34FILTER_LOW = 0.10 # On-policy filtering: discard if pass_rate < 10%
35FILTER_HIGH = 0.90 # On-policy filtering: discard if pass_rate > 90%
36FILTER_BATCH = 64 # Number of problems to evaluate per filtering round
37MAX_GEN_TOKENS = 512
40def collect_data_files(root_dir):
41 files = []
42 for root, _, filenames in os.walk(root_dir, followlinks=True):
43 for name in filenames:
44 if name.endswith((".jsonl", ".json", ".parquet")): 44 ↛ 43line 44 didn't jump to line 43 because the condition on line 44 was always true
45 files.append(os.path.join(root, name))
46 return files
48# ---------------------------------------------------------------------------
49# Reward functions — separated into format and accuracy (Llama-Nemotron §5.1)
50# ---------------------------------------------------------------------------
52def compute_format_reward(completion):
53 """Check structural format: <think>...</think> followed by answer."""
54 thought_text, final_answer = _extract_thought_and_answer(completion)
55 if thought_text is None:
56 return 0.0
57 if final_answer:
58 return 1.0
59 return 0.5 # tags present but no answer after
62def compute_accuracy_reward(completion, ground_truth):
63 """Check if the extracted final answer contains the ground truth."""
64 _, final_answer = _extract_thought_and_answer(completion)
65 if not final_answer:
66 return 0.0
67 if ground_truth.lower() in final_answer.lower():
68 return 2.0
69 return 0.0
72def compute_conciseness_penalty(completion):
73 """Penalize verbose thinking relative to answer length."""
74 thought_text, final_answer = _extract_thought_and_answer(completion)
75 if not thought_text or not final_answer:
76 return 0.0
77 thought_ratio = len(thought_text) / max(1, len(final_answer))
78 if thought_ratio > 10.0:
79 return -0.5
80 return 0.0
83def _extract_thought_and_answer(completion):
84 """Extract thought and final answer from strict <think>...</think> format."""
85 if "<think>" not in completion:
86 return None, None
88 after_open = completion.split("<think>", 1)[1]
90 if "</think>" not in after_open:
91 return None, None
93 thought_text, answer_part = after_open.split("</think>", 1)
94 thought_text = thought_text.strip()
95 final_answer = answer_part.strip()
96 return thought_text, final_answer
99def compute_rewards(completions, ground_truth):
100 """Combined reward: format + accuracy + conciseness + length penalty."""
101 rewards = []
102 for comp in completions:
103 r_format = compute_format_reward(comp)
104 r_accuracy = compute_accuracy_reward(comp, ground_truth)
105 r_concise = compute_conciseness_penalty(comp)
106 r_length = -len(comp) * 0.0001
107 rewards.append(r_format + r_accuracy + r_concise + r_length)
108 return torch.tensor(rewards, dtype=torch.float32).to(DEVICE)
110# ---------------------------------------------------------------------------
111# On-policy difficulty filtering (Nanbeige4-3B §3.4)
112# ---------------------------------------------------------------------------
114@torch.no_grad()
115def filter_problems_on_policy(model, tokenizer, problems, eos_id):
116 """Pre-pass: compute per-problem pass rate, keep only those in [FILTER_LOW, FILTER_HIGH]."""
117 filtered = []
118 model.eval()
120 for sample in problems:
121 question = sample.get('problem', sample.get('question', ''))
122 ground_truth = str(sample.get('expected_answer', sample.get('answer', sample.get('solution', ''))))
124 # Use pre-computed pass_rate from dataset if available (OpenMathReasoning)
125 precomputed = sample.get('pass_rate_72b_tir', 'n/a')
126 if precomputed not in ('n/a', None, ''): 126 ↛ 127line 126 didn't jump to line 127 because the condition on line 126 was never true
127 try:
128 pass_rate = float(precomputed)
129 if FILTER_LOW <= pass_rate <= FILTER_HIGH:
130 sample['_pass_rate'] = pass_rate
131 filtered.append(sample)
132 continue
133 except (ValueError, TypeError):
134 pass
136 sys_prompt = "You are a deductive reasoning agent. You must analyze the user's request step-by-step within <think> tags before acting."
137 prompt = f"<|im_start|>system\n{sys_prompt}<|im_end|>\n<|im_start|>user\n{question}<|im_end|>\n<|im_start|>assistant\n<think>\n"
138 input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(DEVICE)
140 # Quick pass-rate estimate with GROUP_SIZE generations
141 n_correct = 0
142 for _ in range(GROUP_SIZE):
143 generated = model.generate(input_ids, max_new_tokens=MAX_GEN_TOKENS, temperature=0.8,
144 do_sample=True, eos_token_id=eos_id)
145 gen_text = tokenizer.decode(generated[0][input_ids.shape[1]:], skip_special_tokens=True)
146 if compute_accuracy_reward(gen_text, ground_truth) > 0:
147 n_correct += 1
149 pass_rate = n_correct / GROUP_SIZE
150 if FILTER_LOW <= pass_rate <= FILTER_HIGH: 150 ↛ 120line 150 didn't jump to line 120 because the condition on line 150 was always true
151 sample['_pass_rate'] = pass_rate
152 filtered.append(sample)
154 return filtered
157def run_rl_steps(model, tokenizer, dataset, optimizers, eos_id,
158 total_steps, checkpoint_dir, device):
159 """Inner GRPO training loop, decoupled from model loading and wandb init."""
160 muon_opt, adam_opt, mamba_opt = optimizers
161 data_iter = iter(dataset)
162 step = 0
164 while step < total_steps:
165 raw_batch = []
166 for _ in range(FILTER_BATCH):
167 try: raw_batch.append(next(data_iter))
168 except StopIteration:
169 data_iter = iter(dataset)
170 raw_batch.append(next(data_iter))
172 filtered = filter_problems_on_policy(model, tokenizer, raw_batch, eos_id)
173 if not filtered: 173 ↛ 174line 173 didn't jump to line 174 because the condition on line 173 was never true
174 print(f" Step {step}: filter returned 0 problems, retrying...")
175 continue
177 filtered.sort(key=lambda s: -s['_pass_rate'])
178 print(f" Step {step}: filtered {len(filtered)}/{FILTER_BATCH} problems (pass_rate range: "
179 f"{filtered[-1]['_pass_rate']:.0%}–{filtered[0]['_pass_rate']:.0%})")
181 for sample in filtered:
182 if step >= total_steps: 182 ↛ 183line 182 didn't jump to line 183 because the condition on line 182 was never true
183 break
185 question = sample.get('problem', sample.get('question', ''))
186 ground_truth = str(sample.get('expected_answer', sample.get('answer', sample.get('solution', ''))))
188 sys_prompt = "You are a deductive reasoning agent. You must analyze the user's request step-by-step within <think> tags before acting."
189 prompt = f"<|im_start|>system\n{sys_prompt}<|im_end|>\n<|im_start|>user\n{question}<|im_end|>\n<|im_start|>assistant\n<think>\n"
190 input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
192 if device.startswith("cuda"): 192 ↛ 193line 192 didn't jump to line 193 because the condition on line 192 was never true
193 for opt in [muon_opt, adam_opt, mamba_opt]:
194 for state in opt.state.values():
195 for k, v in state.items():
196 if isinstance(v, torch.Tensor): state[k] = v.cpu()
197 torch.cuda.empty_cache()
199 model.eval()
200 completions_ids, completions_text, old_log_probs_list = [], [], []
201 with torch.no_grad():
202 for _ in range(GROUP_SIZE):
203 generated = model.generate(input_ids, max_new_tokens=MAX_GEN_TOKENS, temperature=0.8,
204 do_sample=True, eos_token_id=eos_id)
205 gen_only = generated[0][input_ids.shape[1]:]
206 completions_ids.append(gen_only.cpu())
207 completions_text.append(tokenizer.decode(gen_only, skip_special_tokens=True))
209 full_seq = torch.cat([input_ids[0], gen_only]).unsqueeze(0)
210 with maybe_autocast(device):
211 hidden = model.forward_hidden(full_seq, seq_idx=None)
212 hidden_slice = hidden[0:1, input_ids.shape[1]-1:-1, :]
213 logits = model.output(hidden_slice)
214 log_probs = -F.cross_entropy(
215 logits[0, :, :].contiguous(),
216 gen_only.contiguous(), reduction='none'
217 )
218 old_log_probs_list.append(log_probs.detach())
220 rewards = compute_rewards(completions_text, ground_truth)
221 advantages = (rewards - rewards.mean()) / (rewards.std() + 1e-8)
223 if device.startswith("cuda"): 223 ↛ 224line 223 didn't jump to line 224 because the condition on line 223 was never true
224 for opt in [muon_opt, adam_opt, mamba_opt]:
225 for state in opt.state.values():
226 for k, v in state.items():
227 if isinstance(v, torch.Tensor): state[k] = v.to(device)
229 model.train()
230 for opt in [muon_opt, adam_opt, mamba_opt]: opt.zero_grad()
231 policy_loss = 0.0
232 EPS = 0.2
234 for i in range(GROUP_SIZE):
235 comp_ids = completions_ids[i].to(device)
236 old_log_probs = old_log_probs_list[i].to(device)
238 full_seq = torch.cat([input_ids[0], comp_ids]).unsqueeze(0)
239 with maybe_autocast(device):
240 hidden = model.forward_hidden(full_seq, seq_idx=None)
241 hidden_slice = hidden[0:1, input_ids.shape[1]-1:-1, :]
242 logits = model.output(hidden_slice)
243 log_probs = -F.cross_entropy(
244 logits[0, :, :].contiguous(),
245 comp_ids.contiguous(), reduction='none'
246 )
248 ratio = torch.exp(log_probs - old_log_probs)
249 surr1 = ratio * advantages[i]
250 surr2 = torch.clamp(ratio, 1.0 - EPS, 1.0 + EPS) * advantages[i]
251 loss = -torch.min(surr1, surr2).mean() / GROUP_SIZE
252 loss.backward()
253 policy_loss += loss.item()
255 torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
256 for opt in [muon_opt, adam_opt, mamba_opt]: opt.step()
258 if step % 10 == 0: 258 ↛ 265line 258 didn't jump to line 265 because the condition on line 258 was always true
259 wandb.log({
260 "RL/Mean_Reward": rewards.mean().item(),
261 "RL/Policy_Loss": policy_loss,
262 "RL/Format_Reward": sum(compute_format_reward(c) for c in completions_text) / len(completions_text),
263 "RL/Pass_Rate": sample['_pass_rate'],
264 }, step=step)
265 if step > 0 and step % 1000 == 0: 265 ↛ 266line 265 didn't jump to line 266 because the condition on line 265 was never true
266 torch.save({'step': step, 'model_state_dict': model.state_dict()},
267 os.path.join(checkpoint_dir, f"rl_step_{step:06d}.pt"))
268 step += 1
270 torch.save({'step': step, 'model_state_dict': model.state_dict()},
271 os.path.join(checkpoint_dir, "rl_final.pt"))
274def main():
275 global DEVICE
277 # RL is single-GPU only (GRPO generation is inherently sequential).
278 # Skip DDP init entirely to avoid collective-operation deadlocks.
279 rank = int(os.environ.get("RANK", "0"))
280 if rank != 0:
281 print(f"[rl_train] Rank {rank}: RL is single-GPU only, exiting.")
282 return
284 os.makedirs(CHECKPOINT_DIR, exist_ok=True)
285 wandb.init(project="Agentic-1.58b-Model", name="run-rl-grpo")
287 tokenizer = AutoTokenizer.from_pretrained("custom_agentic_tokenizer")
288 eos_id = tokenizer.encode("<|im_end|>", add_special_tokens=False)[0]
290 files = collect_data_files(LOCAL_RL_DATA)
291 if not files:
292 raise FileNotFoundError(f"No .json/.jsonl/.parquet files found under {LOCAL_RL_DATA}")
294 parquet_files = [f for f in files if f.endswith(".parquet")]
295 json_files = [f for f in files if f.endswith((".json", ".jsonl"))]
296 if parquet_files:
297 dataset = load_dataset("parquet", data_files=parquet_files, split="train", streaming=True)
298 else:
299 dataset = load_dataset("json", data_files=json_files, split="train", streaming=True)
300 data_iter = iter(dataset)
302 model = BitMambaLLM(**MODEL_CONFIG).to(DEVICE)
303 model.load_state_dict(torch.load(SFT_CKPT, map_location=DEVICE)['model_state_dict'])
304 muon_opt, adam_opt, mamba_opt = setup_mamba_optimizers(model, {"peak_lr": PEAK_LR, "end_lr": 1e-6})
306 run_rl_steps(
307 model=model,
308 tokenizer=tokenizer,
309 dataset=dataset,
310 optimizers=(muon_opt, adam_opt, mamba_opt),
311 eos_id=eos_id,
312 total_steps=TOTAL_STEPS,
313 checkpoint_dir=CHECKPOINT_DIR,
314 device=DEVICE,
315 )
317 wandb.finish()
320if __name__ == "__main__":
321 main()