Coverage for model.py: 86%
424 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-21 23:06 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-21 23:06 +0000
1# Copyright 2026 venim1103
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
15import torch
16import torch.nn as nn
17import torch.nn.functional as F
18import math
19from contextlib import nullcontext
20try:
21 import triton
22 import triton.language as tl
23 HAS_TRITON = True
24except ImportError:
25 triton = None
26 tl = None
27 HAS_TRITON = False
29try:
30 from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined
31except ImportError:
32 mamba_chunk_scan_combined = None
34try:
35 from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
36except ImportError:
37 causal_conv1d_fn = None
38 causal_conv1d_update = None
40try:
41 from mamba_ssm.ops.triton.selective_state_update import selective_state_update
42except ImportError:
43 selective_state_update = None
45from einops import rearrange
48def maybe_autocast(device=None, amp_dtype=None):
49 if device is None: 49 ↛ 50line 49 didn't jump to line 50 because the condition on line 49 was never true
50 device_type = "cuda" if torch.cuda.is_available() else "cpu"
51 elif isinstance(device, torch.device): 51 ↛ 52line 51 didn't jump to line 52 because the condition on line 51 was never true
52 device_type = device.type
53 else:
54 device_type = str(device)
56 if device_type != "cuda": 56 ↛ 59line 56 didn't jump to line 59 because the condition on line 56 was always true
57 return nullcontext()
59 if amp_dtype is None:
60 amp_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
61 elif amp_dtype == torch.bfloat16 and not torch.cuda.is_bf16_supported():
62 amp_dtype = torch.float16
64 return torch.autocast(device_type="cuda", dtype=amp_dtype)
67def _mamba_chunk_scan_combined_fallback(
68 x,
69 dt,
70 A,
71 B,
72 C,
73 chunk_size,
74 D,
75 z=None,
76 dt_bias=None,
77 dt_softplus=True,
78 seq_idx=None,
79 return_final_states=False,
80):
81 del chunk_size
82 bsz, seqlen, nheads, headdim = x.shape
83 ngroups = B.shape[2]
84 d_state = B.shape[3]
86 heads_per_group = nheads // ngroups
87 Bh = B.repeat_interleave(heads_per_group, dim=2).to(torch.float32)
88 Ch = C.repeat_interleave(heads_per_group, dim=2).to(torch.float32)
90 x_f = x.to(torch.float32)
91 dt_f = dt.to(torch.float32)
92 A_f = A.to(torch.float32)
93 D_f = D.to(torch.float32)
95 if dt_bias is not None: 95 ↛ 97line 95 didn't jump to line 97 because the condition on line 95 was always true
96 dt_f = dt_f + dt_bias.to(torch.float32).view(1, 1, -1)
97 if dt_softplus: 97 ↛ 100line 97 didn't jump to line 100 because the condition on line 97 was always true
98 dt_f = F.softplus(dt_f)
100 state = x_f.new_zeros((bsz, nheads, headdim, d_state))
101 outputs = []
103 for t in range(seqlen):
104 if seq_idx is not None and t > 0:
105 reset = (seq_idx[:, t] != seq_idx[:, t - 1]).view(bsz, 1, 1, 1)
106 state = torch.where(reset, torch.zeros_like(state), state)
108 x_t = x_f[:, t, :, :]
109 dt_t = dt_f[:, t, :]
110 B_t = Bh[:, t, :, :]
111 C_t = Ch[:, t, :, :]
113 dA = torch.exp(dt_t * A_f.view(1, -1))
114 dB = dt_t.unsqueeze(-1) * B_t
116 state = dA.unsqueeze(-1).unsqueeze(-1) * state + torch.einsum("bhn,bhp->bhpn", dB, x_t)
117 y_t = torch.einsum("bhpn,bhn->bhp", state, C_t) + D_f.view(1, -1, 1) * x_t
119 if z is not None: 119 ↛ 122line 119 didn't jump to line 122 because the condition on line 119 was always true
120 y_t = y_t * F.silu(z[:, t, :, :].to(torch.float32))
122 outputs.append(y_t)
124 y = torch.stack(outputs, dim=1).to(x.dtype)
125 if return_final_states:
126 return y, state.to(x.dtype)
127 return y
130if mamba_chunk_scan_combined is None: 130 ↛ 136line 130 didn't jump to line 136 because the condition on line 130 was always true
131 mamba_chunk_scan_combined = _mamba_chunk_scan_combined_fallback
133# ==========================================
134# 1. Custom Triton Kernel for 1.58-bit Weights
135# ==========================================
136if HAS_TRITON: 136 ↛ 152line 136 didn't jump to line 152 because the condition on line 136 was always true
137 @triton.jit
138 def _ternary_quant_kernel(w_ptr, output_ptr, n_elements, scale_ptr, BLOCK_SIZE: tl.constexpr):
139 pid = tl.program_id(axis=0)
140 offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
141 mask = offsets < n_elements
142 w = tl.load(w_ptr + offsets, mask=mask).to(tl.float32)
143 # Load scalar scale from tensor pointer to avoid CPU sync
144 scale = tl.load(scale_ptr).to(tl.float32)
145 w_scaled = w / scale
146 # Triton does not expose tl.math.round on all versions; emulate nearest-int rounding.
147 w_quant = tl.where(w_scaled >= 0, tl.floor(w_scaled + 0.5), tl.ceil(w_scaled - 0.5))
148 w_quant = tl.maximum(w_quant, -1.0)
149 w_quant = tl.minimum(w_quant, 1.0)
150 tl.store(output_ptr + offsets, w_quant, mask=mask)
152class TernaryQuantizeSTE(torch.autograd.Function):
153 @staticmethod
154 def forward(ctx, weight):
155 scale = weight.abs().mean().clamp(min=1e-5)
156 if HAS_TRITON and weight.is_cuda: 156 ↛ 157line 156 didn't jump to line 157 because the condition on line 156 was never true
157 output = torch.empty_like(weight)
158 n_elements = weight.numel()
159 grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
160 # Pass the scale tensor directly as a pointer! No .item() CPU sync means no graph breaks!
161 _ternary_quant_kernel[grid](weight, output, n_elements, scale, BLOCK_SIZE=1024)
162 return output
163 w_scaled = weight / scale
164 w_quant = torch.where(w_scaled >= 0, torch.floor(w_scaled + 0.5), torch.ceil(w_scaled - 0.5))
165 return torch.clamp(w_quant, -1.0, 1.0)
167 @staticmethod
168 def backward(ctx, grad_output):
169 return grad_output
171def weight_quant(weight):
172 return TernaryQuantizeSTE.apply(weight)
174def activation_quant(x):
175 scale = 127.0 / x.abs().max(dim=-1, keepdim=True).values.clamp_(min=1e-5)
176 quantized_x = torch.round(x * scale)
177 quantized_x = torch.clamp(quantized_x, -128, 127)
178 return ((quantized_x - (x * scale)).detach() + (x * scale)) / scale
180class BitLinear(nn.Module):
181 def __init__(self, in_features, out_features, bias=False):
182 super().__init__()
183 self.in_features = in_features
184 self.out_features = out_features
185 self.weight = nn.Parameter(torch.randn(out_features, in_features) * (1.0 / math.sqrt(in_features)))
186 if bias:
187 self.bias = nn.Parameter(torch.zeros(out_features))
188 else:
189 self.register_parameter('bias', None)
190 self.norm = nn.LayerNorm(in_features, elementwise_affine=False)
191 self.register_buffer('_cached_quant_weight', None, persistent=False)
192 self._cached_quant_weight_version = -1
194 def _clear_inference_cache(self):
195 self._cached_quant_weight = None
196 self._cached_quant_weight_version = -1
198 def _get_quantized_weight(self):
199 if self.training:
200 return weight_quant(self.weight)
202 weight_version = getattr(self.weight, '_version', None)
203 cache_is_stale = (
204 self._cached_quant_weight is None
205 or self._cached_quant_weight.device != self.weight.device
206 or self._cached_quant_weight.dtype != self.weight.dtype
207 or self._cached_quant_weight_version != weight_version
208 )
209 if cache_is_stale:
210 with torch.no_grad():
211 self._cached_quant_weight = weight_quant(self.weight).detach()
212 self._cached_quant_weight_version = weight_version
213 return self._cached_quant_weight
215 def prepare_for_inference(self):
216 self.eval()
217 self._get_quantized_weight()
218 return self
220 def train(self, mode=True):
221 if mode:
222 self._clear_inference_cache()
223 return super().train(mode)
225 def forward(self, x):
226 x_norm = self.norm(x)
227 x_quant = activation_quant(x_norm)
228 w_quant = self._get_quantized_weight()
229 return F.linear(x_quant, w_quant, self.bias)
231class RMSNorm(nn.Module):
232 def __init__(self, dim, eps=1e-6):
233 super().__init__()
234 self.eps = eps
235 self.weight = nn.Parameter(torch.ones(dim))
236 def forward(self, x):
237 norm_x = torch.mean(x ** 2, dim=-1, keepdim=True)
238 return self.weight * (x * torch.rsqrt(norm_x + self.eps))
241class AttentionBlock(nn.Module):
242 """Lightweight attention block with GQA (4 KV heads) for hybrid architecture.
244 As recommended in Nemotron-H §2.1, a small percentage of attention layers
245 (evenly dispersed) alongside Mamba-2 layers improves retrieval-heavy tasks.
246 """
248 def __init__(self, dim, n_kv_heads=4):
249 super().__init__()
250 self.dim = dim
251 self.n_heads = dim // 64 # head_dim = 64
252 self.n_kv_heads = n_kv_heads
253 self.head_dim = dim // self.n_heads
255 # GQA: fewer KV heads than Q heads
256 self.q_proj = BitLinear(dim, dim)
257 self.k_proj = BitLinear(dim, self.n_kv_heads * self.head_dim)
258 self.v_proj = BitLinear(dim, self.n_kv_heads * self.head_dim)
259 self.o_proj = BitLinear(dim, dim)
261 self.norm = RMSNorm(dim)
263 def forward(self, hidden_states, seq_idx=None):
264 x = self.norm(hidden_states)
266 # Project to Q, K, V
267 q = self.q_proj(x).view(x.shape[0], x.shape[1], self.n_heads, self.head_dim)
268 k = self.k_proj(x).view(x.shape[0], x.shape[1], self.n_kv_heads, self.head_dim)
269 v = self.v_proj(x).view(x.shape[0], x.shape[1], self.n_kv_heads, self.head_dim)
271 # Expand K, V to all heads (GQA)
272 if self.n_kv_heads < self.n_heads: 272 ↛ 273line 272 didn't jump to line 273 because the condition on line 272 was never true
273 repeat_factor = self.n_heads // self.n_kv_heads
274 k = k.repeat_interleave(repeat_factor, dim=2)
275 v = v.repeat_interleave(repeat_factor, dim=2)
277 # Use PyTorch SDPA to avoid explicitly materializing a full causal mask.
278 # Shapes for SDPA are (B, H, L, D).
279 q = q.transpose(1, 2)
280 k = k.transpose(1, 2)
281 v = v.transpose(1, 2)
283 if seq_idx is None:
284 out = F.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=True)
285 out = out.transpose(1, 2)
286 else:
287 # Enforce document boundaries from seq_idx by attending only within
288 # each contiguous segment. This prevents cross-document leakage.
289 batch_size, _, seqlen, _ = q.shape
290 out = hidden_states.new_empty(batch_size, seqlen, self.n_heads, self.head_dim)
291 for b in range(batch_size):
292 seg = seq_idx[b]
293 changes = torch.nonzero(seg[1:] != seg[:-1], as_tuple=False).flatten() + 1
294 boundaries = torch.cat([
295 torch.tensor([0], device=seg.device, dtype=torch.long),
296 changes.to(torch.long),
297 torch.tensor([seqlen], device=seg.device, dtype=torch.long),
298 ])
300 for i in range(boundaries.numel() - 1):
301 start = boundaries[i].item()
302 end = boundaries[i + 1].item()
303 if end <= start: 303 ↛ 304line 303 didn't jump to line 304 because the condition on line 303 was never true
304 continue
305 q_seg = q[b:b + 1, :, start:end, :]
306 k_seg = k[b:b + 1, :, start:end, :]
307 v_seg = v[b:b + 1, :, start:end, :]
308 out_seg = F.scaled_dot_product_attention(
309 q_seg, k_seg, v_seg, attn_mask=None, dropout_p=0.0, is_causal=True
310 )
311 out[b:b + 1, start:end, :, :] = out_seg.transpose(1, 2)
313 out = out.contiguous().view(x.shape[0], x.shape[1], self.dim)
315 return hidden_states + self.o_proj(out)
317 def prefill(self, hidden_states):
318 """Full-sequence attention, returns output and KV cache for decoding."""
319 x = self.norm(hidden_states)
320 bsz, seqlen, _ = x.shape
322 q = self.q_proj(x).view(bsz, seqlen, self.n_heads, self.head_dim)
323 k = self.k_proj(x).view(bsz, seqlen, self.n_kv_heads, self.head_dim)
324 v = self.v_proj(x).view(bsz, seqlen, self.n_kv_heads, self.head_dim)
326 if self.n_kv_heads < self.n_heads: 326 ↛ 327line 326 didn't jump to line 327 because the condition on line 326 was never true
327 repeat_factor = self.n_heads // self.n_kv_heads
328 k_exp = k.repeat_interleave(repeat_factor, dim=2)
329 v_exp = v.repeat_interleave(repeat_factor, dim=2)
330 else:
331 k_exp, v_exp = k, v
333 out = F.scaled_dot_product_attention(
334 q.transpose(1, 2), k_exp.transpose(1, 2), v_exp.transpose(1, 2),
335 is_causal=True,
336 ).transpose(1, 2).contiguous().view(bsz, seqlen, self.dim)
338 cache = {"k": k, "v": v}
339 return hidden_states + self.o_proj(out), cache
341 def step(self, hidden_states, cache):
342 """Single-token decode step with KV cache."""
343 x = self.norm(hidden_states)
344 bsz = x.shape[0]
346 q = self.q_proj(x).view(bsz, 1, self.n_heads, self.head_dim)
347 k_new = self.k_proj(x).view(bsz, 1, self.n_kv_heads, self.head_dim)
348 v_new = self.v_proj(x).view(bsz, 1, self.n_kv_heads, self.head_dim)
350 cache["k"] = torch.cat([cache["k"], k_new], dim=1)
351 cache["v"] = torch.cat([cache["v"], v_new], dim=1)
353 k, v = cache["k"], cache["v"]
354 if self.n_kv_heads < self.n_heads: 354 ↛ 355line 354 didn't jump to line 355 because the condition on line 354 was never true
355 repeat_factor = self.n_heads // self.n_kv_heads
356 k = k.repeat_interleave(repeat_factor, dim=2)
357 v = v.repeat_interleave(repeat_factor, dim=2)
359 # Single query attending to all cached keys — no causal mask needed
360 out = F.scaled_dot_product_attention(
361 q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2),
362 is_causal=False,
363 ).transpose(1, 2).contiguous().view(bsz, 1, self.dim)
365 return hidden_states + self.o_proj(out)
367# ==========================================
368# 2. BitMamba Block & Architecture
369# ==========================================
370class BitMambaBlock(nn.Module):
371 """Mamba-2 SSD block with 1.58-bit (ternary) heavy projections.
373 Architecture follows the Mamba-2 head-based formulation from the SSD paper
374 (Dao & Gu, 2024) and the Nemotron-H design (§2.1). Key differences from
375 Mamba-1:
376 - A is a scalar per head (not a diagonal matrix per state)
377 - B, C are shared across heads via ngroups (analogous to GQA)
378 - No dt_rank decomposition — dt is produced directly per head
379 - Uses the mamba_chunk_scan_combined Triton kernel with native seq_idx
380 """
382 def __init__(self, dim, d_state=128, d_conv=4, expand=2, headdim=64, ngroups=1, chunk_size=256):
383 super().__init__()
384 self.dim = dim
385 self.d_state = d_state
386 self.d_conv = d_conv
387 self.headdim = headdim
388 self.d_inner = int(expand * dim)
389 self.nheads = self.d_inner // headdim
390 self.ngroups = ngroups
391 self.chunk_size = chunk_size
393 assert self.d_inner % headdim == 0, f"d_inner ({self.d_inner}) must be divisible by headdim ({headdim})"
395 self.norm = RMSNorm(dim)
397 # 1.58-bit Heavy Projections (Saves VRAM!)
398 # in_proj: produces [z, x, B, C, dt] in one shot
399 d_in_proj = 2 * self.d_inner + 2 * self.ngroups * self.d_state + self.nheads
400 self.in_proj = BitLinear(self.dim, d_in_proj, bias=False)
401 self.out_proj = BitLinear(self.d_inner, self.dim, bias=False)
403 # FP16/FP32 Recurrent Core (Maintains Stability!)
404 conv_dim = self.d_inner + 2 * self.ngroups * self.d_state
405 self.conv1d = nn.Conv1d(
406 in_channels=conv_dim, out_channels=conv_dim, bias=True,
407 kernel_size=d_conv, groups=conv_dim, padding=d_conv - 1
408 )
410 # Mamba-2: A is a scalar per head (stored in log-space for stability)
411 A = torch.arange(1, self.nheads + 1, dtype=torch.float32)
412 self.A_log = nn.Parameter(torch.log(A))
413 self.D = nn.Parameter(torch.ones(self.nheads, dtype=torch.float32))
414 self.dt_bias = nn.Parameter(torch.empty(self.nheads))
416 # Initialize dt_bias so initial dt after softplus is in [dt_min, dt_max]
417 dt_min, dt_max = 0.001, 0.1
418 dt_init_floor = 1e-4
419 inv_dt = torch.exp(torch.rand(self.nheads) * (math.log(dt_max) - math.log(dt_min)) + math.log(dt_min))
420 inv_dt = torch.clamp(inv_dt, min=dt_init_floor)
421 # Inverse of softplus: x = log(exp(val) - 1)
422 self.dt_bias.data = inv_dt + torch.log(-torch.expm1(-inv_dt))
424 # Output gate normalization (as used in Mamba-2 reference)
425 self.out_norm = RMSNorm(self.d_inner)
427 def forward(self, hidden_states, seq_idx=None):
428 batch, seqlen, _ = hidden_states.shape
429 h = self.norm(hidden_states)
431 # Single projection produces z, x, B, C, dt
432 zxbcdt = self.in_proj(h)
434 # Split: z (gate) | xBC (conv input) | dt (timestep)
435 z, xBC, dt = torch.split(
436 zxbcdt,
437 [self.d_inner, self.d_inner + 2 * self.ngroups * self.d_state, self.nheads],
438 dim=-1
439 )
441 # Causal conv1d on x, B, C (not z — z is the gate)
442 if causal_conv1d_fn is not None and xBC.is_cuda: 442 ↛ 443line 442 didn't jump to line 443 because the condition on line 442 was never true
443 xBC = causal_conv1d_fn(
444 xBC.transpose(1, 2), # (B, conv_dim, L)
445 rearrange(self.conv1d.weight, "d 1 w -> d w"),
446 self.conv1d.bias,
447 activation="silu",
448 ).transpose(1, 2) # back to (B, L, conv_dim)
449 else:
450 xBC = self.conv1d(xBC.transpose(1, 2))[..., :seqlen].transpose(1, 2)
451 xBC = F.silu(xBC)
453 # Split convolved result into x, B, C
454 x, B, C = torch.split(
455 xBC,
456 [self.d_inner, self.ngroups * self.d_state, self.ngroups * self.d_state],
457 dim=-1
458 )
460 A = -torch.exp(self.A_log.float()) # (nheads,)
462 # Reshape for Mamba-2 SSD kernel
463 # x: (B, L, nheads, headdim), dt: (B, L, nheads)
464 # B: (B, L, ngroups, d_state), C: (B, L, ngroups, d_state)
465 y = mamba_chunk_scan_combined(
466 rearrange(x, "b l (h p) -> b l h p", p=self.headdim),
467 dt,
468 A,
469 rearrange(B, "b l (g n) -> b l g n", g=self.ngroups),
470 rearrange(C, "b l (g n) -> b l g n", g=self.ngroups),
471 chunk_size=self.chunk_size,
472 D=self.D,
473 z=rearrange(z, "b l (h p) -> b l h p", p=self.headdim),
474 dt_bias=self.dt_bias,
475 dt_softplus=True,
476 seq_idx=seq_idx,
477 )
478 y = rearrange(y, "b l h p -> b l (h p)")
480 # Output gate normalization + projection
481 y = self.out_norm(y)
482 out = self.out_proj(y)
483 return hidden_states + out
485 def prefill(self, hidden_states):
486 """Full-sequence Mamba-2 scan, returns output and (conv_state, ssm_state) for decoding."""
487 batch, seqlen, _ = hidden_states.shape
488 h = self.norm(hidden_states)
490 zxbcdt = self.in_proj(h)
491 z, xBC, dt = torch.split(
492 zxbcdt,
493 [self.d_inner, self.d_inner + 2 * self.ngroups * self.d_state, self.nheads],
494 dim=-1
495 )
497 # Save conv state: last d_conv inputs before convolution
498 xBC_t = xBC.transpose(1, 2) # (batch, conv_dim, seqlen)
499 if seqlen >= self.d_conv:
500 conv_state = xBC_t[:, :, -self.d_conv:].clone()
501 else:
502 conv_state = F.pad(xBC_t, (self.d_conv - seqlen, 0)).clone()
504 # Apply conv1d
505 if causal_conv1d_fn is not None and xBC_t.is_cuda: 505 ↛ 506line 505 didn't jump to line 506 because the condition on line 505 was never true
506 xBC = causal_conv1d_fn(
507 xBC_t,
508 rearrange(self.conv1d.weight, "d 1 w -> d w"),
509 self.conv1d.bias,
510 activation="silu",
511 ).transpose(1, 2)
512 else:
513 xBC = self.conv1d(xBC_t)[..., :seqlen].transpose(1, 2)
514 xBC = F.silu(xBC)
516 x, B, C = torch.split(
517 xBC,
518 [self.d_inner, self.ngroups * self.d_state, self.ngroups * self.d_state],
519 dim=-1
520 )
522 A = -torch.exp(self.A_log.float())
524 y, ssm_state = mamba_chunk_scan_combined(
525 rearrange(x, "b l (h p) -> b l h p", p=self.headdim),
526 dt, A,
527 rearrange(B, "b l (g n) -> b l g n", g=self.ngroups),
528 rearrange(C, "b l (g n) -> b l g n", g=self.ngroups),
529 chunk_size=self.chunk_size,
530 D=self.D,
531 z=rearrange(z, "b l (h p) -> b l h p", p=self.headdim),
532 dt_bias=self.dt_bias,
533 dt_softplus=True,
534 return_final_states=True,
535 )
536 y = rearrange(y, "b l h p -> b l (h p)")
537 y = self.out_norm(y)
538 out = self.out_proj(y)
540 cache = {"conv_state": conv_state, "ssm_state": ssm_state}
541 return hidden_states + out, cache
543 def step(self, hidden_states, cache):
544 """Single-token Mamba-2 step with cached conv + SSM state.
546 Uses official causal_conv1d_update and selective_state_update Triton
547 kernels when available (matching the Mamba-2 reference implementation).
548 The SSD chunk-scan kernel and sequential step use algebraically
549 equivalent but numerically distinct accumulation orders, so step
550 outputs will have small relative differences (~0.5%) vs a full
551 forward pass. This is inherent to the SSD formulation and matches
552 the upstream mamba-ssm behaviour.
553 """
554 dtype = hidden_states.dtype
555 h = self.norm(hidden_states) # (bsz, 1, dim)
557 zxbcdt = self.in_proj(h).squeeze(1) # (bsz, d_in_proj)
558 z, xBC, dt = torch.split(
559 zxbcdt,
560 [self.d_inner, self.d_inner + 2 * self.ngroups * self.d_state, self.nheads],
561 dim=-1
562 )
564 conv_state = cache["conv_state"]
566 if causal_conv1d_update is not None and xBC.is_cuda: 566 ↛ 567line 566 didn't jump to line 567 because the condition on line 566 was never true
567 xBC = causal_conv1d_update(
568 xBC,
569 conv_state,
570 rearrange(self.conv1d.weight, "d 1 w -> d w"),
571 self.conv1d.bias,
572 activation="silu",
573 )
574 else:
575 conv_state = torch.roll(conv_state, shifts=-1, dims=-1)
576 conv_state[:, :, -1] = xBC
577 cache["conv_state"] = conv_state
578 xBC = (conv_state * self.conv1d.weight.squeeze(1)).sum(-1) + self.conv1d.bias
579 xBC = F.silu(xBC)
581 x, B, C = torch.split(
582 xBC,
583 [self.d_inner, self.ngroups * self.d_state, self.ngroups * self.d_state],
584 dim=-1
585 )
587 A = -torch.exp(self.A_log.float())
588 ssm_state = cache["ssm_state"] # (bsz, nheads, headdim, d_state)
590 if selective_state_update is not None and ssm_state.is_cuda: 590 ↛ 591line 590 didn't jump to line 591 because the condition on line 590 was never true
591 x = rearrange(x, "b (h p) -> b h p", p=self.headdim)
592 z = rearrange(z, "b (h p) -> b h p", p=self.headdim)
593 B = rearrange(B, "b (g n) -> b g n", g=self.ngroups)
594 C = rearrange(C, "b (g n) -> b g n", g=self.ngroups)
595 # selective_state_update expects per-element A/dt/D with stride-0
596 # broadcasting along headdim (detected via tie_hdim optimisation).
597 A_ssm = A.view(self.nheads, 1, 1).expand(self.nheads, self.headdim, self.d_state)
598 dt_ssm = dt.unsqueeze(-1).expand(-1, -1, self.headdim)
599 dt_bias_ssm = self.dt_bias.unsqueeze(-1).expand(-1, self.headdim)
600 D_ssm = self.D.unsqueeze(-1).expand(-1, self.headdim)
601 y = selective_state_update(
602 ssm_state, x, dt_ssm, A_ssm, B, C,
603 D=D_ssm, z=z,
604 dt_bias=dt_bias_ssm, dt_softplus=True,
605 )
606 y = rearrange(y, "b h p -> b (h p)")
607 else:
608 x = rearrange(x, "b (h p) -> b h p", p=self.headdim).float()
609 B = rearrange(B, "b (g n) -> b g n", g=self.ngroups).float()
610 C = rearrange(C, "b (g n) -> b g n", g=self.ngroups).float()
611 heads_per_group = self.nheads // self.ngroups
612 B = B.repeat_interleave(heads_per_group, dim=1)
613 C = C.repeat_interleave(heads_per_group, dim=1)
615 dt_act = F.softplus(dt.float() + self.dt_bias.float())
616 dA = torch.exp(dt_act * A)
617 dB = dt_act.unsqueeze(-1) * B
619 ssm_state = dA.unsqueeze(-1).unsqueeze(-1) * ssm_state + torch.einsum("bhn,bhp->bhpn", dB, x)
620 cache["ssm_state"] = ssm_state
622 y = torch.einsum("bhpn,bhn->bhp", ssm_state, C)
623 y = y + self.D.float().unsqueeze(-1) * x
624 z = rearrange(z, "b (h p) -> b h p", p=self.headdim).float()
625 y = y * F.silu(z)
626 y = rearrange(y, "b h p -> b (h p)")
628 y = y.unsqueeze(1).to(dtype) # (bsz, 1, d_inner)
629 y = self.out_norm(y)
630 out = self.out_proj(y)
632 return hidden_states + out
634class BitMambaLLM(nn.Module):
635 def __init__(self, vocab_size=64000, dim=1024, n_layers=40, d_state=128, expand=2,
636 headdim=64, ngroups=1, chunk_size=256, use_checkpoint=False,
637 use_attn=False, attn_pct=0.08):
638 super().__init__()
639 self.vocab_size = vocab_size
640 self.use_checkpoint = use_checkpoint
641 self.use_attn = use_attn
642 self.attn_pct = attn_pct
644 # Compute attention layer indices (evenly dispersed as in Nemotron-H)
645 if use_attn and attn_pct > 0:
646 n_attn = max(1, int(n_layers * attn_pct))
647 step = n_layers // n_attn
648 self.attn_indices = set(range(step // 2, n_layers, step))
649 else:
650 self.attn_indices = set()
652 self.tok_embeddings = nn.Embedding(vocab_size, dim)
654 self.layers = nn.ModuleList()
655 for i in range(n_layers):
656 if i in self.attn_indices:
657 # Lightweight attention layer (GQA with 4 KV heads)
658 self.layers.append(AttentionBlock(dim, n_kv_heads=4))
659 else:
660 self.layers.append(BitMambaBlock(
661 dim, d_state=d_state, expand=expand,
662 headdim=headdim, ngroups=ngroups, chunk_size=chunk_size,
663 ))
665 self.norm = RMSNorm(dim)
666 self.output = BitLinear(dim, vocab_size, bias=False)
668 def _backbone(self, input_ids, seq_idx=None):
669 """Shared backbone: embedding → layers → norm. Used by both forward and forward_hidden."""
670 from torch.utils.checkpoint import checkpoint
671 x = self.tok_embeddings(input_ids)
672 for layer in self.layers:
673 if self.use_checkpoint and self.training:
674 x = checkpoint(layer, x, seq_idx, use_reentrant=False)
675 else:
676 x = layer(x, seq_idx=seq_idx)
677 return self.norm(x)
679 def forward(self, input_ids, seq_idx=None, targets=None):
680 hidden_states = self._backbone(input_ids, seq_idx)
682 if targets is not None: 682 ↛ 683line 682 didn't jump to line 683 because the condition on line 682 was never true
683 return chunked_cross_entropy(hidden_states, self.output, targets, return_stats=True)
685 return self.output(hidden_states)
687 def forward_hidden(self, input_ids, seq_idx=None):
688 """Return pre-logit hidden states (G2: for chunked cross-entropy)."""
689 return self._backbone(input_ids, seq_idx)
691 def prepare_for_inference(self):
692 self.eval()
693 for module in self.modules():
694 if isinstance(module, BitLinear):
695 module.prepare_for_inference()
696 return self
698 @torch.no_grad()
699 def generate(self, input_ids, max_new_tokens=512, temperature=0.7,
700 do_sample=True, eos_token_id=None):
701 """O(n) autoregressive generation with cached SSM/KV states.
703 Prefills the prompt in one pass, then decodes one token at a time
704 using per-layer conv/SSM state (Mamba blocks) and KV cache (attention blocks).
705 """
706 self.prepare_for_inference()
708 # --- Prefill: process the full prompt ---
709 x = self.tok_embeddings(input_ids)
710 caches = []
711 for layer in self.layers:
712 x, cache = layer.prefill(x)
713 caches.append(cache)
714 x = self.norm(x)
715 logits = self.output(x[:, -1:, :]) # (B, 1, vocab)
717 # --- Decode: one token at a time ---
718 generated = input_ids
719 for _ in range(max_new_tokens):
720 if do_sample and temperature > 0:
721 next_token = torch.multinomial(
722 F.softmax(logits[0, -1, :] / temperature, dim=-1), num_samples=1
723 )
724 else:
725 next_token = logits[0, -1, :].argmax(dim=-1, keepdim=True)
727 generated = torch.cat([generated, next_token.unsqueeze(0)], dim=-1)
729 if eos_token_id is not None and next_token.item() == eos_token_id:
730 break
732 # Single-token forward through cached layers
733 x = self.tok_embeddings(next_token.unsqueeze(0)) # (1, 1, dim)
734 for i, layer in enumerate(self.layers):
735 x = layer.step(x, caches[i])
736 x = self.norm(x)
737 logits = self.output(x) # (1, 1, vocab)
739 return generated
742def chunked_cross_entropy(hidden, output_proj, targets, chunk_size=1024, ignore_index=-100, return_stats=False):
743 """Compute cross-entropy without materializing the full logits tensor.
745 Instead of computing logits for the entire sequence at once (which for
746 [BS, 16384, 64000] costs ~3.9 GB), this applies the output projection
747 in chunks of ``chunk_size`` tokens along the sequence dimension.
749 Args:
750 hidden: (BS, seq_len, dim) — pre-logit hidden states from forward_hidden
751 output_proj: nn.Module — the output head (model.output)
752 targets: (BS, seq_len) — target token ids
753 chunk_size: int — number of tokens per chunk (default 1024)
754 ignore_index: int — label to ignore in loss (default -100)
756 Returns:
757 Scalar loss (mean cross-entropy over all tokens), or a tuple of
758 ``(total_loss_sum, valid_tokens)`` when ``return_stats=True``.
759 """
760 bs, seq_len, _ = hidden.shape
761 total_loss = 0.0
762 valid_tokens = (targets != ignore_index).sum().clamp(min=1)
764 for i in range(0, seq_len, chunk_size):
765 end = min(i + chunk_size, seq_len)
766 chunk_logits = output_proj(hidden[:, i:end, :]) # (BS, chunk, vocab)
767 chunk_targets = targets[:, i:end].reshape(-1) # (BS * chunk,)
768 total_loss += F.cross_entropy(
769 chunk_logits.reshape(-1, chunk_logits.size(-1)),
770 chunk_targets,
771 reduction='sum',
772 ignore_index=ignore_index,
773 )
775 if return_stats:
776 return total_loss, valid_tokens
777 return total_loss / valid_tokens