Coverage for gpu_smoke_test.py: 100%
17 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-21 23:06 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-21 23:06 +0000
1# Copyright 2026 venim1103
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
15"""Quick GPU smoke test for BitMamba/Triton path.
17Usage:
18 python gpu_smoke_test.py
19"""
21import torch
23from model import BitMambaLLM
26def main() -> int:
27 print(f"torch={torch.__version__}, cuda_compiled={torch.version.cuda}")
29 if not torch.cuda.is_available():
30 print("FAIL: torch.cuda.is_available() is False")
31 return 1
33 device = "cuda"
34 print(f"Using GPU: {torch.cuda.get_device_name(0)}")
36 model = BitMambaLLM(
37 vocab_size=64000,
38 dim=512,
39 n_layers=24,
40 d_state=64,
41 expand=2,
42 use_attn=True,
43 ).to(device)
44 model.eval()
46 x = torch.randint(0, 64000, (1, 32), device=device)
48 with torch.no_grad():
49 y = model(x)
51 print("PASS")
52 print(f"input_shape={tuple(x.shape)} output_shape={tuple(y.shape)} device={y.device}")
53 return 0
56if __name__ == "__main__":
57 raise SystemExit(main())