Coverage for inference.py: 100%

42 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-21 23:06 +0000

1# Copyright 2026 venim1103 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14 

15import os 

16import torch 

17from transformers import AutoTokenizer 

18from context_config import CONTEXT_LENGTH 

19from model import BitMambaLLM, maybe_autocast 

20 

21DEVICE = "cuda" if torch.cuda.is_available() else "cpu" 

22 

23def resolve_model_settings(mode: str): 

24 mode = mode.lower() 

25 if mode == "scout": 

26 return ( 

27 dict(vocab_size=64000, dim=512, n_layers=24, d_state=64, expand=2), 

28 "checkpoints/bitmamba_scout/step_100000.pt", 

29 ) 

30 if mode == "parent": 

31 return ( 

32 dict(vocab_size=64000, dim=1024, n_layers=40, d_state=128, expand=2), 

33 "checkpoints/bitmamba_parent/step_1000000.pt", 

34 ) 

35 raise ValueError(f"Unsupported MODE '{mode}'. Expected 'scout' or 'parent'.") 

36 

37 

38MODE = os.environ.get("MODE", "scout") 

39MODEL_CONFIG, CKPT_PATH = resolve_model_settings(MODE) 

40 

41def generate(model, tokenizer, prompt, max_new_tokens=150, temperature=0.7): 

42 model.eval() 

43 input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(DEVICE) 

44 eos_id = tokenizer.encode("<|im_end|>", add_special_tokens=False)[0] 

45 

46 print(f"\nPrompt: {prompt}") 

47 print("Generating...", flush=True) 

48 

49 with maybe_autocast(DEVICE): 

50 output_ids = model.generate( 

51 input_ids, max_new_tokens=max_new_tokens, temperature=temperature, 

52 do_sample=True, eos_token_id=eos_id 

53 ) 

54 

55 generated_text = tokenizer.decode(output_ids[0][input_ids.shape[1]:], skip_special_tokens=False) 

56 print(generated_text) 

57 print() 

58 

59def main() -> int: 

60 print("Loading custom tokenizer...") 

61 tokenizer = AutoTokenizer.from_pretrained("custom_agentic_tokenizer") 

62 tokenizer.model_max_length = CONTEXT_LENGTH 

63 

64 print(f"Loading {MODE.upper()} BitMamba model from {CKPT_PATH}...") 

65 model = BitMambaLLM(**MODEL_CONFIG).to(DEVICE) 

66 

67 try: 

68 model.load_state_dict(torch.load(CKPT_PATH, map_location=DEVICE)['model_state_dict']) 

69 except FileNotFoundError: 

70 print(f"Error: Could not find {CKPT_PATH}.") 

71 return 1 

72 model.prepare_for_inference() 

73 

74 print("\n--- Testing Model Logic ---") 

75 chat_prompt = "<|im_start|>system\nYou are a deductive reasoning agent. You must analyze the user's request step-by-step within <think> tags before acting.<|im_end|>\n<|im_start|>user\nIf I have 3 apples and eat 1, how many are left?<|im_end|>\n<|im_start|>assistant\n<think>\n" 

76 generate(model, tokenizer, prompt=chat_prompt, max_new_tokens=150) 

77 return 0 

78 

79 

80if __name__ == "__main__": 

81 raise SystemExit(main()) 

82