Coverage for optim.py: 90%
102 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-21 23:06 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-21 23:06 +0000
1# Copyright 2026 venim1103
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
15import torch
16import math
18def _resolve_adam_class(use_8bit=True):
19 """Resolve AdamW implementation lazily to keep CPU imports safe in CI/tests."""
20 if not use_8bit: 20 ↛ 24line 20 didn't jump to line 24 because the condition on line 20 was always true
21 return torch.optim.AdamW
23 # bitsandbytes can abort process import on CPU-only systems; only try on CUDA.
24 if not torch.cuda.is_available():
25 return torch.optim.AdamW
27 try:
28 import bitsandbytes as bnb
29 return bnb.optim.Adam8bit
30 except Exception:
31 return torch.optim.AdamW
33class Muon(torch.optim.Optimizer):
34 def __init__(self, params, lr=1e-3, momentum=0.95, ns_steps=3):
35 defaults = dict(lr=lr, momentum=momentum, ns_steps=ns_steps)
36 super().__init__(params, defaults)
37 self.ns_workspaces = {}
39 @staticmethod
40 def _get_ns_workspace(workspaces, shape, device, dtype):
41 key = (shape, device, dtype)
42 if key not in workspaces:
43 rows, cols = shape
44 workspaces[key] = {
45 'a': torch.empty((rows, rows), device=device, dtype=dtype),
46 'aa': torch.empty((rows, rows), device=device, dtype=dtype),
47 'b': torch.empty((rows, rows), device=device, dtype=dtype),
48 'update': torch.empty((rows, cols), device=device, dtype=dtype),
49 }
50 return workspaces[key]
52 def step(self):
53 if len(self.ns_workspaces) > 16:
54 self.ns_workspaces.clear()
55 for group in self.param_groups:
56 lr = group['lr']
57 momentum = group['momentum']
58 ns_steps = group['ns_steps']
59 for p in group['params']:
60 if p.grad is None: continue
62 with torch.no_grad():
63 g = p.grad.to(torch.float32)
64 p.grad = None # Free BF16 gradient memory (G5)
66 state = self.state[p]
67 if len(state) == 0:
68 state['momentum_buffer'] = torch.zeros_like(g)
70 buf = state['momentum_buffer']
71 buf.mul_(momentum).add_(g)
73 X = buf / (buf.norm(keepdim=True) + 1e-8) # New tensor, buf stays intact (G4)
75 # --- SAFETY GUARD ---
76 # If the matrix is massively wide (e.g., rows > 4096), fallback to normalized
77 # momentum SGD. This prevents catastrophic OOM if a large layer slips through.
78 if X.size(0) > 4096: 78 ↛ 79line 78 didn't jump to line 79 because the condition on line 78 was never true
79 p.data.add_(X.type_as(p.data), alpha=-lr)
80 continue
81 # --------------------
83 a, b, c = (3.4445, -4.7750, 2.0315)
84 workspace = self._get_ns_workspace(self.ns_workspaces, X.shape, X.device, X.dtype)
85 A = workspace['a']
86 AA = workspace['aa']
87 B = workspace['b']
88 update = workspace['update']
89 for _ in range(ns_steps):
90 torch.matmul(X, X.T, out=A)
91 torch.matmul(A, A, out=AA)
92 B.copy_(A).mul_(b).add_(AA, alpha=c)
93 torch.matmul(B, X, out=update)
94 X.mul_(a).add_(update)
96 p.data.add_(X.type_as(p.data), alpha=-lr)
98def setup_mamba_optimizers(model, config, use_8bit=True):
99 muon_params, adam_params, mamba_sensitive_params = [], [], []
101 for name, p in model.named_parameters():
102 if not p.requires_grad: continue 102 ↛ 101line 102 didn't jump to line 101 because the continue on line 102 wasn't executed
104 # ISOLATION: The sensitive continuous Mamba parameters
105 if any(key in name for key in ['A_log', 'D', 'dt_bias', 'dt_proj']):
106 mamba_sensitive_params.append(p)
107 # Muon handles the 2D BitLinear weights (ndim == 2 excludes 3D conv1d weights)
108 # CRITICAL: Explicitly exclude 'output.weight' to prevent massive 15GB workspace allocations
109 elif p.ndim == 2 and 'weight' in name and 'norm' not in name and 'tok_embeddings' not in name and 'output.weight' not in name:
110 muon_params.append(p)
111 # AdamW handles biases, norms, and embeddings
112 else:
113 adam_params.append(p)
115 # G11: Use 8-bit Adam when available and safe for this runtime.
116 AdamCls = _resolve_adam_class(use_8bit=use_8bit)
118 muon_opt = Muon(muon_params, lr=config['peak_lr'])
119 adam_opt = AdamCls(adam_params, lr=config['peak_lr'], weight_decay=0.01)
121 # The dedicated optimizer for the State Space Core (Fixed low LR, 0 Weight Decay)
122 mamba_core_opt = AdamCls(
123 mamba_sensitive_params, lr=config['peak_lr'] * 0.1, weight_decay=0.0
124 )
126 return muon_opt, adam_opt, mamba_core_opt
128class FGWSD_Scheduler:
129 def __init__(self, muon_opt, adam_opt, mamba_opt, total_steps, config):
130 self.opts = [muon_opt, adam_opt, mamba_opt]
131 self.peak_lr = config['peak_lr']
132 self.end_lr = config['end_lr']
133 self.phases = config['phases']
134 self.total_steps = total_steps
136 self.step_boundaries = []
137 current_step = 0
138 for p in self.phases:
139 current_step += int(self.total_steps * p['pct'])
140 self.step_boundaries.append(current_step)
142 def get_lr_and_ctx(self, step):
143 for i, boundary in enumerate(self.step_boundaries):
144 if step < boundary:
145 phase = self.phases[i]
146 start_step = self.step_boundaries[i-1] if i > 0 else 0
147 progress = (step - start_step) / (boundary - start_step)
149 if i == 0: lr = self.end_lr + progress * (self.peak_lr - self.end_lr)
150 elif i in [1, 2]: lr = self.peak_lr
151 else:
152 cosine_decay = 0.5 * (1 + math.cos(math.pi * progress))
153 lr = self.end_lr + (self.peak_lr - self.end_lr) * cosine_decay
155 return lr, phase['ctx'], f"Phase_{i+1}"
156 return self.end_lr, self.phases[-1]['ctx'], "Complete"
158 def step(self, current_step):
159 lr, ctx, phase_name = self.get_lr_and_ctx(current_step)
161 # Update Adam and Muon
162 for opt in self.opts[:2]:
163 for group in opt.param_groups: group['lr'] = lr
165 # Update Mamba core (keep it 10x lower than the main LR)
166 for group in self.opts[2].param_groups: group['lr'] = lr * 0.1
168 return lr, ctx, phase_name