Coverage for upscale.py: 93%

45 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-21 23:06 +0000

1# Copyright 2026 venim1103 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14 

15import torch 

16import os 

17from model import BitMambaLLM 

18 

19def upscaler(small_ckpt_path, output_ckpt_path): 

20 print(f"Loading trained 40-Layer BitMamba Parent from {small_ckpt_path}...") 

21 checkpoint = torch.load(small_ckpt_path, map_location="cpu") 

22 small_state_dict = checkpoint['model_state_dict'] 

23 

24 print("Initializing new 64-Layer Upscaled BitMamba Model...") 

25 big_model = BitMambaLLM( 

26 vocab_size=64000, dim=1024, n_layers=64, d_state=128, expand=2, 

27 use_attn=True, attn_pct=0.08, 

28 ) 

29 big_state_dict = big_model.state_dict() 

30 

31 # Identify which layers are attention vs mamba in each model 

32 # Parent was trained WITHOUT attention (pure Mamba), so use_attn=False 

33 small_model_tmp = BitMambaLLM( 

34 vocab_size=64000, dim=1024, n_layers=40, d_state=128, expand=2, 

35 use_attn=False, 

36 ) 

37 small_attn_indices = small_model_tmp.attn_indices 

38 big_attn_indices = big_model.attn_indices 

39 del small_model_tmp 

40 

41 print(f"Small model attention layers: {sorted(small_attn_indices)}") 

42 print(f"Big model attention layers: {sorted(big_attn_indices)}") 

43 

44 print("Transplanting weights via SOLAR duplication...") 

45 for key in big_state_dict.keys(): 

46 if key.startswith("tok_embeddings.") or key.startswith("norm.") or key.startswith("output."): 

47 big_state_dict[key] = small_state_dict[key] 

48 elif key.startswith("layers."): 48 ↛ 45line 48 didn't jump to line 45 because the condition on line 48 was always true

49 parts = key.split('.') 

50 target_layer_idx = int(parts[1]) 

51 

52 if target_layer_idx < 32: 52 ↛ 53line 52 didn't jump to line 53 because the condition on line 52 was never true

53 source_layer_idx = target_layer_idx 

54 else: 

55 source_layer_idx = target_layer_idx - 24 

56 

57 # If layer types differ (attn vs mamba), keep random init 

58 target_is_attn = target_layer_idx in big_attn_indices 

59 source_is_attn = source_layer_idx in small_attn_indices 

60 if target_is_attn != source_is_attn: 

61 continue # skip — incompatible layer type, keep random init 

62 

63 parts[1] = str(source_layer_idx) 

64 source_key = '.'.join(parts) 

65 if source_key in small_state_dict: 65 ↛ 45line 65 didn't jump to line 45 because the condition on line 65 was always true

66 big_state_dict[key] = small_state_dict[source_key] 

67 

68 print("Loading mapped weights into the new model...") 

69 big_model.load_state_dict(big_state_dict) 

70 

71 print(f"Saving Upscaled Checkpoint to {output_ckpt_path}...") 

72 torch.save({ 

73 'step': 0, 

74 'model_state_dict': big_state_dict, 

75 'source_checkpoint': small_ckpt_path, 

76 'requires_continued_pretraining': True, 

77 'recommended_mode': 'upscaled', 

78 }, output_ckpt_path) 

79 print("Done! Continued pre-training is REQUIRED after upscaling.") 

80 print("") 

81 print("To continue pre-training the upscaled model:") 

82 print(" 1. Set MODE='upscaled' in train.py (line 24)") 

83 print(" 2. Run: MODE=upscaled python train.py") 

84 print(" 3. This will train for 20k steps at lower LR (1e-4) to let") 

85 print(" duplicated layers differentiate (MiniPuzzle-inspired)") 

86 

87if __name__ == "__main__": 

88 os.makedirs("checkpoints/upscaled", exist_ok=True) 

89 upscaler("checkpoints/bitmamba_parent/step_1000000.pt", "checkpoints/upscaled/step_000000_1B_mamba.pt") 

90