Coverage for optim.py: 90%

102 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-21 23:06 +0000

1# Copyright 2026 venim1103 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14 

15import torch 

16import math 

17 

18def _resolve_adam_class(use_8bit=True): 

19 """Resolve AdamW implementation lazily to keep CPU imports safe in CI/tests.""" 

20 if not use_8bit: 20 ↛ 24line 20 didn't jump to line 24 because the condition on line 20 was always true

21 return torch.optim.AdamW 

22 

23 # bitsandbytes can abort process import on CPU-only systems; only try on CUDA. 

24 if not torch.cuda.is_available(): 

25 return torch.optim.AdamW 

26 

27 try: 

28 import bitsandbytes as bnb 

29 return bnb.optim.Adam8bit 

30 except Exception: 

31 return torch.optim.AdamW 

32 

33class Muon(torch.optim.Optimizer): 

34 def __init__(self, params, lr=1e-3, momentum=0.95, ns_steps=3): 

35 defaults = dict(lr=lr, momentum=momentum, ns_steps=ns_steps) 

36 super().__init__(params, defaults) 

37 self.ns_workspaces = {} 

38 

39 @staticmethod 

40 def _get_ns_workspace(workspaces, shape, device, dtype): 

41 key = (shape, device, dtype) 

42 if key not in workspaces: 

43 rows, cols = shape 

44 workspaces[key] = { 

45 'a': torch.empty((rows, rows), device=device, dtype=dtype), 

46 'aa': torch.empty((rows, rows), device=device, dtype=dtype), 

47 'b': torch.empty((rows, rows), device=device, dtype=dtype), 

48 'update': torch.empty((rows, cols), device=device, dtype=dtype), 

49 } 

50 return workspaces[key] 

51 

52 def step(self): 

53 if len(self.ns_workspaces) > 16: 

54 self.ns_workspaces.clear() 

55 for group in self.param_groups: 

56 lr = group['lr'] 

57 momentum = group['momentum'] 

58 ns_steps = group['ns_steps'] 

59 for p in group['params']: 

60 if p.grad is None: continue 

61 

62 with torch.no_grad(): 

63 g = p.grad.to(torch.float32) 

64 p.grad = None # Free BF16 gradient memory (G5) 

65 

66 state = self.state[p] 

67 if len(state) == 0: 

68 state['momentum_buffer'] = torch.zeros_like(g) 

69 

70 buf = state['momentum_buffer'] 

71 buf.mul_(momentum).add_(g) 

72 

73 X = buf / (buf.norm(keepdim=True) + 1e-8) # New tensor, buf stays intact (G4) 

74 

75 # --- SAFETY GUARD --- 

76 # If the matrix is massively wide (e.g., rows > 4096), fallback to normalized  

77 # momentum SGD. This prevents catastrophic OOM if a large layer slips through. 

78 if X.size(0) > 4096: 78 ↛ 79line 78 didn't jump to line 79 because the condition on line 78 was never true

79 p.data.add_(X.type_as(p.data), alpha=-lr) 

80 continue 

81 # -------------------- 

82 

83 a, b, c = (3.4445, -4.7750, 2.0315) 

84 workspace = self._get_ns_workspace(self.ns_workspaces, X.shape, X.device, X.dtype) 

85 A = workspace['a'] 

86 AA = workspace['aa'] 

87 B = workspace['b'] 

88 update = workspace['update'] 

89 for _ in range(ns_steps): 

90 torch.matmul(X, X.T, out=A) 

91 torch.matmul(A, A, out=AA) 

92 B.copy_(A).mul_(b).add_(AA, alpha=c) 

93 torch.matmul(B, X, out=update) 

94 X.mul_(a).add_(update) 

95 

96 p.data.add_(X.type_as(p.data), alpha=-lr) 

97 

98def setup_mamba_optimizers(model, config, use_8bit=True): 

99 muon_params, adam_params, mamba_sensitive_params = [], [], [] 

100 

101 for name, p in model.named_parameters(): 

102 if not p.requires_grad: continue 102 ↛ 101line 102 didn't jump to line 101 because the continue on line 102 wasn't executed

103 

104 # ISOLATION: The sensitive continuous Mamba parameters 

105 if any(key in name for key in ['A_log', 'D', 'dt_bias', 'dt_proj']): 

106 mamba_sensitive_params.append(p) 

107 # Muon handles the 2D BitLinear weights (ndim == 2 excludes 3D conv1d weights) 

108 # CRITICAL: Explicitly exclude 'output.weight' to prevent massive 15GB workspace allocations 

109 elif p.ndim == 2 and 'weight' in name and 'norm' not in name and 'tok_embeddings' not in name and 'output.weight' not in name: 

110 muon_params.append(p) 

111 # AdamW handles biases, norms, and embeddings 

112 else: 

113 adam_params.append(p) 

114 

115 # G11: Use 8-bit Adam when available and safe for this runtime. 

116 AdamCls = _resolve_adam_class(use_8bit=use_8bit) 

117 

118 muon_opt = Muon(muon_params, lr=config['peak_lr']) 

119 adam_opt = AdamCls(adam_params, lr=config['peak_lr'], weight_decay=0.01) 

120 

121 # The dedicated optimizer for the State Space Core (Fixed low LR, 0 Weight Decay) 

122 mamba_core_opt = AdamCls( 

123 mamba_sensitive_params, lr=config['peak_lr'] * 0.1, weight_decay=0.0 

124 ) 

125 

126 return muon_opt, adam_opt, mamba_core_opt 

127 

128class FGWSD_Scheduler: 

129 def __init__(self, muon_opt, adam_opt, mamba_opt, total_steps, config): 

130 self.opts = [muon_opt, adam_opt, mamba_opt] 

131 self.peak_lr = config['peak_lr'] 

132 self.end_lr = config['end_lr'] 

133 self.phases = config['phases'] 

134 self.total_steps = total_steps 

135 

136 self.step_boundaries = [] 

137 current_step = 0 

138 for p in self.phases: 

139 current_step += int(self.total_steps * p['pct']) 

140 self.step_boundaries.append(current_step) 

141 

142 def get_lr_and_ctx(self, step): 

143 for i, boundary in enumerate(self.step_boundaries): 

144 if step < boundary: 

145 phase = self.phases[i] 

146 start_step = self.step_boundaries[i-1] if i > 0 else 0 

147 progress = (step - start_step) / (boundary - start_step) 

148 

149 if i == 0: lr = self.end_lr + progress * (self.peak_lr - self.end_lr) 

150 elif i in [1, 2]: lr = self.peak_lr 

151 else: 

152 cosine_decay = 0.5 * (1 + math.cos(math.pi * progress)) 

153 lr = self.end_lr + (self.peak_lr - self.end_lr) * cosine_decay 

154 

155 return lr, phase['ctx'], f"Phase_{i+1}" 

156 return self.end_lr, self.phases[-1]['ctx'], "Complete" 

157 

158 def step(self, current_step): 

159 lr, ctx, phase_name = self.get_lr_and_ctx(current_step) 

160 

161 # Update Adam and Muon 

162 for opt in self.opts[:2]: 

163 for group in opt.param_groups: group['lr'] = lr 

164 

165 # Update Mamba core (keep it 10x lower than the main LR) 

166 for group in self.opts[2].param_groups: group['lr'] = lr * 0.1 

167 

168 return lr, ctx, phase_name 

169