Coverage for model.py: 86%

424 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-21 23:06 +0000

1# Copyright 2026 venim1103 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14 

15import torch 

16import torch.nn as nn 

17import torch.nn.functional as F 

18import math 

19from contextlib import nullcontext 

20try: 

21 import triton 

22 import triton.language as tl 

23 HAS_TRITON = True 

24except ImportError: 

25 triton = None 

26 tl = None 

27 HAS_TRITON = False 

28 

29try: 

30 from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined 

31except ImportError: 

32 mamba_chunk_scan_combined = None 

33 

34try: 

35 from causal_conv1d import causal_conv1d_fn, causal_conv1d_update 

36except ImportError: 

37 causal_conv1d_fn = None 

38 causal_conv1d_update = None 

39 

40try: 

41 from mamba_ssm.ops.triton.selective_state_update import selective_state_update 

42except ImportError: 

43 selective_state_update = None 

44 

45from einops import rearrange 

46 

47 

48def maybe_autocast(device=None, amp_dtype=None): 

49 if device is None: 49 ↛ 50line 49 didn't jump to line 50 because the condition on line 49 was never true

50 device_type = "cuda" if torch.cuda.is_available() else "cpu" 

51 elif isinstance(device, torch.device): 51 ↛ 52line 51 didn't jump to line 52 because the condition on line 51 was never true

52 device_type = device.type 

53 else: 

54 device_type = str(device) 

55 

56 if device_type != "cuda": 56 ↛ 59line 56 didn't jump to line 59 because the condition on line 56 was always true

57 return nullcontext() 

58 

59 if amp_dtype is None: 

60 amp_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 

61 elif amp_dtype == torch.bfloat16 and not torch.cuda.is_bf16_supported(): 

62 amp_dtype = torch.float16 

63 

64 return torch.autocast(device_type="cuda", dtype=amp_dtype) 

65 

66 

67def _mamba_chunk_scan_combined_fallback( 

68 x, 

69 dt, 

70 A, 

71 B, 

72 C, 

73 chunk_size, 

74 D, 

75 z=None, 

76 dt_bias=None, 

77 dt_softplus=True, 

78 seq_idx=None, 

79 return_final_states=False, 

80): 

81 del chunk_size 

82 bsz, seqlen, nheads, headdim = x.shape 

83 ngroups = B.shape[2] 

84 d_state = B.shape[3] 

85 

86 heads_per_group = nheads // ngroups 

87 Bh = B.repeat_interleave(heads_per_group, dim=2).to(torch.float32) 

88 Ch = C.repeat_interleave(heads_per_group, dim=2).to(torch.float32) 

89 

90 x_f = x.to(torch.float32) 

91 dt_f = dt.to(torch.float32) 

92 A_f = A.to(torch.float32) 

93 D_f = D.to(torch.float32) 

94 

95 if dt_bias is not None: 95 ↛ 97line 95 didn't jump to line 97 because the condition on line 95 was always true

96 dt_f = dt_f + dt_bias.to(torch.float32).view(1, 1, -1) 

97 if dt_softplus: 97 ↛ 100line 97 didn't jump to line 100 because the condition on line 97 was always true

98 dt_f = F.softplus(dt_f) 

99 

100 state = x_f.new_zeros((bsz, nheads, headdim, d_state)) 

101 outputs = [] 

102 

103 for t in range(seqlen): 

104 if seq_idx is not None and t > 0: 

105 reset = (seq_idx[:, t] != seq_idx[:, t - 1]).view(bsz, 1, 1, 1) 

106 state = torch.where(reset, torch.zeros_like(state), state) 

107 

108 x_t = x_f[:, t, :, :] 

109 dt_t = dt_f[:, t, :] 

110 B_t = Bh[:, t, :, :] 

111 C_t = Ch[:, t, :, :] 

112 

113 dA = torch.exp(dt_t * A_f.view(1, -1)) 

114 dB = dt_t.unsqueeze(-1) * B_t 

115 

116 state = dA.unsqueeze(-1).unsqueeze(-1) * state + torch.einsum("bhn,bhp->bhpn", dB, x_t) 

117 y_t = torch.einsum("bhpn,bhn->bhp", state, C_t) + D_f.view(1, -1, 1) * x_t 

118 

119 if z is not None: 119 ↛ 122line 119 didn't jump to line 122 because the condition on line 119 was always true

120 y_t = y_t * F.silu(z[:, t, :, :].to(torch.float32)) 

121 

122 outputs.append(y_t) 

123 

124 y = torch.stack(outputs, dim=1).to(x.dtype) 

125 if return_final_states: 

126 return y, state.to(x.dtype) 

127 return y 

128 

129 

130if mamba_chunk_scan_combined is None: 130 ↛ 136line 130 didn't jump to line 136 because the condition on line 130 was always true

131 mamba_chunk_scan_combined = _mamba_chunk_scan_combined_fallback 

132 

133# ========================================== 

134# 1. Custom Triton Kernel for 1.58-bit Weights 

135# ========================================== 

136if HAS_TRITON: 136 ↛ 152line 136 didn't jump to line 152 because the condition on line 136 was always true

137 @triton.jit 

138 def _ternary_quant_kernel(w_ptr, output_ptr, n_elements, scale_ptr, BLOCK_SIZE: tl.constexpr): 

139 pid = tl.program_id(axis=0) 

140 offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

141 mask = offsets < n_elements 

142 w = tl.load(w_ptr + offsets, mask=mask).to(tl.float32) 

143 # Load scalar scale from tensor pointer to avoid CPU sync 

144 scale = tl.load(scale_ptr).to(tl.float32) 

145 w_scaled = w / scale 

146 # Triton does not expose tl.math.round on all versions; emulate nearest-int rounding. 

147 w_quant = tl.where(w_scaled >= 0, tl.floor(w_scaled + 0.5), tl.ceil(w_scaled - 0.5)) 

148 w_quant = tl.maximum(w_quant, -1.0) 

149 w_quant = tl.minimum(w_quant, 1.0) 

150 tl.store(output_ptr + offsets, w_quant, mask=mask) 

151 

152class TernaryQuantizeSTE(torch.autograd.Function): 

153 @staticmethod 

154 def forward(ctx, weight): 

155 scale = weight.abs().mean().clamp(min=1e-5) 

156 if HAS_TRITON and weight.is_cuda: 156 ↛ 157line 156 didn't jump to line 157 because the condition on line 156 was never true

157 output = torch.empty_like(weight) 

158 n_elements = weight.numel() 

159 grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),) 

160 # Pass the scale tensor directly as a pointer! No .item() CPU sync means no graph breaks! 

161 _ternary_quant_kernel[grid](weight, output, n_elements, scale, BLOCK_SIZE=1024) 

162 return output 

163 w_scaled = weight / scale 

164 w_quant = torch.where(w_scaled >= 0, torch.floor(w_scaled + 0.5), torch.ceil(w_scaled - 0.5)) 

165 return torch.clamp(w_quant, -1.0, 1.0) 

166 

167 @staticmethod 

168 def backward(ctx, grad_output): 

169 return grad_output 

170 

171def weight_quant(weight): 

172 return TernaryQuantizeSTE.apply(weight) 

173 

174def activation_quant(x): 

175 scale = 127.0 / x.abs().max(dim=-1, keepdim=True).values.clamp_(min=1e-5) 

176 quantized_x = torch.round(x * scale) 

177 quantized_x = torch.clamp(quantized_x, -128, 127) 

178 return ((quantized_x - (x * scale)).detach() + (x * scale)) / scale 

179 

180class BitLinear(nn.Module): 

181 def __init__(self, in_features, out_features, bias=False): 

182 super().__init__() 

183 self.in_features = in_features 

184 self.out_features = out_features 

185 self.weight = nn.Parameter(torch.randn(out_features, in_features) * (1.0 / math.sqrt(in_features))) 

186 if bias: 

187 self.bias = nn.Parameter(torch.zeros(out_features)) 

188 else: 

189 self.register_parameter('bias', None) 

190 self.norm = nn.LayerNorm(in_features, elementwise_affine=False) 

191 self.register_buffer('_cached_quant_weight', None, persistent=False) 

192 self._cached_quant_weight_version = -1 

193 

194 def _clear_inference_cache(self): 

195 self._cached_quant_weight = None 

196 self._cached_quant_weight_version = -1 

197 

198 def _get_quantized_weight(self): 

199 if self.training: 

200 return weight_quant(self.weight) 

201 

202 weight_version = getattr(self.weight, '_version', None) 

203 cache_is_stale = ( 

204 self._cached_quant_weight is None 

205 or self._cached_quant_weight.device != self.weight.device 

206 or self._cached_quant_weight.dtype != self.weight.dtype 

207 or self._cached_quant_weight_version != weight_version 

208 ) 

209 if cache_is_stale: 

210 with torch.no_grad(): 

211 self._cached_quant_weight = weight_quant(self.weight).detach() 

212 self._cached_quant_weight_version = weight_version 

213 return self._cached_quant_weight 

214 

215 def prepare_for_inference(self): 

216 self.eval() 

217 self._get_quantized_weight() 

218 return self 

219 

220 def train(self, mode=True): 

221 if mode: 

222 self._clear_inference_cache() 

223 return super().train(mode) 

224 

225 def forward(self, x): 

226 x_norm = self.norm(x) 

227 x_quant = activation_quant(x_norm) 

228 w_quant = self._get_quantized_weight() 

229 return F.linear(x_quant, w_quant, self.bias) 

230 

231class RMSNorm(nn.Module): 

232 def __init__(self, dim, eps=1e-6): 

233 super().__init__() 

234 self.eps = eps 

235 self.weight = nn.Parameter(torch.ones(dim)) 

236 def forward(self, x): 

237 norm_x = torch.mean(x ** 2, dim=-1, keepdim=True) 

238 return self.weight * (x * torch.rsqrt(norm_x + self.eps)) 

239 

240 

241class AttentionBlock(nn.Module): 

242 """Lightweight attention block with GQA (4 KV heads) for hybrid architecture. 

243  

244 As recommended in Nemotron-H §2.1, a small percentage of attention layers 

245 (evenly dispersed) alongside Mamba-2 layers improves retrieval-heavy tasks. 

246 """ 

247 

248 def __init__(self, dim, n_kv_heads=4): 

249 super().__init__() 

250 self.dim = dim 

251 self.n_heads = dim // 64 # head_dim = 64 

252 self.n_kv_heads = n_kv_heads 

253 self.head_dim = dim // self.n_heads 

254 

255 # GQA: fewer KV heads than Q heads 

256 self.q_proj = BitLinear(dim, dim) 

257 self.k_proj = BitLinear(dim, self.n_kv_heads * self.head_dim) 

258 self.v_proj = BitLinear(dim, self.n_kv_heads * self.head_dim) 

259 self.o_proj = BitLinear(dim, dim) 

260 

261 self.norm = RMSNorm(dim) 

262 

263 def forward(self, hidden_states, seq_idx=None): 

264 x = self.norm(hidden_states) 

265 

266 # Project to Q, K, V 

267 q = self.q_proj(x).view(x.shape[0], x.shape[1], self.n_heads, self.head_dim) 

268 k = self.k_proj(x).view(x.shape[0], x.shape[1], self.n_kv_heads, self.head_dim) 

269 v = self.v_proj(x).view(x.shape[0], x.shape[1], self.n_kv_heads, self.head_dim) 

270 

271 # Expand K, V to all heads (GQA) 

272 if self.n_kv_heads < self.n_heads: 272 ↛ 273line 272 didn't jump to line 273 because the condition on line 272 was never true

273 repeat_factor = self.n_heads // self.n_kv_heads 

274 k = k.repeat_interleave(repeat_factor, dim=2) 

275 v = v.repeat_interleave(repeat_factor, dim=2) 

276 

277 # Use PyTorch SDPA to avoid explicitly materializing a full causal mask. 

278 # Shapes for SDPA are (B, H, L, D). 

279 q = q.transpose(1, 2) 

280 k = k.transpose(1, 2) 

281 v = v.transpose(1, 2) 

282 

283 if seq_idx is None: 

284 out = F.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=True) 

285 out = out.transpose(1, 2) 

286 else: 

287 # Enforce document boundaries from seq_idx by attending only within 

288 # each contiguous segment. This prevents cross-document leakage. 

289 batch_size, _, seqlen, _ = q.shape 

290 out = hidden_states.new_empty(batch_size, seqlen, self.n_heads, self.head_dim) 

291 for b in range(batch_size): 

292 seg = seq_idx[b] 

293 changes = torch.nonzero(seg[1:] != seg[:-1], as_tuple=False).flatten() + 1 

294 boundaries = torch.cat([ 

295 torch.tensor([0], device=seg.device, dtype=torch.long), 

296 changes.to(torch.long), 

297 torch.tensor([seqlen], device=seg.device, dtype=torch.long), 

298 ]) 

299 

300 for i in range(boundaries.numel() - 1): 

301 start = boundaries[i].item() 

302 end = boundaries[i + 1].item() 

303 if end <= start: 303 ↛ 304line 303 didn't jump to line 304 because the condition on line 303 was never true

304 continue 

305 q_seg = q[b:b + 1, :, start:end, :] 

306 k_seg = k[b:b + 1, :, start:end, :] 

307 v_seg = v[b:b + 1, :, start:end, :] 

308 out_seg = F.scaled_dot_product_attention( 

309 q_seg, k_seg, v_seg, attn_mask=None, dropout_p=0.0, is_causal=True 

310 ) 

311 out[b:b + 1, start:end, :, :] = out_seg.transpose(1, 2) 

312 

313 out = out.contiguous().view(x.shape[0], x.shape[1], self.dim) 

314 

315 return hidden_states + self.o_proj(out) 

316 

317 def prefill(self, hidden_states): 

318 """Full-sequence attention, returns output and KV cache for decoding.""" 

319 x = self.norm(hidden_states) 

320 bsz, seqlen, _ = x.shape 

321 

322 q = self.q_proj(x).view(bsz, seqlen, self.n_heads, self.head_dim) 

323 k = self.k_proj(x).view(bsz, seqlen, self.n_kv_heads, self.head_dim) 

324 v = self.v_proj(x).view(bsz, seqlen, self.n_kv_heads, self.head_dim) 

325 

326 if self.n_kv_heads < self.n_heads: 326 ↛ 327line 326 didn't jump to line 327 because the condition on line 326 was never true

327 repeat_factor = self.n_heads // self.n_kv_heads 

328 k_exp = k.repeat_interleave(repeat_factor, dim=2) 

329 v_exp = v.repeat_interleave(repeat_factor, dim=2) 

330 else: 

331 k_exp, v_exp = k, v 

332 

333 out = F.scaled_dot_product_attention( 

334 q.transpose(1, 2), k_exp.transpose(1, 2), v_exp.transpose(1, 2), 

335 is_causal=True, 

336 ).transpose(1, 2).contiguous().view(bsz, seqlen, self.dim) 

337 

338 cache = {"k": k, "v": v} 

339 return hidden_states + self.o_proj(out), cache 

340 

341 def step(self, hidden_states, cache): 

342 """Single-token decode step with KV cache.""" 

343 x = self.norm(hidden_states) 

344 bsz = x.shape[0] 

345 

346 q = self.q_proj(x).view(bsz, 1, self.n_heads, self.head_dim) 

347 k_new = self.k_proj(x).view(bsz, 1, self.n_kv_heads, self.head_dim) 

348 v_new = self.v_proj(x).view(bsz, 1, self.n_kv_heads, self.head_dim) 

349 

350 cache["k"] = torch.cat([cache["k"], k_new], dim=1) 

351 cache["v"] = torch.cat([cache["v"], v_new], dim=1) 

352 

353 k, v = cache["k"], cache["v"] 

354 if self.n_kv_heads < self.n_heads: 354 ↛ 355line 354 didn't jump to line 355 because the condition on line 354 was never true

355 repeat_factor = self.n_heads // self.n_kv_heads 

356 k = k.repeat_interleave(repeat_factor, dim=2) 

357 v = v.repeat_interleave(repeat_factor, dim=2) 

358 

359 # Single query attending to all cached keys — no causal mask needed 

360 out = F.scaled_dot_product_attention( 

361 q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), 

362 is_causal=False, 

363 ).transpose(1, 2).contiguous().view(bsz, 1, self.dim) 

364 

365 return hidden_states + self.o_proj(out) 

366 

367# ========================================== 

368# 2. BitMamba Block & Architecture 

369# ========================================== 

370class BitMambaBlock(nn.Module): 

371 """Mamba-2 SSD block with 1.58-bit (ternary) heavy projections. 

372 

373 Architecture follows the Mamba-2 head-based formulation from the SSD paper 

374 (Dao & Gu, 2024) and the Nemotron-H design (§2.1). Key differences from 

375 Mamba-1: 

376 - A is a scalar per head (not a diagonal matrix per state) 

377 - B, C are shared across heads via ngroups (analogous to GQA) 

378 - No dt_rank decomposition — dt is produced directly per head 

379 - Uses the mamba_chunk_scan_combined Triton kernel with native seq_idx 

380 """ 

381 

382 def __init__(self, dim, d_state=128, d_conv=4, expand=2, headdim=64, ngroups=1, chunk_size=256): 

383 super().__init__() 

384 self.dim = dim 

385 self.d_state = d_state 

386 self.d_conv = d_conv 

387 self.headdim = headdim 

388 self.d_inner = int(expand * dim) 

389 self.nheads = self.d_inner // headdim 

390 self.ngroups = ngroups 

391 self.chunk_size = chunk_size 

392 

393 assert self.d_inner % headdim == 0, f"d_inner ({self.d_inner}) must be divisible by headdim ({headdim})" 

394 

395 self.norm = RMSNorm(dim) 

396 

397 # 1.58-bit Heavy Projections (Saves VRAM!) 

398 # in_proj: produces [z, x, B, C, dt] in one shot 

399 d_in_proj = 2 * self.d_inner + 2 * self.ngroups * self.d_state + self.nheads 

400 self.in_proj = BitLinear(self.dim, d_in_proj, bias=False) 

401 self.out_proj = BitLinear(self.d_inner, self.dim, bias=False) 

402 

403 # FP16/FP32 Recurrent Core (Maintains Stability!) 

404 conv_dim = self.d_inner + 2 * self.ngroups * self.d_state 

405 self.conv1d = nn.Conv1d( 

406 in_channels=conv_dim, out_channels=conv_dim, bias=True, 

407 kernel_size=d_conv, groups=conv_dim, padding=d_conv - 1 

408 ) 

409 

410 # Mamba-2: A is a scalar per head (stored in log-space for stability) 

411 A = torch.arange(1, self.nheads + 1, dtype=torch.float32) 

412 self.A_log = nn.Parameter(torch.log(A)) 

413 self.D = nn.Parameter(torch.ones(self.nheads, dtype=torch.float32)) 

414 self.dt_bias = nn.Parameter(torch.empty(self.nheads)) 

415 

416 # Initialize dt_bias so initial dt after softplus is in [dt_min, dt_max] 

417 dt_min, dt_max = 0.001, 0.1 

418 dt_init_floor = 1e-4 

419 inv_dt = torch.exp(torch.rand(self.nheads) * (math.log(dt_max) - math.log(dt_min)) + math.log(dt_min)) 

420 inv_dt = torch.clamp(inv_dt, min=dt_init_floor) 

421 # Inverse of softplus: x = log(exp(val) - 1) 

422 self.dt_bias.data = inv_dt + torch.log(-torch.expm1(-inv_dt)) 

423 

424 # Output gate normalization (as used in Mamba-2 reference) 

425 self.out_norm = RMSNorm(self.d_inner) 

426 

427 def forward(self, hidden_states, seq_idx=None): 

428 batch, seqlen, _ = hidden_states.shape 

429 h = self.norm(hidden_states) 

430 

431 # Single projection produces z, x, B, C, dt 

432 zxbcdt = self.in_proj(h) 

433 

434 # Split: z (gate) | xBC (conv input) | dt (timestep) 

435 z, xBC, dt = torch.split( 

436 zxbcdt, 

437 [self.d_inner, self.d_inner + 2 * self.ngroups * self.d_state, self.nheads], 

438 dim=-1 

439 ) 

440 

441 # Causal conv1d on x, B, C (not z — z is the gate) 

442 if causal_conv1d_fn is not None and xBC.is_cuda: 442 ↛ 443line 442 didn't jump to line 443 because the condition on line 442 was never true

443 xBC = causal_conv1d_fn( 

444 xBC.transpose(1, 2), # (B, conv_dim, L) 

445 rearrange(self.conv1d.weight, "d 1 w -> d w"), 

446 self.conv1d.bias, 

447 activation="silu", 

448 ).transpose(1, 2) # back to (B, L, conv_dim) 

449 else: 

450 xBC = self.conv1d(xBC.transpose(1, 2))[..., :seqlen].transpose(1, 2) 

451 xBC = F.silu(xBC) 

452 

453 # Split convolved result into x, B, C 

454 x, B, C = torch.split( 

455 xBC, 

456 [self.d_inner, self.ngroups * self.d_state, self.ngroups * self.d_state], 

457 dim=-1 

458 ) 

459 

460 A = -torch.exp(self.A_log.float()) # (nheads,) 

461 

462 # Reshape for Mamba-2 SSD kernel 

463 # x: (B, L, nheads, headdim), dt: (B, L, nheads) 

464 # B: (B, L, ngroups, d_state), C: (B, L, ngroups, d_state) 

465 y = mamba_chunk_scan_combined( 

466 rearrange(x, "b l (h p) -> b l h p", p=self.headdim), 

467 dt, 

468 A, 

469 rearrange(B, "b l (g n) -> b l g n", g=self.ngroups), 

470 rearrange(C, "b l (g n) -> b l g n", g=self.ngroups), 

471 chunk_size=self.chunk_size, 

472 D=self.D, 

473 z=rearrange(z, "b l (h p) -> b l h p", p=self.headdim), 

474 dt_bias=self.dt_bias, 

475 dt_softplus=True, 

476 seq_idx=seq_idx, 

477 ) 

478 y = rearrange(y, "b l h p -> b l (h p)") 

479 

480 # Output gate normalization + projection 

481 y = self.out_norm(y) 

482 out = self.out_proj(y) 

483 return hidden_states + out 

484 

485 def prefill(self, hidden_states): 

486 """Full-sequence Mamba-2 scan, returns output and (conv_state, ssm_state) for decoding.""" 

487 batch, seqlen, _ = hidden_states.shape 

488 h = self.norm(hidden_states) 

489 

490 zxbcdt = self.in_proj(h) 

491 z, xBC, dt = torch.split( 

492 zxbcdt, 

493 [self.d_inner, self.d_inner + 2 * self.ngroups * self.d_state, self.nheads], 

494 dim=-1 

495 ) 

496 

497 # Save conv state: last d_conv inputs before convolution 

498 xBC_t = xBC.transpose(1, 2) # (batch, conv_dim, seqlen) 

499 if seqlen >= self.d_conv: 

500 conv_state = xBC_t[:, :, -self.d_conv:].clone() 

501 else: 

502 conv_state = F.pad(xBC_t, (self.d_conv - seqlen, 0)).clone() 

503 

504 # Apply conv1d 

505 if causal_conv1d_fn is not None and xBC_t.is_cuda: 505 ↛ 506line 505 didn't jump to line 506 because the condition on line 505 was never true

506 xBC = causal_conv1d_fn( 

507 xBC_t, 

508 rearrange(self.conv1d.weight, "d 1 w -> d w"), 

509 self.conv1d.bias, 

510 activation="silu", 

511 ).transpose(1, 2) 

512 else: 

513 xBC = self.conv1d(xBC_t)[..., :seqlen].transpose(1, 2) 

514 xBC = F.silu(xBC) 

515 

516 x, B, C = torch.split( 

517 xBC, 

518 [self.d_inner, self.ngroups * self.d_state, self.ngroups * self.d_state], 

519 dim=-1 

520 ) 

521 

522 A = -torch.exp(self.A_log.float()) 

523 

524 y, ssm_state = mamba_chunk_scan_combined( 

525 rearrange(x, "b l (h p) -> b l h p", p=self.headdim), 

526 dt, A, 

527 rearrange(B, "b l (g n) -> b l g n", g=self.ngroups), 

528 rearrange(C, "b l (g n) -> b l g n", g=self.ngroups), 

529 chunk_size=self.chunk_size, 

530 D=self.D, 

531 z=rearrange(z, "b l (h p) -> b l h p", p=self.headdim), 

532 dt_bias=self.dt_bias, 

533 dt_softplus=True, 

534 return_final_states=True, 

535 ) 

536 y = rearrange(y, "b l h p -> b l (h p)") 

537 y = self.out_norm(y) 

538 out = self.out_proj(y) 

539 

540 cache = {"conv_state": conv_state, "ssm_state": ssm_state} 

541 return hidden_states + out, cache 

542 

543 def step(self, hidden_states, cache): 

544 """Single-token Mamba-2 step with cached conv + SSM state. 

545 

546 Uses official causal_conv1d_update and selective_state_update Triton 

547 kernels when available (matching the Mamba-2 reference implementation). 

548 The SSD chunk-scan kernel and sequential step use algebraically 

549 equivalent but numerically distinct accumulation orders, so step 

550 outputs will have small relative differences (~0.5%) vs a full 

551 forward pass. This is inherent to the SSD formulation and matches 

552 the upstream mamba-ssm behaviour. 

553 """ 

554 dtype = hidden_states.dtype 

555 h = self.norm(hidden_states) # (bsz, 1, dim) 

556 

557 zxbcdt = self.in_proj(h).squeeze(1) # (bsz, d_in_proj) 

558 z, xBC, dt = torch.split( 

559 zxbcdt, 

560 [self.d_inner, self.d_inner + 2 * self.ngroups * self.d_state, self.nheads], 

561 dim=-1 

562 ) 

563 

564 conv_state = cache["conv_state"] 

565 

566 if causal_conv1d_update is not None and xBC.is_cuda: 566 ↛ 567line 566 didn't jump to line 567 because the condition on line 566 was never true

567 xBC = causal_conv1d_update( 

568 xBC, 

569 conv_state, 

570 rearrange(self.conv1d.weight, "d 1 w -> d w"), 

571 self.conv1d.bias, 

572 activation="silu", 

573 ) 

574 else: 

575 conv_state = torch.roll(conv_state, shifts=-1, dims=-1) 

576 conv_state[:, :, -1] = xBC 

577 cache["conv_state"] = conv_state 

578 xBC = (conv_state * self.conv1d.weight.squeeze(1)).sum(-1) + self.conv1d.bias 

579 xBC = F.silu(xBC) 

580 

581 x, B, C = torch.split( 

582 xBC, 

583 [self.d_inner, self.ngroups * self.d_state, self.ngroups * self.d_state], 

584 dim=-1 

585 ) 

586 

587 A = -torch.exp(self.A_log.float()) 

588 ssm_state = cache["ssm_state"] # (bsz, nheads, headdim, d_state) 

589 

590 if selective_state_update is not None and ssm_state.is_cuda: 590 ↛ 591line 590 didn't jump to line 591 because the condition on line 590 was never true

591 x = rearrange(x, "b (h p) -> b h p", p=self.headdim) 

592 z = rearrange(z, "b (h p) -> b h p", p=self.headdim) 

593 B = rearrange(B, "b (g n) -> b g n", g=self.ngroups) 

594 C = rearrange(C, "b (g n) -> b g n", g=self.ngroups) 

595 # selective_state_update expects per-element A/dt/D with stride-0 

596 # broadcasting along headdim (detected via tie_hdim optimisation). 

597 A_ssm = A.view(self.nheads, 1, 1).expand(self.nheads, self.headdim, self.d_state) 

598 dt_ssm = dt.unsqueeze(-1).expand(-1, -1, self.headdim) 

599 dt_bias_ssm = self.dt_bias.unsqueeze(-1).expand(-1, self.headdim) 

600 D_ssm = self.D.unsqueeze(-1).expand(-1, self.headdim) 

601 y = selective_state_update( 

602 ssm_state, x, dt_ssm, A_ssm, B, C, 

603 D=D_ssm, z=z, 

604 dt_bias=dt_bias_ssm, dt_softplus=True, 

605 ) 

606 y = rearrange(y, "b h p -> b (h p)") 

607 else: 

608 x = rearrange(x, "b (h p) -> b h p", p=self.headdim).float() 

609 B = rearrange(B, "b (g n) -> b g n", g=self.ngroups).float() 

610 C = rearrange(C, "b (g n) -> b g n", g=self.ngroups).float() 

611 heads_per_group = self.nheads // self.ngroups 

612 B = B.repeat_interleave(heads_per_group, dim=1) 

613 C = C.repeat_interleave(heads_per_group, dim=1) 

614 

615 dt_act = F.softplus(dt.float() + self.dt_bias.float()) 

616 dA = torch.exp(dt_act * A) 

617 dB = dt_act.unsqueeze(-1) * B 

618 

619 ssm_state = dA.unsqueeze(-1).unsqueeze(-1) * ssm_state + torch.einsum("bhn,bhp->bhpn", dB, x) 

620 cache["ssm_state"] = ssm_state 

621 

622 y = torch.einsum("bhpn,bhn->bhp", ssm_state, C) 

623 y = y + self.D.float().unsqueeze(-1) * x 

624 z = rearrange(z, "b (h p) -> b h p", p=self.headdim).float() 

625 y = y * F.silu(z) 

626 y = rearrange(y, "b h p -> b (h p)") 

627 

628 y = y.unsqueeze(1).to(dtype) # (bsz, 1, d_inner) 

629 y = self.out_norm(y) 

630 out = self.out_proj(y) 

631 

632 return hidden_states + out 

633 

634class BitMambaLLM(nn.Module): 

635 def __init__(self, vocab_size=64000, dim=1024, n_layers=40, d_state=128, expand=2, 

636 headdim=64, ngroups=1, chunk_size=256, use_checkpoint=False, 

637 use_attn=False, attn_pct=0.08): 

638 super().__init__() 

639 self.vocab_size = vocab_size 

640 self.use_checkpoint = use_checkpoint 

641 self.use_attn = use_attn 

642 self.attn_pct = attn_pct 

643 

644 # Compute attention layer indices (evenly dispersed as in Nemotron-H) 

645 if use_attn and attn_pct > 0: 

646 n_attn = max(1, int(n_layers * attn_pct)) 

647 step = n_layers // n_attn 

648 self.attn_indices = set(range(step // 2, n_layers, step)) 

649 else: 

650 self.attn_indices = set() 

651 

652 self.tok_embeddings = nn.Embedding(vocab_size, dim) 

653 

654 self.layers = nn.ModuleList() 

655 for i in range(n_layers): 

656 if i in self.attn_indices: 

657 # Lightweight attention layer (GQA with 4 KV heads) 

658 self.layers.append(AttentionBlock(dim, n_kv_heads=4)) 

659 else: 

660 self.layers.append(BitMambaBlock( 

661 dim, d_state=d_state, expand=expand, 

662 headdim=headdim, ngroups=ngroups, chunk_size=chunk_size, 

663 )) 

664 

665 self.norm = RMSNorm(dim) 

666 self.output = BitLinear(dim, vocab_size, bias=False) 

667 

668 def _backbone(self, input_ids, seq_idx=None): 

669 """Shared backbone: embedding → layers → norm. Used by both forward and forward_hidden.""" 

670 from torch.utils.checkpoint import checkpoint 

671 x = self.tok_embeddings(input_ids) 

672 for layer in self.layers: 

673 if self.use_checkpoint and self.training: 

674 x = checkpoint(layer, x, seq_idx, use_reentrant=False) 

675 else: 

676 x = layer(x, seq_idx=seq_idx) 

677 return self.norm(x) 

678 

679 def forward(self, input_ids, seq_idx=None, targets=None): 

680 hidden_states = self._backbone(input_ids, seq_idx) 

681 

682 if targets is not None: 682 ↛ 683line 682 didn't jump to line 683 because the condition on line 682 was never true

683 return chunked_cross_entropy(hidden_states, self.output, targets, return_stats=True) 

684 

685 return self.output(hidden_states) 

686 

687 def forward_hidden(self, input_ids, seq_idx=None): 

688 """Return pre-logit hidden states (G2: for chunked cross-entropy).""" 

689 return self._backbone(input_ids, seq_idx) 

690 

691 def prepare_for_inference(self): 

692 self.eval() 

693 for module in self.modules(): 

694 if isinstance(module, BitLinear): 

695 module.prepare_for_inference() 

696 return self 

697 

698 @torch.no_grad() 

699 def generate(self, input_ids, max_new_tokens=512, temperature=0.7, 

700 do_sample=True, eos_token_id=None): 

701 """O(n) autoregressive generation with cached SSM/KV states. 

702 

703 Prefills the prompt in one pass, then decodes one token at a time 

704 using per-layer conv/SSM state (Mamba blocks) and KV cache (attention blocks). 

705 """ 

706 self.prepare_for_inference() 

707 

708 # --- Prefill: process the full prompt --- 

709 x = self.tok_embeddings(input_ids) 

710 caches = [] 

711 for layer in self.layers: 

712 x, cache = layer.prefill(x) 

713 caches.append(cache) 

714 x = self.norm(x) 

715 logits = self.output(x[:, -1:, :]) # (B, 1, vocab) 

716 

717 # --- Decode: one token at a time --- 

718 generated = input_ids 

719 for _ in range(max_new_tokens): 

720 if do_sample and temperature > 0: 

721 next_token = torch.multinomial( 

722 F.softmax(logits[0, -1, :] / temperature, dim=-1), num_samples=1 

723 ) 

724 else: 

725 next_token = logits[0, -1, :].argmax(dim=-1, keepdim=True) 

726 

727 generated = torch.cat([generated, next_token.unsqueeze(0)], dim=-1) 

728 

729 if eos_token_id is not None and next_token.item() == eos_token_id: 

730 break 

731 

732 # Single-token forward through cached layers 

733 x = self.tok_embeddings(next_token.unsqueeze(0)) # (1, 1, dim) 

734 for i, layer in enumerate(self.layers): 

735 x = layer.step(x, caches[i]) 

736 x = self.norm(x) 

737 logits = self.output(x) # (1, 1, vocab) 

738 

739 return generated 

740 

741 

742def chunked_cross_entropy(hidden, output_proj, targets, chunk_size=1024, ignore_index=-100, return_stats=False): 

743 """Compute cross-entropy without materializing the full logits tensor. 

744 

745 Instead of computing logits for the entire sequence at once (which for 

746 [BS, 16384, 64000] costs ~3.9 GB), this applies the output projection 

747 in chunks of ``chunk_size`` tokens along the sequence dimension. 

748 

749 Args: 

750 hidden: (BS, seq_len, dim) — pre-logit hidden states from forward_hidden 

751 output_proj: nn.Module — the output head (model.output) 

752 targets: (BS, seq_len) — target token ids 

753 chunk_size: int — number of tokens per chunk (default 1024) 

754 ignore_index: int — label to ignore in loss (default -100) 

755 

756 Returns: 

757 Scalar loss (mean cross-entropy over all tokens), or a tuple of 

758 ``(total_loss_sum, valid_tokens)`` when ``return_stats=True``. 

759 """ 

760 bs, seq_len, _ = hidden.shape 

761 total_loss = 0.0 

762 valid_tokens = (targets != ignore_index).sum().clamp(min=1) 

763 

764 for i in range(0, seq_len, chunk_size): 

765 end = min(i + chunk_size, seq_len) 

766 chunk_logits = output_proj(hidden[:, i:end, :]) # (BS, chunk, vocab) 

767 chunk_targets = targets[:, i:end].reshape(-1) # (BS * chunk,) 

768 total_loss += F.cross_entropy( 

769 chunk_logits.reshape(-1, chunk_logits.size(-1)), 

770 chunk_targets, 

771 reduction='sum', 

772 ignore_index=ignore_index, 

773 ) 

774 

775 if return_stats: 

776 return total_loss, valid_tokens 

777 return total_loss / valid_tokens 

778