Coverage for upscale.py: 93%
45 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-21 23:06 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-21 23:06 +0000
1# Copyright 2026 venim1103
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
15import torch
16import os
17from model import BitMambaLLM
19def upscaler(small_ckpt_path, output_ckpt_path):
20 print(f"Loading trained 40-Layer BitMamba Parent from {small_ckpt_path}...")
21 checkpoint = torch.load(small_ckpt_path, map_location="cpu")
22 small_state_dict = checkpoint['model_state_dict']
24 print("Initializing new 64-Layer Upscaled BitMamba Model...")
25 big_model = BitMambaLLM(
26 vocab_size=64000, dim=1024, n_layers=64, d_state=128, expand=2,
27 use_attn=True, attn_pct=0.08,
28 )
29 big_state_dict = big_model.state_dict()
31 # Identify which layers are attention vs mamba in each model
32 # Parent was trained WITHOUT attention (pure Mamba), so use_attn=False
33 small_model_tmp = BitMambaLLM(
34 vocab_size=64000, dim=1024, n_layers=40, d_state=128, expand=2,
35 use_attn=False,
36 )
37 small_attn_indices = small_model_tmp.attn_indices
38 big_attn_indices = big_model.attn_indices
39 del small_model_tmp
41 print(f"Small model attention layers: {sorted(small_attn_indices)}")
42 print(f"Big model attention layers: {sorted(big_attn_indices)}")
44 print("Transplanting weights via SOLAR duplication...")
45 for key in big_state_dict.keys():
46 if key.startswith("tok_embeddings.") or key.startswith("norm.") or key.startswith("output."):
47 big_state_dict[key] = small_state_dict[key]
48 elif key.startswith("layers."): 48 ↛ 45line 48 didn't jump to line 45 because the condition on line 48 was always true
49 parts = key.split('.')
50 target_layer_idx = int(parts[1])
52 if target_layer_idx < 32: 52 ↛ 53line 52 didn't jump to line 53 because the condition on line 52 was never true
53 source_layer_idx = target_layer_idx
54 else:
55 source_layer_idx = target_layer_idx - 24
57 # If layer types differ (attn vs mamba), keep random init
58 target_is_attn = target_layer_idx in big_attn_indices
59 source_is_attn = source_layer_idx in small_attn_indices
60 if target_is_attn != source_is_attn:
61 continue # skip — incompatible layer type, keep random init
63 parts[1] = str(source_layer_idx)
64 source_key = '.'.join(parts)
65 if source_key in small_state_dict: 65 ↛ 45line 65 didn't jump to line 45 because the condition on line 65 was always true
66 big_state_dict[key] = small_state_dict[source_key]
68 print("Loading mapped weights into the new model...")
69 big_model.load_state_dict(big_state_dict)
71 print(f"Saving Upscaled Checkpoint to {output_ckpt_path}...")
72 torch.save({
73 'step': 0,
74 'model_state_dict': big_state_dict,
75 'source_checkpoint': small_ckpt_path,
76 'requires_continued_pretraining': True,
77 'recommended_mode': 'upscaled',
78 }, output_ckpt_path)
79 print("Done! Continued pre-training is REQUIRED after upscaling.")
80 print("")
81 print("To continue pre-training the upscaled model:")
82 print(" 1. Set MODE='upscaled' in train.py (line 24)")
83 print(" 2. Run: MODE=upscaled python train.py")
84 print(" 3. This will train for 20k steps at lower LR (1e-4) to let")
85 print(" duplicated layers differentiate (MiniPuzzle-inspired)")
87if __name__ == "__main__":
88 os.makedirs("checkpoints/upscaled", exist_ok=True)
89 upscaler("checkpoints/bitmamba_parent/step_1000000.pt", "checkpoints/upscaled/step_000000_1B_mamba.pt")