Coverage for sft_train.py: 73%

106 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-21 23:06 +0000

1# Copyright 2026 venim1103 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14 

15import torch 

16import torch.nn.functional as F 

17import os 

18import wandb 

19from model import BitMambaLLM, chunked_cross_entropy, maybe_autocast 

20from optim import setup_mamba_optimizers 

21from sft_data import SFT_STAGES, create_sft_dataloader 

22from context_config import CONTEXT_LENGTH 

23from transformers import AutoTokenizer 

24from dist_utils import ( 

25 setup_distributed, cleanup_distributed, is_main_process, 

26 wrap_model_ddp, unwrap_model, barrier, get_world_size, 

27) 

28 

29DEVICE = "cuda" if torch.cuda.is_available() else "cpu" # overridden by setup_distributed() 

30PRETRAINED_CKPT = "checkpoints/bitmamba_parent/step_1000000.pt" 

31CHECKPOINT_DIR = "checkpoints/sft" 

32MODEL_CONFIG = dict(vocab_size=64000, dim=1024, n_layers=40, d_state=128, expand=2, use_checkpoint=True) 

33 

34BATCH_SIZE = 2 

35GRAD_ACCUM_STEPS = 8 

36# Use FP16 on Ampere (RTX 3090). Change to torch.bfloat16 on Ada Lovelace (RTX 4090). 

37AMP_DTYPE = torch.float16 

38# Fixed constant to keep loss_sum in a safe range for the FP16 GradScaler. 

39# Uses the shared project max context so it is safe across all stages. 

40# Must match the constant used in the grad_scale calculation below. 

41SAFE_DIVISOR = BATCH_SIZE * float(CONTEXT_LENGTH) 

42 

43 

44def run_sft_stage(model, tokenizer, stage_cfg, stage_num, global_step): 

45 """Run one SFT stage: create dataloader, optimizer, and train for N epochs.""" 

46 raw_model = unwrap_model(model) 

47 name = stage_cfg["name"] 

48 lr = stage_cfg["lr"] 

49 epochs = stage_cfg["epochs"] 

50 

51 train_loader = create_sft_dataloader( 

52 stage_cfg["paths"], tokenizer, 

53 max_seq_len=stage_cfg["max_seq_len"], 

54 batch_size=BATCH_SIZE, 

55 reasoning_off_prob=stage_cfg["reasoning_off_prob"], 

56 ) 

57 

58 use_8bit = DEVICE.startswith("cuda") 

59 muon_opt, adam_opt, mamba_opt = setup_mamba_optimizers( 

60 raw_model, 

61 {"peak_lr": lr, "end_lr": lr * 0.1}, 

62 use_8bit=use_8bit, 

63 ) 

64 total_optim_steps = len(train_loader) * epochs // GRAD_ACCUM_STEPS 

65 scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(adam_opt, T_max=max(total_optim_steps, 1), eta_min=lr * 0.1) 

66 

67 if is_main_process(): 67 ↛ 72line 67 didn't jump to line 72 because the condition on line 67 was always true

68 print(f"\n{'='*60}") 

69 print(f"SFT Stage {stage_num}: {name} | epochs={epochs} lr={lr} samples={len(train_loader.dataset)}") 

70 print(f"{'='*60}") 

71 

72 scaler = torch.amp.GradScaler(enabled=(DEVICE.startswith("cuda") and AMP_DTYPE == torch.float16)) 

73 model.train() 

74 for epoch in range(epochs): 

75 # DistributedSampler must be told the epoch so it shuffles differently each time 

76 if hasattr(train_loader.sampler, "set_epoch"): 76 ↛ 77line 76 didn't jump to line 77 because the condition on line 76 was never true

77 train_loader.sampler.set_epoch(epoch) 

78 for opt in [muon_opt, adam_opt, mamba_opt]: opt.zero_grad() 

79 accumulated_loss = 0.0 

80 accumulated_loss_sum = 0.0 

81 accumulated_valid_tokens = 0 

82 n_batches = len(train_loader) 

83 

84 for batch_idx, (x, y) in enumerate(train_loader): 

85 x, y = x.to(DEVICE), y.to(DEVICE) 

86 

87 with maybe_autocast(DEVICE, amp_dtype=AMP_DTYPE): 

88 loss_sum, valid_tokens = model(x[:, :-1], seq_idx=None, targets=y[..., 1:]) 

89 safe_loss = loss_sum / SAFE_DIVISOR 

90 scaler.scale(safe_loss).backward() 

91 # Accumulate the true (un-divided) loss sum for correct metric logging. 

92 accumulated_loss_sum += loss_sum.detach().item() 

93 accumulated_valid_tokens += valid_tokens.detach().item() 

94 

95 should_step = ((batch_idx + 1) % GRAD_ACCUM_STEPS == 0) or (batch_idx + 1 == n_batches) 

96 if should_step: 96 ↛ 84line 96 didn't jump to line 84 because the condition on line 96 was always true

97 world_size = get_world_size() 

98 global_valid_tokens = accumulated_valid_tokens 

99 if world_size > 1: 99 ↛ 100line 99 didn't jump to line 100 because the condition on line 99 was never true

100 valid_token_tensor = torch.tensor(accumulated_valid_tokens, device=DEVICE, dtype=torch.float32) 

101 torch.distributed.all_reduce(valid_token_tensor, op=torch.distributed.ReduceOp.SUM) 

102 global_valid_tokens = valid_token_tensor.item() 

103 # Unscale first so custom grad_scale and clip_grad_norm operate 

104 # on true gradient magnitudes, not GradScaler-inflated ones. 

105 for opt in [muon_opt, adam_opt, mamba_opt]: scaler.unscale_(opt) 

106 # Multiply SAFE_DIVISOR back in to restore exact accumulation math. 

107 grad_scale = (world_size * SAFE_DIVISOR) / max(global_valid_tokens, 1.0) 

108 for param in model.parameters(): 

109 if param.grad is not None: 109 ↛ 110line 109 didn't jump to line 110 because the condition on line 109 was never true

110 param.grad.mul_(grad_scale) 

111 torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) 

112 

113 # 1. Optimizer steps FIRST (via GradScaler) 

114 scale_before = scaler.get_scale() 

115 for opt in [muon_opt, adam_opt, mamba_opt]: scaler.step(opt) 

116 scaler.update() 

117 

118 # 2. Scheduler steps SECOND 

119 # PyTorch GradScaler rule: only step scheduler if scaler didn't skip 

120 if scale_before <= scaler.get_scale(): 120 ↛ 124line 120 didn't jump to line 124 because the condition on line 120 was always true

121 scheduler.step() 

122 

123 # Sync LRs across optimizer groups 

124 current_lr = scheduler.get_last_lr()[0] 

125 for opt in [muon_opt, adam_opt]: 

126 for group in opt.param_groups: group['lr'] = current_lr 

127 for group in mamba_opt.param_groups: group['lr'] = current_lr * 0.1 

128 

129 for opt in [muon_opt, adam_opt, mamba_opt]: opt.zero_grad() 

130 accumulated_loss = accumulated_loss_sum * grad_scale 

131 

132 if global_step % 10 == 0 and is_main_process(): 132 ↛ 138line 132 didn't jump to line 138 because the condition on line 132 was always true

133 wandb.log({ 

134 f"SFT/{name}_Loss": accumulated_loss, 

135 "SFT/Stage": stage_num, 

136 "SFT/LR": current_lr, 

137 }, step=global_step) 

138 accumulated_loss = 0.0 

139 accumulated_loss_sum = 0.0 

140 accumulated_valid_tokens = 0 

141 global_step += 1 

142 

143 if is_main_process(): 143 ↛ 147line 143 didn't jump to line 147 because the condition on line 143 was always true

144 ckpt_path = os.path.join(CHECKPOINT_DIR, f"sft_stage{stage_num}_{name}_epoch{epoch+1}.pt") 

145 torch.save({'stage': stage_num, 'epoch': epoch, 'model_state_dict': raw_model.state_dict()}, ckpt_path) 

146 print(f" Epoch {epoch+1}/{epochs} done → {ckpt_path}") 

147 barrier() 

148 

149 return global_step 

150 

151 

152def main(): 

153 global DEVICE 

154 rank, local_rank, world_size, device = setup_distributed() 

155 DEVICE = device 

156 

157 os.makedirs(CHECKPOINT_DIR, exist_ok=True) 

158 if is_main_process(): 

159 wandb.init(project="Agentic-1.58b-Model", name="run-sft-bitmamba-3stage") 

160 

161 tokenizer = AutoTokenizer.from_pretrained("custom_agentic_tokenizer") 

162 model = BitMambaLLM(**MODEL_CONFIG).to(DEVICE) 

163 model.load_state_dict(torch.load(PRETRAINED_CKPT, map_location=DEVICE)['model_state_dict']) 

164 

165 model = wrap_model_ddp(model, local_rank) 

166 raw_model = unwrap_model(model) 

167 

168 global_step = 0 

169 for stage_num, stage_cfg in enumerate(SFT_STAGES, start=1): 

170 global_step = run_sft_stage(model, tokenizer, stage_cfg, stage_num, global_step) 

171 

172 # Save final checkpoint 

173 if is_main_process(): 

174 torch.save({'model_state_dict': raw_model.state_dict()}, os.path.join(CHECKPOINT_DIR, "sft_final.pt")) 

175 print(f"\nSFT complete. Final checkpoint → {CHECKPOINT_DIR}/sft_final.pt") 

176 wandb.finish() 

177 cleanup_distributed() 

178 

179if __name__ == "__main__": 

180 main() 

181