Coverage for rl_train.py: 73%

203 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-21 23:06 +0000

1# Copyright 2026 venim1103 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14 

15import torch 

16import torch.nn.functional as F 

17import os 

18import wandb 

19from datasets import load_dataset 

20from transformers import AutoTokenizer 

21from model import BitMambaLLM, maybe_autocast 

22from optim import setup_mamba_optimizers 

23 

24DEVICE = "cuda" if torch.cuda.is_available() else "cpu" 

25SFT_CKPT = "checkpoints/sft/sft_final.pt" 

26LOCAL_RL_DATA = "local_data/rl/reasoning" 

27CHECKPOINT_DIR = "checkpoints/rl" 

28MODEL_CONFIG = dict(vocab_size=64000, dim=1024, n_layers=40, d_state=128, expand=2, use_checkpoint=True) 

29 

30BATCH_SIZE = 1 

31GROUP_SIZE = 8 # Increased from 4 (Nanbeige4-3B / Llama-Nemotron recommend 8-16) 

32TOTAL_STEPS = 10_000 

33PEAK_LR = 1e-6 

34FILTER_LOW = 0.10 # On-policy filtering: discard if pass_rate < 10% 

35FILTER_HIGH = 0.90 # On-policy filtering: discard if pass_rate > 90% 

36FILTER_BATCH = 64 # Number of problems to evaluate per filtering round 

37MAX_GEN_TOKENS = 512 

38 

39 

40def collect_data_files(root_dir): 

41 files = [] 

42 for root, _, filenames in os.walk(root_dir, followlinks=True): 

43 for name in filenames: 

44 if name.endswith((".jsonl", ".json", ".parquet")): 44 ↛ 43line 44 didn't jump to line 43 because the condition on line 44 was always true

45 files.append(os.path.join(root, name)) 

46 return files 

47 

48# --------------------------------------------------------------------------- 

49# Reward functions — separated into format and accuracy (Llama-Nemotron §5.1) 

50# --------------------------------------------------------------------------- 

51 

52def compute_format_reward(completion): 

53 """Check structural format: <think>...</think> followed by answer.""" 

54 thought_text, final_answer = _extract_thought_and_answer(completion) 

55 if thought_text is None: 

56 return 0.0 

57 if final_answer: 

58 return 1.0 

59 return 0.5 # tags present but no answer after 

60 

61 

62def compute_accuracy_reward(completion, ground_truth): 

63 """Check if the extracted final answer contains the ground truth.""" 

64 _, final_answer = _extract_thought_and_answer(completion) 

65 if not final_answer: 

66 return 0.0 

67 if ground_truth.lower() in final_answer.lower(): 

68 return 2.0 

69 return 0.0 

70 

71 

72def compute_conciseness_penalty(completion): 

73 """Penalize verbose thinking relative to answer length.""" 

74 thought_text, final_answer = _extract_thought_and_answer(completion) 

75 if not thought_text or not final_answer: 

76 return 0.0 

77 thought_ratio = len(thought_text) / max(1, len(final_answer)) 

78 if thought_ratio > 10.0: 

79 return -0.5 

80 return 0.0 

81 

82 

83def _extract_thought_and_answer(completion): 

84 """Extract thought and final answer from strict <think>...</think> format.""" 

85 if "<think>" not in completion: 

86 return None, None 

87 

88 after_open = completion.split("<think>", 1)[1] 

89 

90 if "</think>" not in after_open: 

91 return None, None 

92 

93 thought_text, answer_part = after_open.split("</think>", 1) 

94 thought_text = thought_text.strip() 

95 final_answer = answer_part.strip() 

96 return thought_text, final_answer 

97 

98 

99def compute_rewards(completions, ground_truth): 

100 """Combined reward: format + accuracy + conciseness + length penalty.""" 

101 rewards = [] 

102 for comp in completions: 

103 r_format = compute_format_reward(comp) 

104 r_accuracy = compute_accuracy_reward(comp, ground_truth) 

105 r_concise = compute_conciseness_penalty(comp) 

106 r_length = -len(comp) * 0.0001 

107 rewards.append(r_format + r_accuracy + r_concise + r_length) 

108 return torch.tensor(rewards, dtype=torch.float32).to(DEVICE) 

109 

110# --------------------------------------------------------------------------- 

111# On-policy difficulty filtering (Nanbeige4-3B §3.4) 

112# --------------------------------------------------------------------------- 

113 

114@torch.no_grad() 

115def filter_problems_on_policy(model, tokenizer, problems, eos_id): 

116 """Pre-pass: compute per-problem pass rate, keep only those in [FILTER_LOW, FILTER_HIGH].""" 

117 filtered = [] 

118 model.eval() 

119 

120 for sample in problems: 

121 question = sample.get('problem', sample.get('question', '')) 

122 ground_truth = str(sample.get('expected_answer', sample.get('answer', sample.get('solution', '')))) 

123 

124 # Use pre-computed pass_rate from dataset if available (OpenMathReasoning) 

125 precomputed = sample.get('pass_rate_72b_tir', 'n/a') 

126 if precomputed not in ('n/a', None, ''): 126 ↛ 127line 126 didn't jump to line 127 because the condition on line 126 was never true

127 try: 

128 pass_rate = float(precomputed) 

129 if FILTER_LOW <= pass_rate <= FILTER_HIGH: 

130 sample['_pass_rate'] = pass_rate 

131 filtered.append(sample) 

132 continue 

133 except (ValueError, TypeError): 

134 pass 

135 

136 sys_prompt = "You are a deductive reasoning agent. You must analyze the user's request step-by-step within <think> tags before acting." 

137 prompt = f"<|im_start|>system\n{sys_prompt}<|im_end|>\n<|im_start|>user\n{question}<|im_end|>\n<|im_start|>assistant\n<think>\n" 

138 input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(DEVICE) 

139 

140 # Quick pass-rate estimate with GROUP_SIZE generations 

141 n_correct = 0 

142 for _ in range(GROUP_SIZE): 

143 generated = model.generate(input_ids, max_new_tokens=MAX_GEN_TOKENS, temperature=0.8, 

144 do_sample=True, eos_token_id=eos_id) 

145 gen_text = tokenizer.decode(generated[0][input_ids.shape[1]:], skip_special_tokens=True) 

146 if compute_accuracy_reward(gen_text, ground_truth) > 0: 

147 n_correct += 1 

148 

149 pass_rate = n_correct / GROUP_SIZE 

150 if FILTER_LOW <= pass_rate <= FILTER_HIGH: 150 ↛ 120line 150 didn't jump to line 120 because the condition on line 150 was always true

151 sample['_pass_rate'] = pass_rate 

152 filtered.append(sample) 

153 

154 return filtered 

155 

156 

157def run_rl_steps(model, tokenizer, dataset, optimizers, eos_id, 

158 total_steps, checkpoint_dir, device): 

159 """Inner GRPO training loop, decoupled from model loading and wandb init.""" 

160 muon_opt, adam_opt, mamba_opt = optimizers 

161 data_iter = iter(dataset) 

162 step = 0 

163 

164 while step < total_steps: 

165 raw_batch = [] 

166 for _ in range(FILTER_BATCH): 

167 try: raw_batch.append(next(data_iter)) 

168 except StopIteration: 

169 data_iter = iter(dataset) 

170 raw_batch.append(next(data_iter)) 

171 

172 filtered = filter_problems_on_policy(model, tokenizer, raw_batch, eos_id) 

173 if not filtered: 173 ↛ 174line 173 didn't jump to line 174 because the condition on line 173 was never true

174 print(f" Step {step}: filter returned 0 problems, retrying...") 

175 continue 

176 

177 filtered.sort(key=lambda s: -s['_pass_rate']) 

178 print(f" Step {step}: filtered {len(filtered)}/{FILTER_BATCH} problems (pass_rate range: " 

179 f"{filtered[-1]['_pass_rate']:.0%}–{filtered[0]['_pass_rate']:.0%})") 

180 

181 for sample in filtered: 

182 if step >= total_steps: 182 ↛ 183line 182 didn't jump to line 183 because the condition on line 182 was never true

183 break 

184 

185 question = sample.get('problem', sample.get('question', '')) 

186 ground_truth = str(sample.get('expected_answer', sample.get('answer', sample.get('solution', '')))) 

187 

188 sys_prompt = "You are a deductive reasoning agent. You must analyze the user's request step-by-step within <think> tags before acting." 

189 prompt = f"<|im_start|>system\n{sys_prompt}<|im_end|>\n<|im_start|>user\n{question}<|im_end|>\n<|im_start|>assistant\n<think>\n" 

190 input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device) 

191 

192 if device.startswith("cuda"): 192 ↛ 193line 192 didn't jump to line 193 because the condition on line 192 was never true

193 for opt in [muon_opt, adam_opt, mamba_opt]: 

194 for state in opt.state.values(): 

195 for k, v in state.items(): 

196 if isinstance(v, torch.Tensor): state[k] = v.cpu() 

197 torch.cuda.empty_cache() 

198 

199 model.eval() 

200 completions_ids, completions_text, old_log_probs_list = [], [], [] 

201 with torch.no_grad(): 

202 for _ in range(GROUP_SIZE): 

203 generated = model.generate(input_ids, max_new_tokens=MAX_GEN_TOKENS, temperature=0.8, 

204 do_sample=True, eos_token_id=eos_id) 

205 gen_only = generated[0][input_ids.shape[1]:] 

206 completions_ids.append(gen_only.cpu()) 

207 completions_text.append(tokenizer.decode(gen_only, skip_special_tokens=True)) 

208 

209 full_seq = torch.cat([input_ids[0], gen_only]).unsqueeze(0) 

210 with maybe_autocast(device): 

211 hidden = model.forward_hidden(full_seq, seq_idx=None) 

212 hidden_slice = hidden[0:1, input_ids.shape[1]-1:-1, :] 

213 logits = model.output(hidden_slice) 

214 log_probs = -F.cross_entropy( 

215 logits[0, :, :].contiguous(), 

216 gen_only.contiguous(), reduction='none' 

217 ) 

218 old_log_probs_list.append(log_probs.detach()) 

219 

220 rewards = compute_rewards(completions_text, ground_truth) 

221 advantages = (rewards - rewards.mean()) / (rewards.std() + 1e-8) 

222 

223 if device.startswith("cuda"): 223 ↛ 224line 223 didn't jump to line 224 because the condition on line 223 was never true

224 for opt in [muon_opt, adam_opt, mamba_opt]: 

225 for state in opt.state.values(): 

226 for k, v in state.items(): 

227 if isinstance(v, torch.Tensor): state[k] = v.to(device) 

228 

229 model.train() 

230 for opt in [muon_opt, adam_opt, mamba_opt]: opt.zero_grad() 

231 policy_loss = 0.0 

232 EPS = 0.2 

233 

234 for i in range(GROUP_SIZE): 

235 comp_ids = completions_ids[i].to(device) 

236 old_log_probs = old_log_probs_list[i].to(device) 

237 

238 full_seq = torch.cat([input_ids[0], comp_ids]).unsqueeze(0) 

239 with maybe_autocast(device): 

240 hidden = model.forward_hidden(full_seq, seq_idx=None) 

241 hidden_slice = hidden[0:1, input_ids.shape[1]-1:-1, :] 

242 logits = model.output(hidden_slice) 

243 log_probs = -F.cross_entropy( 

244 logits[0, :, :].contiguous(), 

245 comp_ids.contiguous(), reduction='none' 

246 ) 

247 

248 ratio = torch.exp(log_probs - old_log_probs) 

249 surr1 = ratio * advantages[i] 

250 surr2 = torch.clamp(ratio, 1.0 - EPS, 1.0 + EPS) * advantages[i] 

251 loss = -torch.min(surr1, surr2).mean() / GROUP_SIZE 

252 loss.backward() 

253 policy_loss += loss.item() 

254 

255 torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) 

256 for opt in [muon_opt, adam_opt, mamba_opt]: opt.step() 

257 

258 if step % 10 == 0: 258 ↛ 265line 258 didn't jump to line 265 because the condition on line 258 was always true

259 wandb.log({ 

260 "RL/Mean_Reward": rewards.mean().item(), 

261 "RL/Policy_Loss": policy_loss, 

262 "RL/Format_Reward": sum(compute_format_reward(c) for c in completions_text) / len(completions_text), 

263 "RL/Pass_Rate": sample['_pass_rate'], 

264 }, step=step) 

265 if step > 0 and step % 1000 == 0: 265 ↛ 266line 265 didn't jump to line 266 because the condition on line 265 was never true

266 torch.save({'step': step, 'model_state_dict': model.state_dict()}, 

267 os.path.join(checkpoint_dir, f"rl_step_{step:06d}.pt")) 

268 step += 1 

269 

270 torch.save({'step': step, 'model_state_dict': model.state_dict()}, 

271 os.path.join(checkpoint_dir, "rl_final.pt")) 

272 

273 

274def main(): 

275 global DEVICE 

276 

277 # RL is single-GPU only (GRPO generation is inherently sequential). 

278 # Skip DDP init entirely to avoid collective-operation deadlocks. 

279 rank = int(os.environ.get("RANK", "0")) 

280 if rank != 0: 

281 print(f"[rl_train] Rank {rank}: RL is single-GPU only, exiting.") 

282 return 

283 

284 os.makedirs(CHECKPOINT_DIR, exist_ok=True) 

285 wandb.init(project="Agentic-1.58b-Model", name="run-rl-grpo") 

286 

287 tokenizer = AutoTokenizer.from_pretrained("custom_agentic_tokenizer") 

288 eos_id = tokenizer.encode("<|im_end|>", add_special_tokens=False)[0] 

289 

290 files = collect_data_files(LOCAL_RL_DATA) 

291 if not files: 

292 raise FileNotFoundError(f"No .json/.jsonl/.parquet files found under {LOCAL_RL_DATA}") 

293 

294 parquet_files = [f for f in files if f.endswith(".parquet")] 

295 json_files = [f for f in files if f.endswith((".json", ".jsonl"))] 

296 if parquet_files: 

297 dataset = load_dataset("parquet", data_files=parquet_files, split="train", streaming=True) 

298 else: 

299 dataset = load_dataset("json", data_files=json_files, split="train", streaming=True) 

300 data_iter = iter(dataset) 

301 

302 model = BitMambaLLM(**MODEL_CONFIG).to(DEVICE) 

303 model.load_state_dict(torch.load(SFT_CKPT, map_location=DEVICE)['model_state_dict']) 

304 muon_opt, adam_opt, mamba_opt = setup_mamba_optimizers(model, {"peak_lr": PEAK_LR, "end_lr": 1e-6}) 

305 

306 run_rl_steps( 

307 model=model, 

308 tokenizer=tokenizer, 

309 dataset=dataset, 

310 optimizers=(muon_opt, adam_opt, mamba_opt), 

311 eos_id=eos_id, 

312 total_steps=TOTAL_STEPS, 

313 checkpoint_dir=CHECKPOINT_DIR, 

314 device=DEVICE, 

315 ) 

316 

317 wandb.finish() 

318 

319 

320if __name__ == "__main__": 

321 main() 

322