Coverage for gpu_smoke_test.py: 100%

17 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-21 23:06 +0000

1# Copyright 2026 venim1103 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14 

15"""Quick GPU smoke test for BitMamba/Triton path. 

16 

17Usage: 

18 python gpu_smoke_test.py 

19""" 

20 

21import torch 

22 

23from model import BitMambaLLM 

24 

25 

26def main() -> int: 

27 print(f"torch={torch.__version__}, cuda_compiled={torch.version.cuda}") 

28 

29 if not torch.cuda.is_available(): 

30 print("FAIL: torch.cuda.is_available() is False") 

31 return 1 

32 

33 device = "cuda" 

34 print(f"Using GPU: {torch.cuda.get_device_name(0)}") 

35 

36 model = BitMambaLLM( 

37 vocab_size=64000, 

38 dim=512, 

39 n_layers=24, 

40 d_state=64, 

41 expand=2, 

42 use_attn=True, 

43 ).to(device) 

44 model.eval() 

45 

46 x = torch.randint(0, 64000, (1, 32), device=device) 

47 

48 with torch.no_grad(): 

49 y = model(x) 

50 

51 print("PASS") 

52 print(f"input_shape={tuple(x.shape)} output_shape={tuple(y.shape)} device={y.device}") 

53 return 0 

54 

55 

56if __name__ == "__main__": 

57 raise SystemExit(main())