Coverage for sft_train.py: 73%
106 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-21 23:06 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-21 23:06 +0000
1# Copyright 2026 venim1103
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
15import torch
16import torch.nn.functional as F
17import os
18import wandb
19from model import BitMambaLLM, chunked_cross_entropy, maybe_autocast
20from optim import setup_mamba_optimizers
21from sft_data import SFT_STAGES, create_sft_dataloader
22from context_config import CONTEXT_LENGTH
23from transformers import AutoTokenizer
24from dist_utils import (
25 setup_distributed, cleanup_distributed, is_main_process,
26 wrap_model_ddp, unwrap_model, barrier, get_world_size,
27)
29DEVICE = "cuda" if torch.cuda.is_available() else "cpu" # overridden by setup_distributed()
30PRETRAINED_CKPT = "checkpoints/bitmamba_parent/step_1000000.pt"
31CHECKPOINT_DIR = "checkpoints/sft"
32MODEL_CONFIG = dict(vocab_size=64000, dim=1024, n_layers=40, d_state=128, expand=2, use_checkpoint=True)
34BATCH_SIZE = 2
35GRAD_ACCUM_STEPS = 8
36# Use FP16 on Ampere (RTX 3090). Change to torch.bfloat16 on Ada Lovelace (RTX 4090).
37AMP_DTYPE = torch.float16
38# Fixed constant to keep loss_sum in a safe range for the FP16 GradScaler.
39# Uses the shared project max context so it is safe across all stages.
40# Must match the constant used in the grad_scale calculation below.
41SAFE_DIVISOR = BATCH_SIZE * float(CONTEXT_LENGTH)
44def run_sft_stage(model, tokenizer, stage_cfg, stage_num, global_step):
45 """Run one SFT stage: create dataloader, optimizer, and train for N epochs."""
46 raw_model = unwrap_model(model)
47 name = stage_cfg["name"]
48 lr = stage_cfg["lr"]
49 epochs = stage_cfg["epochs"]
51 train_loader = create_sft_dataloader(
52 stage_cfg["paths"], tokenizer,
53 max_seq_len=stage_cfg["max_seq_len"],
54 batch_size=BATCH_SIZE,
55 reasoning_off_prob=stage_cfg["reasoning_off_prob"],
56 )
58 use_8bit = DEVICE.startswith("cuda")
59 muon_opt, adam_opt, mamba_opt = setup_mamba_optimizers(
60 raw_model,
61 {"peak_lr": lr, "end_lr": lr * 0.1},
62 use_8bit=use_8bit,
63 )
64 total_optim_steps = len(train_loader) * epochs // GRAD_ACCUM_STEPS
65 scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(adam_opt, T_max=max(total_optim_steps, 1), eta_min=lr * 0.1)
67 if is_main_process(): 67 ↛ 72line 67 didn't jump to line 72 because the condition on line 67 was always true
68 print(f"\n{'='*60}")
69 print(f"SFT Stage {stage_num}: {name} | epochs={epochs} lr={lr} samples={len(train_loader.dataset)}")
70 print(f"{'='*60}")
72 scaler = torch.amp.GradScaler(enabled=(DEVICE.startswith("cuda") and AMP_DTYPE == torch.float16))
73 model.train()
74 for epoch in range(epochs):
75 # DistributedSampler must be told the epoch so it shuffles differently each time
76 if hasattr(train_loader.sampler, "set_epoch"): 76 ↛ 77line 76 didn't jump to line 77 because the condition on line 76 was never true
77 train_loader.sampler.set_epoch(epoch)
78 for opt in [muon_opt, adam_opt, mamba_opt]: opt.zero_grad()
79 accumulated_loss = 0.0
80 accumulated_loss_sum = 0.0
81 accumulated_valid_tokens = 0
82 n_batches = len(train_loader)
84 for batch_idx, (x, y) in enumerate(train_loader):
85 x, y = x.to(DEVICE), y.to(DEVICE)
87 with maybe_autocast(DEVICE, amp_dtype=AMP_DTYPE):
88 loss_sum, valid_tokens = model(x[:, :-1], seq_idx=None, targets=y[..., 1:])
89 safe_loss = loss_sum / SAFE_DIVISOR
90 scaler.scale(safe_loss).backward()
91 # Accumulate the true (un-divided) loss sum for correct metric logging.
92 accumulated_loss_sum += loss_sum.detach().item()
93 accumulated_valid_tokens += valid_tokens.detach().item()
95 should_step = ((batch_idx + 1) % GRAD_ACCUM_STEPS == 0) or (batch_idx + 1 == n_batches)
96 if should_step: 96 ↛ 84line 96 didn't jump to line 84 because the condition on line 96 was always true
97 world_size = get_world_size()
98 global_valid_tokens = accumulated_valid_tokens
99 if world_size > 1: 99 ↛ 100line 99 didn't jump to line 100 because the condition on line 99 was never true
100 valid_token_tensor = torch.tensor(accumulated_valid_tokens, device=DEVICE, dtype=torch.float32)
101 torch.distributed.all_reduce(valid_token_tensor, op=torch.distributed.ReduceOp.SUM)
102 global_valid_tokens = valid_token_tensor.item()
103 # Unscale first so custom grad_scale and clip_grad_norm operate
104 # on true gradient magnitudes, not GradScaler-inflated ones.
105 for opt in [muon_opt, adam_opt, mamba_opt]: scaler.unscale_(opt)
106 # Multiply SAFE_DIVISOR back in to restore exact accumulation math.
107 grad_scale = (world_size * SAFE_DIVISOR) / max(global_valid_tokens, 1.0)
108 for param in model.parameters():
109 if param.grad is not None: 109 ↛ 110line 109 didn't jump to line 110 because the condition on line 109 was never true
110 param.grad.mul_(grad_scale)
111 torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
113 # 1. Optimizer steps FIRST (via GradScaler)
114 scale_before = scaler.get_scale()
115 for opt in [muon_opt, adam_opt, mamba_opt]: scaler.step(opt)
116 scaler.update()
118 # 2. Scheduler steps SECOND
119 # PyTorch GradScaler rule: only step scheduler if scaler didn't skip
120 if scale_before <= scaler.get_scale(): 120 ↛ 124line 120 didn't jump to line 124 because the condition on line 120 was always true
121 scheduler.step()
123 # Sync LRs across optimizer groups
124 current_lr = scheduler.get_last_lr()[0]
125 for opt in [muon_opt, adam_opt]:
126 for group in opt.param_groups: group['lr'] = current_lr
127 for group in mamba_opt.param_groups: group['lr'] = current_lr * 0.1
129 for opt in [muon_opt, adam_opt, mamba_opt]: opt.zero_grad()
130 accumulated_loss = accumulated_loss_sum * grad_scale
132 if global_step % 10 == 0 and is_main_process(): 132 ↛ 138line 132 didn't jump to line 138 because the condition on line 132 was always true
133 wandb.log({
134 f"SFT/{name}_Loss": accumulated_loss,
135 "SFT/Stage": stage_num,
136 "SFT/LR": current_lr,
137 }, step=global_step)
138 accumulated_loss = 0.0
139 accumulated_loss_sum = 0.0
140 accumulated_valid_tokens = 0
141 global_step += 1
143 if is_main_process(): 143 ↛ 147line 143 didn't jump to line 147 because the condition on line 143 was always true
144 ckpt_path = os.path.join(CHECKPOINT_DIR, f"sft_stage{stage_num}_{name}_epoch{epoch+1}.pt")
145 torch.save({'stage': stage_num, 'epoch': epoch, 'model_state_dict': raw_model.state_dict()}, ckpt_path)
146 print(f" Epoch {epoch+1}/{epochs} done → {ckpt_path}")
147 barrier()
149 return global_step
152def main():
153 global DEVICE
154 rank, local_rank, world_size, device = setup_distributed()
155 DEVICE = device
157 os.makedirs(CHECKPOINT_DIR, exist_ok=True)
158 if is_main_process():
159 wandb.init(project="Agentic-1.58b-Model", name="run-sft-bitmamba-3stage")
161 tokenizer = AutoTokenizer.from_pretrained("custom_agentic_tokenizer")
162 model = BitMambaLLM(**MODEL_CONFIG).to(DEVICE)
163 model.load_state_dict(torch.load(PRETRAINED_CKPT, map_location=DEVICE)['model_state_dict'])
165 model = wrap_model_ddp(model, local_rank)
166 raw_model = unwrap_model(model)
168 global_step = 0
169 for stage_num, stage_cfg in enumerate(SFT_STAGES, start=1):
170 global_step = run_sft_stage(model, tokenizer, stage_cfg, stage_num, global_step)
172 # Save final checkpoint
173 if is_main_process():
174 torch.save({'model_state_dict': raw_model.state_dict()}, os.path.join(CHECKPOINT_DIR, "sft_final.pt"))
175 print(f"\nSFT complete. Final checkpoint → {CHECKPOINT_DIR}/sft_final.pt")
176 wandb.finish()
177 cleanup_distributed()
179if __name__ == "__main__":
180 main()