Coverage for inference.py: 100%
42 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-21 23:06 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-21 23:06 +0000
1# Copyright 2026 venim1103
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
15import os
16import torch
17from transformers import AutoTokenizer
18from context_config import CONTEXT_LENGTH
19from model import BitMambaLLM, maybe_autocast
21DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
23def resolve_model_settings(mode: str):
24 mode = mode.lower()
25 if mode == "scout":
26 return (
27 dict(vocab_size=64000, dim=512, n_layers=24, d_state=64, expand=2),
28 "checkpoints/bitmamba_scout/step_100000.pt",
29 )
30 if mode == "parent":
31 return (
32 dict(vocab_size=64000, dim=1024, n_layers=40, d_state=128, expand=2),
33 "checkpoints/bitmamba_parent/step_1000000.pt",
34 )
35 raise ValueError(f"Unsupported MODE '{mode}'. Expected 'scout' or 'parent'.")
38MODE = os.environ.get("MODE", "scout")
39MODEL_CONFIG, CKPT_PATH = resolve_model_settings(MODE)
41def generate(model, tokenizer, prompt, max_new_tokens=150, temperature=0.7):
42 model.eval()
43 input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(DEVICE)
44 eos_id = tokenizer.encode("<|im_end|>", add_special_tokens=False)[0]
46 print(f"\nPrompt: {prompt}")
47 print("Generating...", flush=True)
49 with maybe_autocast(DEVICE):
50 output_ids = model.generate(
51 input_ids, max_new_tokens=max_new_tokens, temperature=temperature,
52 do_sample=True, eos_token_id=eos_id
53 )
55 generated_text = tokenizer.decode(output_ids[0][input_ids.shape[1]:], skip_special_tokens=False)
56 print(generated_text)
57 print()
59def main() -> int:
60 print("Loading custom tokenizer...")
61 tokenizer = AutoTokenizer.from_pretrained("custom_agentic_tokenizer")
62 tokenizer.model_max_length = CONTEXT_LENGTH
64 print(f"Loading {MODE.upper()} BitMamba model from {CKPT_PATH}...")
65 model = BitMambaLLM(**MODEL_CONFIG).to(DEVICE)
67 try:
68 model.load_state_dict(torch.load(CKPT_PATH, map_location=DEVICE)['model_state_dict'])
69 except FileNotFoundError:
70 print(f"Error: Could not find {CKPT_PATH}.")
71 return 1
72 model.prepare_for_inference()
74 print("\n--- Testing Model Logic ---")
75 chat_prompt = "<|im_start|>system\nYou are a deductive reasoning agent. You must analyze the user's request step-by-step within <think> tags before acting.<|im_end|>\n<|im_start|>user\nIf I have 3 apples and eat 1, how many are left?<|im_end|>\n<|im_start|>assistant\n<think>\n"
76 generate(model, tokenizer, prompt=chat_prompt, max_new_tokens=150)
77 return 0
80if __name__ == "__main__":
81 raise SystemExit(main())