Spaces:
Running
Running
Update downloads/interactive.py
Browse files- downloads/interactive.py +1011 -156
downloads/interactive.py
CHANGED
|
@@ -25,10 +25,12 @@ import torch.nn as nn
|
|
| 25 |
import torch.nn.functional as F
|
| 26 |
from torch.utils.checkpoint import checkpoint
|
| 27 |
|
|
|
|
| 28 |
HUGGINGFACE_MODELS = {
|
| 29 |
"TMLM-Haiku-1": "CompactAI-O/TMLM-Haiku-1",
|
| 30 |
"TMLM-Haiku-1.3": "CompactAI-O/TMLM-Haiku-1.3",
|
| 31 |
"TMLM-Haiku-2": "CompactAI-O/TMLM-Haiku-2",
|
|
|
|
| 32 |
}
|
| 33 |
|
| 34 |
|
|
@@ -79,6 +81,15 @@ MODEL_SERIES = {
|
|
| 79 |
"engram_table_size": 64,
|
| 80 |
"engram_max_ngram": 2,
|
| 81 |
"mhc_expansion": 2,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 82 |
},
|
| 83 |
"sonnet": {
|
| 84 |
"dim": 1024,
|
|
@@ -95,6 +106,15 @@ MODEL_SERIES = {
|
|
| 95 |
"engram_table_size": 4096,
|
| 96 |
"engram_max_ngram": 2,
|
| 97 |
"mhc_expansion": 2,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 98 |
},
|
| 99 |
"opus": {
|
| 100 |
"dim": 1536,
|
|
@@ -111,6 +131,15 @@ MODEL_SERIES = {
|
|
| 111 |
"engram_table_size": 8192,
|
| 112 |
"engram_max_ngram": 2,
|
| 113 |
"mhc_expansion": 4,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 114 |
},
|
| 115 |
}
|
| 116 |
|
|
@@ -422,6 +451,68 @@ class SwiGLU(nn.Module):
|
|
| 422 |
return out
|
| 423 |
|
| 424 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 425 |
class EngramBlock(nn.Module):
|
| 426 |
"""DeepSeek Engram: conditional memory via O(1) hashed N-gram lookup.
|
| 427 |
|
|
@@ -566,6 +657,115 @@ class EngramBlock(nn.Module):
|
|
| 566 |
return gate * value
|
| 567 |
|
| 568 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 569 |
def _sinkhorn_knopp(logits: torch.Tensor, n_iters: int = 7) -> torch.Tensor:
|
| 570 |
M = torch.exp(logits.clamp(-10, 10))
|
| 571 |
for _ in range(n_iters):
|
|
@@ -733,6 +933,85 @@ class TransformerBlock(nn.Module):
|
|
| 733 |
return x, new_kv
|
| 734 |
|
| 735 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 736 |
class TinyMemoryLM(nn.Module):
|
| 737 |
def __init__(
|
| 738 |
self,
|
|
@@ -754,6 +1033,17 @@ class TinyMemoryLM(nn.Module):
|
|
| 754 |
engram_table_size: int = 8192,
|
| 755 |
engram_max_ngram: int = 3,
|
| 756 |
mhc_expansion: int = 1,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 757 |
) -> None:
|
| 758 |
super().__init__()
|
| 759 |
self.dim = dim
|
|
@@ -766,29 +1056,45 @@ class TinyMemoryLM(nn.Module):
|
|
| 766 |
self.embed_tokens = nn.Embedding(vocab_size, dim)
|
| 767 |
self.head = nn.Linear(dim, vocab_size, bias=False)
|
| 768 |
self.head.weight = self.embed_tokens.weight
|
| 769 |
-
|
| 770 |
self.output_bias = nn.Parameter(torch.zeros(vocab_size))
|
| 771 |
|
| 772 |
-
self.
|
| 773 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 774 |
TransformerBlock(
|
| 775 |
-
dim=dim,
|
| 776 |
-
|
| 777 |
-
|
| 778 |
-
|
| 779 |
-
|
| 780 |
-
dropout=dropout,
|
| 781 |
-
sliding_window=sliding_window,
|
| 782 |
-
rope_fraction=rope_fraction,
|
| 783 |
-
engram_dim=engram_dim,
|
| 784 |
-
engram_heads=engram_heads,
|
| 785 |
-
engram_table_size=engram_table_size,
|
| 786 |
-
engram_max_ngram=engram_max_ngram,
|
| 787 |
-
mhc_expansion=mhc_expansion,
|
| 788 |
)
|
| 789 |
-
for _ in range(
|
| 790 |
-
]
|
| 791 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 792 |
self.norm = RMSNorm(dim)
|
| 793 |
|
| 794 |
self.mtp_horizons = sorted({int(h) for h in mtp_horizons if int(h) > 1})
|
|
@@ -799,10 +1105,37 @@ class TinyMemoryLM(nn.Module):
|
|
| 799 |
{str(h): RMSNorm(dim) for h in self.mtp_horizons}
|
| 800 |
)
|
| 801 |
|
| 802 |
-
res_scale = (2 *
|
| 803 |
-
for
|
| 804 |
-
|
| 805 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 806 |
|
| 807 |
def resize_token_embeddings(self, new_vocab_size: int) -> None:
|
| 808 |
old_vocab_size = self.embed_tokens.num_embeddings
|
|
@@ -810,12 +1143,8 @@ class TinyMemoryLM(nn.Module):
|
|
| 810 |
return
|
| 811 |
device = self.embed_tokens.weight.device
|
| 812 |
old_embed_weight = self.embed_tokens.weight.data.clone()
|
| 813 |
-
self.embed_tokens = nn.Embedding(
|
| 814 |
-
|
| 815 |
-
).to(device)
|
| 816 |
-
self.head = nn.Linear(
|
| 817 |
-
self.embed_tokens.embedding_dim, new_vocab_size, bias=False
|
| 818 |
-
).to(device)
|
| 819 |
self.head.weight = self.embed_tokens.weight
|
| 820 |
old_bias = self.output_bias.data.clone()
|
| 821 |
self.output_bias = nn.Parameter(torch.zeros(new_vocab_size, device=device))
|
|
@@ -824,62 +1153,74 @@ class TinyMemoryLM(nn.Module):
|
|
| 824 |
self.embed_tokens.weight.data[:copy_size] = old_embed_weight[:copy_size]
|
| 825 |
|
| 826 |
def _build_logical_layers(self) -> List[Tuple[nn.Module, int]]:
|
| 827 |
-
|
|
|
|
| 828 |
blocks_list = list(self.blocks)
|
| 829 |
full_sequence = blocks_list + blocks_list
|
| 830 |
-
for
|
| 831 |
-
logical.append((block, logical_idx))
|
| 832 |
-
return logical
|
| 833 |
|
| 834 |
def forward(
|
| 835 |
self,
|
| 836 |
ids: torch.Tensor,
|
| 837 |
use_cache: bool = False,
|
| 838 |
-
past_key_values: Optional[
|
| 839 |
-
List[Optional[Tuple[torch.Tensor, torch.Tensor]]]
|
| 840 |
-
] = None,
|
| 841 |
return_hidden: bool = False,
|
| 842 |
-
) -> Tuple[
|
| 843 |
-
torch.Tensor,
|
| 844 |
-
Dict[int, torch.Tensor],
|
| 845 |
-
Optional[torch.Tensor],
|
| 846 |
-
Optional[List[Tuple[torch.Tensor, torch.Tensor]]],
|
| 847 |
-
]:
|
| 848 |
B, T = ids.shape
|
| 849 |
x = self.embed_tokens(ids) * self.embed_scale_factor
|
| 850 |
-
|
| 851 |
-
|
| 852 |
-
|
| 853 |
-
|
| 854 |
-
|
| 855 |
-
|
| 856 |
-
|
| 857 |
-
|
| 858 |
-
|
| 859 |
-
|
| 860 |
-
|
| 861 |
-
|
| 862 |
-
|
| 863 |
-
|
|
|
|
|
|
|
| 864 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 865 |
|
| 866 |
-
|
| 867 |
-
x, layer_kv = checkpoint(
|
| 868 |
-
block,
|
| 869 |
-
x,
|
| 870 |
-
is_global,
|
| 871 |
-
past_kv,
|
| 872 |
-
use_cache,
|
| 873 |
-
token_ids,
|
| 874 |
-
use_reentrant=True,
|
| 875 |
-
)
|
| 876 |
-
else:
|
| 877 |
-
x, layer_kv = block(x, is_global, past_kv, use_cache, token_ids)
|
| 878 |
|
| 879 |
-
|
| 880 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 881 |
|
| 882 |
-
x = self.norm(x)
|
| 883 |
h_out = x if return_hidden else None
|
| 884 |
logits = self.head(x)
|
| 885 |
if self.embed_scale_factor != 1.0:
|
|
@@ -899,7 +1240,7 @@ class TinyMemoryLM(nn.Module):
|
|
| 899 |
mtp_logits = mtp_logits + self.output_bias
|
| 900 |
mtp[horizon] = mtp_logits
|
| 901 |
|
| 902 |
-
return logits, mtp, h_out, new_past_key_values
|
| 903 |
|
| 904 |
|
| 905 |
# ---------------------------------------------------------------------------
|
|
@@ -1011,7 +1352,7 @@ def generate(
|
|
| 1011 |
ctx_ids = (
|
| 1012 |
input_ids_t[:, -context_window:] if context_window > 0 else input_ids_t
|
| 1013 |
)
|
| 1014 |
-
logits, _
|
| 1015 |
next_logits = logits[0, -1, :].clone()
|
| 1016 |
|
| 1017 |
# Logit soft-capping (Gemma-style) — prevents overconfident collapse
|
|
@@ -1122,8 +1463,16 @@ def discover_models(runs_dir: Path) -> List[dict]:
|
|
| 1122 |
if not tokenizer_path.exists():
|
| 1123 |
continue
|
| 1124 |
name = child.name
|
| 1125 |
-
series =
|
| 1126 |
for ckpt_name in ("model.pt", "pretrain.pt"):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1127 |
ckpt_path = child / ckpt_name
|
| 1128 |
if ckpt_path.exists():
|
| 1129 |
models.append(
|
|
@@ -1135,6 +1484,23 @@ def discover_models(runs_dir: Path) -> List[dict]:
|
|
| 1135 |
"tokenizer_path": tokenizer_path,
|
| 1136 |
}
|
| 1137 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1138 |
return models
|
| 1139 |
|
| 1140 |
|
|
@@ -1153,49 +1519,138 @@ def _detect_mhc(state_dict):
|
|
| 1153 |
return 1
|
| 1154 |
|
| 1155 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1156 |
def _infer_arch_from_state_dict(state_dict, cfg):
|
| 1157 |
"""Infer architecture hyper-parameters directly from checkpoint weights,
|
| 1158 |
falling back to *cfg* (series config) when a key is not found."""
|
| 1159 |
overrides = {}
|
| 1160 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1161 |
# dim from embed_tokens.weight [vocab, dim]
|
| 1162 |
if "embed_tokens.weight" in state_dict:
|
| 1163 |
overrides["dim"] = state_dict["embed_tokens.weight"].shape[1]
|
| 1164 |
|
| 1165 |
-
|
| 1166 |
-
|
| 1167 |
-
|
| 1168 |
-
|
| 1169 |
-
|
| 1170 |
-
|
| 1171 |
-
|
| 1172 |
-
|
| 1173 |
-
|
| 1174 |
-
|
| 1175 |
-
|
| 1176 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1177 |
|
| 1178 |
-
# n_heads from wq [n_heads*head_dim, dim] and wk [n_kv*head_dim, dim]
|
| 1179 |
dim = overrides.get("dim", int(cfg.get("dim", model_config.dim)))
|
| 1180 |
-
if "
|
| 1181 |
-
wq_rows = state_dict["
|
| 1182 |
-
if "
|
| 1183 |
-
head_dim = state_dict["
|
| 1184 |
overrides["n_heads"] = wq_rows // head_dim
|
| 1185 |
-
if "
|
| 1186 |
-
wk_rows = state_dict["
|
| 1187 |
-
if "
|
| 1188 |
-
head_dim = state_dict["
|
| 1189 |
overrides["n_kv_heads"] = wk_rows // head_dim
|
| 1190 |
|
| 1191 |
-
# engram params
|
| 1192 |
for key, val in state_dict.items():
|
| 1193 |
if ".engram.embeddings." in key and key.endswith("_0") and val.dim() == 2:
|
| 1194 |
overrides["engram_table_size"] = val.shape[0]
|
| 1195 |
overrides["engram_dim"] = val.shape[1]
|
| 1196 |
break
|
| 1197 |
-
# engram_heads from branch_conv [total_branch_dim, 1, 4]
|
| 1198 |
-
# total_branch_dim = engram_dim * n_heads * (max_ngram - 1)
|
| 1199 |
engram_dim = overrides.get("engram_dim", int(cfg.get("engram_dim", 0)))
|
| 1200 |
engram_max_ngram = int(cfg.get("engram_max_ngram", 2))
|
| 1201 |
if engram_dim > 0:
|
|
@@ -1207,7 +1662,6 @@ def _infer_arch_from_state_dict(state_dict, cfg):
|
|
| 1207 |
overrides["engram_heads"] = total_branch_dim // denom
|
| 1208 |
break
|
| 1209 |
|
| 1210 |
-
# merge: checkpoint values take priority over series config
|
| 1211 |
merged = dict(cfg)
|
| 1212 |
merged.update(overrides)
|
| 1213 |
return merged
|
|
@@ -1221,8 +1675,6 @@ def load_local_model(model_path: Path, tokenizer_path: Path, series: str) -> dic
|
|
| 1221 |
|
| 1222 |
state_dict = ckpt.get("model_state") or ckpt.get("state_dict") or ckpt
|
| 1223 |
|
| 1224 |
-
# Infer architecture from checkpoint weights so config mismatches are
|
| 1225 |
-
# handled automatically.
|
| 1226 |
cfg = _infer_arch_from_state_dict(state_dict, cfg)
|
| 1227 |
|
| 1228 |
engram_dim = int(cfg.get("engram_dim", 0))
|
|
@@ -1233,38 +1685,65 @@ def load_local_model(model_path: Path, tokenizer_path: Path, series: str) -> dic
|
|
| 1233 |
if mhc_expansion == 1:
|
| 1234 |
mhc_expansion = int(cfg.get("mhc_expansion", 1))
|
| 1235 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1236 |
model = TinyMemoryLM(
|
| 1237 |
vocab_size=vocab_size,
|
| 1238 |
dim=int(cfg.get("dim", model_config.dim)),
|
| 1239 |
-
n_unique_layers=
|
| 1240 |
-
n_logical_layers=int(
|
| 1241 |
-
cfg.get("n_logical_layers", model_config.n_logical_layers)
|
| 1242 |
-
),
|
| 1243 |
n_heads=int(cfg.get("n_heads", model_config.n_heads)),
|
| 1244 |
n_kv_heads=int(cfg.get("n_kv_heads", model_config.n_kv_heads)),
|
| 1245 |
ffn_dim=int(cfg.get("ffn_dim", model_config.ffn_dim)),
|
| 1246 |
dropout=float(cfg.get("dropout", model_config.dropout)),
|
| 1247 |
-
mtp_horizons=tuple(
|
| 1248 |
-
int(v) for v in cfg.get("mtp_horizons", model_config.mtp_horizons)
|
| 1249 |
-
),
|
| 1250 |
grad_checkpoint=False,
|
| 1251 |
-
sliding_window=int(
|
| 1252 |
-
|
| 1253 |
-
|
| 1254 |
-
getattr(model_config, "sliding_window_size", 512),
|
| 1255 |
-
)
|
| 1256 |
-
),
|
| 1257 |
-
rope_fraction=float(
|
| 1258 |
-
cfg.get("rope_fraction", getattr(model_config, "rope_fraction", 0.25))
|
| 1259 |
-
),
|
| 1260 |
-
embed_scale=bool(
|
| 1261 |
-
cfg.get("embed_scale", getattr(model_config, "embed_scale", True))
|
| 1262 |
-
),
|
| 1263 |
engram_dim=engram_dim,
|
| 1264 |
engram_heads=int(cfg.get("engram_heads", 4)),
|
| 1265 |
engram_table_size=int(cfg.get("engram_table_size", 8192)),
|
| 1266 |
engram_max_ngram=int(cfg.get("engram_max_ngram", 3)),
|
| 1267 |
mhc_expansion=mhc_expansion,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1268 |
)
|
| 1269 |
model.load_state_dict(state_dict, strict=False)
|
| 1270 |
model.eval()
|
|
@@ -1277,6 +1756,8 @@ def load_local_model(model_path: Path, tokenizer_path: Path, series: str) -> dic
|
|
| 1277 |
"tokenizer": tokenizer,
|
| 1278 |
"device": device,
|
| 1279 |
"series": series,
|
|
|
|
|
|
|
| 1280 |
}
|
| 1281 |
|
| 1282 |
|
|
@@ -1300,7 +1781,13 @@ def download_huggingface_model(hf_id: str, cache_dir: Path) -> dict:
|
|
| 1300 |
|
| 1301 |
print(f"Using cached {hf_id} from {local_dir}")
|
| 1302 |
|
| 1303 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1304 |
model_path = model_dir / "model.pt"
|
| 1305 |
pretrain_path = model_dir / "pretrain.pt"
|
| 1306 |
tokenizer_path = model_dir / "tokenizer.json"
|
|
@@ -1454,6 +1941,305 @@ def compare_all_models(prompt: str, cfg: dict) -> None:
|
|
| 1454 |
print(f"\n{'='*60}")
|
| 1455 |
|
| 1456 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1457 |
# ---------------------------------------------------------------------------
|
| 1458 |
# Interactive CLI
|
| 1459 |
# ---------------------------------------------------------------------------
|
|
@@ -1524,66 +2310,64 @@ def pick_model(runs_dir: Path) -> tuple[dict, str]:
|
|
| 1524 |
# ---------------------------------------------------------------------------
|
| 1525 |
|
| 1526 |
MODES = {
|
| 1527 |
-
# Chat — two flavours
|
| 1528 |
"chat-coherent": {
|
| 1529 |
"label": "Chat — Coherent",
|
| 1530 |
"desc": "structured, consistent, strong repetition control",
|
| 1531 |
"sft_mode": "chat",
|
| 1532 |
-
"temperature": 0.
|
| 1533 |
-
"top_k":
|
| 1534 |
-
"top_p": 0.
|
| 1535 |
-
"min_p": 0.
|
| 1536 |
-
"no_repeat_ngram_size":
|
| 1537 |
-
"repetition_penalty": 1.
|
| 1538 |
-
"logit_soft_cap":
|
| 1539 |
-
"loop_penalty":
|
| 1540 |
-
"max_new_tokens":
|
| 1541 |
"context_window": 2048,
|
| 1542 |
},
|
| 1543 |
"chat-variants": {
|
| 1544 |
"label": "Chat — Variants",
|
| 1545 |
"desc": "creative, diverse, more surprising outputs",
|
| 1546 |
"sft_mode": "chat",
|
| 1547 |
-
"temperature": 0.
|
| 1548 |
-
"top_k":
|
| 1549 |
-
"top_p": 0.
|
| 1550 |
-
"min_p": 0.
|
| 1551 |
-
"no_repeat_ngram_size":
|
| 1552 |
-
"repetition_penalty": 1.
|
| 1553 |
-
"logit_soft_cap":
|
| 1554 |
-
"loop_penalty":
|
| 1555 |
-
"max_new_tokens":
|
| 1556 |
"context_window": 2048,
|
| 1557 |
},
|
| 1558 |
-
# Pretrain — two flavours
|
| 1559 |
"pretrain-coherent": {
|
| 1560 |
"label": "Pretrain — Coherent",
|
| 1561 |
"desc": "grounded continuation, low temperature, tight sampling",
|
| 1562 |
"sft_mode": False,
|
| 1563 |
-
"temperature": 0.
|
| 1564 |
"top_k": 20,
|
| 1565 |
"top_p": 0.85,
|
| 1566 |
"min_p": 0.10,
|
| 1567 |
-
"no_repeat_ngram_size":
|
| 1568 |
"repetition_penalty": 1.2,
|
| 1569 |
-
"logit_soft_cap":
|
| 1570 |
-
"loop_penalty":
|
| 1571 |
-
"max_new_tokens":
|
| 1572 |
"context_window": 2048,
|
| 1573 |
},
|
| 1574 |
"pretrain-variants": {
|
| 1575 |
"label": "Pretrain — Variants",
|
| 1576 |
"desc": "free-form continuation, higher temperature, more exploration",
|
| 1577 |
"sft_mode": False,
|
| 1578 |
-
"temperature": 0.
|
| 1579 |
"top_k": 60,
|
| 1580 |
-
"top_p": 0.
|
| 1581 |
-
"min_p": 0.
|
| 1582 |
"no_repeat_ngram_size": 4,
|
| 1583 |
-
"repetition_penalty": 1.
|
| 1584 |
-
"logit_soft_cap":
|
| 1585 |
"loop_penalty": 12.0,
|
| 1586 |
-
"max_new_tokens":
|
| 1587 |
"context_window": 2048,
|
| 1588 |
},
|
| 1589 |
}
|
|
@@ -1681,8 +2465,11 @@ _FALLBACK_COLLECTION = [
|
|
| 1681 |
{"version": "TMLM-Haiku-2", "hf_id": "CompactAI-O/TMLM-Haiku-2"},
|
| 1682 |
{"version": "TMLM-Haiku-1.3", "hf_id": "CompactAI-O/TMLM-Haiku-1.3"},
|
| 1683 |
{"version": "TMLM-Haiku-1", "hf_id": "CompactAI-O/TMLM-Haiku-1"},
|
|
|
|
| 1684 |
]
|
| 1685 |
|
|
|
|
|
|
|
| 1686 |
|
| 1687 |
def _probe_repo(hf_id: str) -> dict | None:
|
| 1688 |
"""Return entry dict for one repo, or None if no usable checkpoints found."""
|
|
@@ -1710,6 +2497,7 @@ def _probe_repo(hf_id: str) -> dict | None:
|
|
| 1710 |
|
| 1711 |
_LABELS = {
|
| 1712 |
"model.pt": ("Chat (SFT)", False),
|
|
|
|
| 1713 |
"pretrain.pt": ("Pretrain (base)", True),
|
| 1714 |
}
|
| 1715 |
|
|
@@ -1750,6 +2538,7 @@ def fetch_collection() -> list[dict]:
|
|
| 1750 |
infos = [type("M", (), {"id": e["hf_id"]})() for e in _FALLBACK_COLLECTION]
|
| 1751 |
|
| 1752 |
entries = []
|
|
|
|
| 1753 |
for info in infos:
|
| 1754 |
repo_id = info.id
|
| 1755 |
if _SEARCH.lower() not in repo_id.lower():
|
|
@@ -1757,10 +2546,18 @@ def fetch_collection() -> list[dict]:
|
|
| 1757 |
entry = _probe_repo(repo_id)
|
| 1758 |
if entry:
|
| 1759 |
entries.append(entry)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1760 |
|
| 1761 |
if not entries:
|
| 1762 |
print(" No models found; using fallback list.")
|
| 1763 |
-
entries = []
|
| 1764 |
for fb in _FALLBACK_COLLECTION:
|
| 1765 |
e = _probe_repo(fb["hf_id"])
|
| 1766 |
if e:
|
|
@@ -1848,11 +2645,32 @@ def pick_checkpoint(entry: dict) -> tuple[str, bool]:
|
|
| 1848 |
|
| 1849 |
|
| 1850 |
def main() -> None:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1851 |
print("=" * 56)
|
| 1852 |
-
print("
|
| 1853 |
print(" Models: huggingface.co/CompactAI-O")
|
| 1854 |
print("=" * 56)
|
| 1855 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1856 |
collection = fetch_collection()
|
| 1857 |
if not collection:
|
| 1858 |
print("No models found. Check your internet connection.")
|
|
@@ -1861,6 +2679,11 @@ def main() -> None:
|
|
| 1861 |
entry = pick_version(collection)
|
| 1862 |
fname, is_pretrain = pick_checkpoint(entry)
|
| 1863 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1864 |
root = Path(__file__).resolve().parent
|
| 1865 |
cache_dir = root / "cache" / "huggingface"
|
| 1866 |
cache_dir.mkdir(parents=True, exist_ok=True)
|
|
@@ -1880,9 +2703,41 @@ def main() -> None:
|
|
| 1880 |
print(f"Loading {entry['version']} / {fname} ...")
|
| 1881 |
bundle = load_local_model(model_path, tokenizer_path, "Haiku")
|
| 1882 |
|
| 1883 |
-
|
| 1884 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1885 |
|
| 1886 |
|
| 1887 |
if __name__ == "__main__":
|
| 1888 |
main()
|
|
|
|
|
|
| 25 |
import torch.nn.functional as F
|
| 26 |
from torch.utils.checkpoint import checkpoint
|
| 27 |
|
| 28 |
+
|
| 29 |
HUGGINGFACE_MODELS = {
|
| 30 |
"TMLM-Haiku-1": "CompactAI-O/TMLM-Haiku-1",
|
| 31 |
"TMLM-Haiku-1.3": "CompactAI-O/TMLM-Haiku-1.3",
|
| 32 |
"TMLM-Haiku-2": "CompactAI-O/TMLM-Haiku-2",
|
| 33 |
+
"Glint-1": "CompactAI-O/Glint-1",
|
| 34 |
}
|
| 35 |
|
| 36 |
|
|
|
|
| 81 |
"engram_table_size": 64,
|
| 82 |
"engram_max_ngram": 2,
|
| 83 |
"mhc_expansion": 2,
|
| 84 |
+
"sleep_gate_cap": 0,
|
| 85 |
+
"sleep_gate_heads": 4,
|
| 86 |
+
"latent_think_layers": 0,
|
| 87 |
+
"prelude_layers": 0,
|
| 88 |
+
"coda_layers": 0,
|
| 89 |
+
"recurrent_loops": 0,
|
| 90 |
+
"recurrent_act_threshold": 0.9,
|
| 91 |
+
"recurrent_lora_rank": 0,
|
| 92 |
+
"recurrent_loop_embed_dim": 0,
|
| 93 |
},
|
| 94 |
"sonnet": {
|
| 95 |
"dim": 1024,
|
|
|
|
| 106 |
"engram_table_size": 4096,
|
| 107 |
"engram_max_ngram": 2,
|
| 108 |
"mhc_expansion": 2,
|
| 109 |
+
"sleep_gate_cap": 0,
|
| 110 |
+
"sleep_gate_heads": 8,
|
| 111 |
+
"latent_think_layers": 0,
|
| 112 |
+
"prelude_layers": 0,
|
| 113 |
+
"coda_layers": 0,
|
| 114 |
+
"recurrent_loops": 0,
|
| 115 |
+
"recurrent_act_threshold": 0.99,
|
| 116 |
+
"recurrent_lora_rank": 0,
|
| 117 |
+
"recurrent_loop_embed_dim": 0,
|
| 118 |
},
|
| 119 |
"opus": {
|
| 120 |
"dim": 1536,
|
|
|
|
| 131 |
"engram_table_size": 8192,
|
| 132 |
"engram_max_ngram": 2,
|
| 133 |
"mhc_expansion": 4,
|
| 134 |
+
"sleep_gate_cap": 0,
|
| 135 |
+
"sleep_gate_heads": 8,
|
| 136 |
+
"latent_think_layers": 0,
|
| 137 |
+
"prelude_layers": 0,
|
| 138 |
+
"coda_layers": 0,
|
| 139 |
+
"recurrent_loops": 0,
|
| 140 |
+
"recurrent_act_threshold": 0.99,
|
| 141 |
+
"recurrent_lora_rank": 0,
|
| 142 |
+
"recurrent_loop_embed_dim": 0,
|
| 143 |
},
|
| 144 |
}
|
| 145 |
|
|
|
|
| 451 |
return out
|
| 452 |
|
| 453 |
|
| 454 |
+
def loop_index_embedding(h: torch.Tensor, loop_t: int, loop_dim: int, theta: float = 10000.0) -> torch.Tensor:
|
| 455 |
+
if loop_dim <= 0:
|
| 456 |
+
return h
|
| 457 |
+
loop_dim = min(loop_dim, h.shape[-1])
|
| 458 |
+
if loop_dim % 2 == 1:
|
| 459 |
+
loop_dim -= 1
|
| 460 |
+
if loop_dim <= 0:
|
| 461 |
+
return h
|
| 462 |
+
inv_freq = 1.0 / (theta ** (torch.arange(0, loop_dim, 2, device=h.device, dtype=h.dtype) / loop_dim))
|
| 463 |
+
phase = torch.tensor(float(loop_t), device=h.device, dtype=h.dtype) * inv_freq
|
| 464 |
+
loop_embed = torch.cat([phase.sin(), phase.cos()], dim=0).view(1, 1, loop_dim)
|
| 465 |
+
out = h.clone()
|
| 466 |
+
out[..., :loop_dim] = out[..., :loop_dim] + loop_embed
|
| 467 |
+
return out
|
| 468 |
+
|
| 469 |
+
|
| 470 |
+
class DepthLoRAAdapter(nn.Module):
|
| 471 |
+
def __init__(self, dim: int, rank: int, max_loops: int) -> None:
|
| 472 |
+
super().__init__()
|
| 473 |
+
self.rank = max(0, rank)
|
| 474 |
+
if self.rank <= 0:
|
| 475 |
+
self.down = None
|
| 476 |
+
self.B = None
|
| 477 |
+
self.scale = None
|
| 478 |
+
return
|
| 479 |
+
self.down = nn.Linear(dim, self.rank, bias=False)
|
| 480 |
+
self.B = nn.Parameter(torch.randn(self.rank, dim) * 0.02)
|
| 481 |
+
self.scale = nn.Embedding(max(1, max_loops), self.rank)
|
| 482 |
+
nn.init.zeros_(self.scale.weight)
|
| 483 |
+
|
| 484 |
+
def forward(self, x: torch.Tensor, loop_t: int) -> torch.Tensor:
|
| 485 |
+
if self.rank <= 0 or self.down is None or self.B is None or self.scale is None:
|
| 486 |
+
return torch.zeros_like(x)
|
| 487 |
+
t_idx = min(loop_t, self.scale.num_embeddings - 1)
|
| 488 |
+
scale = self.scale(torch.tensor(t_idx, device=x.device))
|
| 489 |
+
return (self.down(x) * scale) @ self.B
|
| 490 |
+
|
| 491 |
+
|
| 492 |
+
class StableRecurrentInjection(nn.Module):
|
| 493 |
+
def __init__(self, dim: int) -> None:
|
| 494 |
+
super().__init__()
|
| 495 |
+
self.log_A = nn.Parameter(torch.full((dim,), -2.0))
|
| 496 |
+
self.log_dt = nn.Parameter(torch.full((dim,), -2.0))
|
| 497 |
+
self.input_gate = nn.Parameter(torch.zeros(dim))
|
| 498 |
+
|
| 499 |
+
def forward(self, h: torch.Tensor, e: torch.Tensor, transformer_out: torch.Tensor) -> torch.Tensor:
|
| 500 |
+
A = torch.exp(-torch.exp((self.log_dt + self.log_A).clamp(-20, 20))).view(1, 1, -1)
|
| 501 |
+
B = torch.sigmoid(self.input_gate).view(1, 1, -1)
|
| 502 |
+
return A * h + B * e + transformer_out
|
| 503 |
+
|
| 504 |
+
|
| 505 |
+
class AdaptiveHalting(nn.Module):
|
| 506 |
+
def __init__(self, dim: int) -> None:
|
| 507 |
+
super().__init__()
|
| 508 |
+
self.halt = nn.Linear(dim, 1, bias=True)
|
| 509 |
+
nn.init.zeros_(self.halt.weight)
|
| 510 |
+
nn.init.constant_(self.halt.bias, -2.0)
|
| 511 |
+
|
| 512 |
+
def forward(self, h: torch.Tensor) -> torch.Tensor:
|
| 513 |
+
return torch.sigmoid(self.halt(h)).squeeze(-1)
|
| 514 |
+
|
| 515 |
+
|
| 516 |
class EngramBlock(nn.Module):
|
| 517 |
"""DeepSeek Engram: conditional memory via O(1) hashed N-gram lookup.
|
| 518 |
|
|
|
|
| 657 |
return gate * value
|
| 658 |
|
| 659 |
|
| 660 |
+
class SleepGate(nn.Module):
|
| 661 |
+
"""Persistent memory + periodic consolidation gate."""
|
| 662 |
+
|
| 663 |
+
def __init__(
|
| 664 |
+
self,
|
| 665 |
+
dim: int,
|
| 666 |
+
cap: int = 128,
|
| 667 |
+
n_heads: int = 4,
|
| 668 |
+
retention_enabled: bool = True,
|
| 669 |
+
retention_hidden: int = 0,
|
| 670 |
+
) -> None:
|
| 671 |
+
super().__init__()
|
| 672 |
+
self.dim = dim
|
| 673 |
+
self.cap = cap
|
| 674 |
+
self.n_heads = n_heads
|
| 675 |
+
self.head_dim = dim // n_heads
|
| 676 |
+
self.scale = self.head_dim ** -0.5
|
| 677 |
+
self.retention_enabled = retention_enabled
|
| 678 |
+
|
| 679 |
+
self.register_buffer("mem_emb", torch.zeros(cap, dim, dtype=torch.bfloat16))
|
| 680 |
+
self.register_buffer("mem_age", torch.zeros(cap, dtype=torch.long))
|
| 681 |
+
self.register_buffer("mem_beta", torch.ones(cap, dtype=torch.float32))
|
| 682 |
+
self.register_buffer("mem_count", torch.zeros((), dtype=torch.long))
|
| 683 |
+
self.register_buffer("mem_head", torch.zeros((), dtype=torch.long))
|
| 684 |
+
self.register_buffer("global_step", torch.zeros((), dtype=torch.long))
|
| 685 |
+
|
| 686 |
+
self.q_proj = nn.Linear(dim, dim, bias=False)
|
| 687 |
+
self.k_proj = nn.Linear(dim, dim, bias=False)
|
| 688 |
+
self.v_proj = nn.Linear(dim, dim, bias=False)
|
| 689 |
+
self.o_proj = nn.Linear(dim, dim, bias=False)
|
| 690 |
+
nn.init.zeros_(self.o_proj.weight)
|
| 691 |
+
self.gate_scale = nn.Parameter(torch.zeros(()))
|
| 692 |
+
|
| 693 |
+
if retention_enabled:
|
| 694 |
+
if retention_hidden > 0:
|
| 695 |
+
self.retention_gate: Optional[nn.Module] = nn.Sequential(
|
| 696 |
+
nn.Linear(dim, retention_hidden, bias=False),
|
| 697 |
+
nn.GELU(),
|
| 698 |
+
nn.Linear(retention_hidden, 1, bias=True),
|
| 699 |
+
)
|
| 700 |
+
nn.init.constant_(self.retention_gate[-1].bias, 2.2)
|
| 701 |
+
else:
|
| 702 |
+
self.retention_gate = nn.Linear(dim, 1, bias=True)
|
| 703 |
+
nn.init.constant_(self.retention_gate.bias, 2.2)
|
| 704 |
+
else:
|
| 705 |
+
self.retention_gate = None
|
| 706 |
+
|
| 707 |
+
self._last_beta: Optional[torch.Tensor] = None
|
| 708 |
+
|
| 709 |
+
def write(self, hidden: torch.Tensor) -> None:
|
| 710 |
+
B, T, _ = hidden.shape
|
| 711 |
+
tail_full = hidden[:, max(0, T - 16):, :].float().mean(dim=1)
|
| 712 |
+
if self.retention_gate is not None:
|
| 713 |
+
beta_live = torch.sigmoid(self.retention_gate(tail_full).squeeze(-1))
|
| 714 |
+
self._last_beta = beta_live if self.training else None
|
| 715 |
+
beta_store = beta_live.detach().float()
|
| 716 |
+
else:
|
| 717 |
+
self._last_beta = None
|
| 718 |
+
beta_store = torch.ones(B, device=hidden.device, dtype=torch.float32)
|
| 719 |
+
tail = tail_full.to(self.mem_emb.dtype).detach()
|
| 720 |
+
with torch.no_grad():
|
| 721 |
+
head = int(self.mem_head.item())
|
| 722 |
+
count = int(self.mem_count.item())
|
| 723 |
+
step = int(self.global_step.item())
|
| 724 |
+
for b in range(B):
|
| 725 |
+
self.mem_emb[head] = tail[b]
|
| 726 |
+
self.mem_age[head] = step
|
| 727 |
+
self.mem_beta[head] = beta_store[b]
|
| 728 |
+
head = (head + 1) % self.cap
|
| 729 |
+
if count < self.cap:
|
| 730 |
+
count += 1
|
| 731 |
+
self.mem_head.fill_(head)
|
| 732 |
+
self.mem_count.fill_(count)
|
| 733 |
+
|
| 734 |
+
def read(self, x: torch.Tensor) -> torch.Tensor:
|
| 735 |
+
count = int(self.mem_count.item())
|
| 736 |
+
if count == 0:
|
| 737 |
+
return torch.zeros_like(x)
|
| 738 |
+
B, T, D = x.shape
|
| 739 |
+
mem = self.mem_emb[:count].clone().to(x.dtype)
|
| 740 |
+
q = self.q_proj(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
|
| 741 |
+
k = self.k_proj(mem).view(count, self.n_heads, self.head_dim).transpose(0, 1)
|
| 742 |
+
v = self.v_proj(mem).view(count, self.n_heads, self.head_dim).transpose(0, 1)
|
| 743 |
+
attn = torch.einsum("bhtd,hmd->bhtm", q, k) * self.scale
|
| 744 |
+
attn = F.softmax(attn, dim=-1)
|
| 745 |
+
if self.retention_enabled:
|
| 746 |
+
step = int(self.global_step.item())
|
| 747 |
+
ages = self.mem_age[:count].to(x.device)
|
| 748 |
+
delta = (step - ages).clamp(min=0).to(x.dtype)
|
| 749 |
+
betas = self.mem_beta[:count].to(x.dtype).clamp(min=1e-6, max=1.0)
|
| 750 |
+
weights = betas.pow(delta)
|
| 751 |
+
attn = attn * weights.view(1, 1, 1, count)
|
| 752 |
+
attn = attn / attn.sum(dim=-1, keepdim=True).clamp_min(1e-9)
|
| 753 |
+
out = torch.einsum("bhtm,hmd->bhtd", attn, v)
|
| 754 |
+
out = out.transpose(1, 2).contiguous().view(B, T, D)
|
| 755 |
+
out = self.o_proj(out)
|
| 756 |
+
return torch.sigmoid(self.gate_scale) * out
|
| 757 |
+
|
| 758 |
+
@torch.no_grad()
|
| 759 |
+
def reset(self) -> None:
|
| 760 |
+
self.mem_emb.zero_()
|
| 761 |
+
self.mem_age.zero_()
|
| 762 |
+
self.mem_beta.fill_(1.0)
|
| 763 |
+
self.mem_count.zero_()
|
| 764 |
+
self.mem_head.zero_()
|
| 765 |
+
self.global_step.zero_()
|
| 766 |
+
self._last_beta = None
|
| 767 |
+
|
| 768 |
+
|
| 769 |
def _sinkhorn_knopp(logits: torch.Tensor, n_iters: int = 7) -> torch.Tensor:
|
| 770 |
M = torch.exp(logits.clamp(-10, 10))
|
| 771 |
for _ in range(n_iters):
|
|
|
|
| 933 |
return x, new_kv
|
| 934 |
|
| 935 |
|
| 936 |
+
class RecurrentDepthBlock(nn.Module):
|
| 937 |
+
def __init__(
|
| 938 |
+
self,
|
| 939 |
+
dim: int,
|
| 940 |
+
n_heads: int,
|
| 941 |
+
n_kv_heads: int,
|
| 942 |
+
head_dim: int,
|
| 943 |
+
ffn_dim: int,
|
| 944 |
+
dropout: float,
|
| 945 |
+
sliding_window: int,
|
| 946 |
+
rope_fraction: float,
|
| 947 |
+
n_loops: int,
|
| 948 |
+
act_threshold: float,
|
| 949 |
+
lora_rank: int,
|
| 950 |
+
loop_embed_dim: int,
|
| 951 |
+
) -> None:
|
| 952 |
+
super().__init__()
|
| 953 |
+
self.n_loops = max(1, n_loops)
|
| 954 |
+
self.act_threshold = act_threshold
|
| 955 |
+
self.loop_embed_dim = max(0, loop_embed_dim)
|
| 956 |
+
self.norm = RMSNorm(dim)
|
| 957 |
+
self.block = TransformerBlock(
|
| 958 |
+
dim=dim, n_heads=n_heads, n_kv_heads=n_kv_heads, head_dim=head_dim,
|
| 959 |
+
ffn_dim=ffn_dim, dropout=dropout, sliding_window=sliding_window,
|
| 960 |
+
rope_fraction=rope_fraction, engram_dim=0, mhc_expansion=1,
|
| 961 |
+
)
|
| 962 |
+
self.injection = StableRecurrentInjection(dim)
|
| 963 |
+
self.act = AdaptiveHalting(dim)
|
| 964 |
+
self.lora = DepthLoRAAdapter(dim, lora_rank, self.n_loops)
|
| 965 |
+
|
| 966 |
+
def forward(
|
| 967 |
+
self,
|
| 968 |
+
h: torch.Tensor,
|
| 969 |
+
e: torch.Tensor,
|
| 970 |
+
token_ids: Optional[torch.Tensor] = None,
|
| 971 |
+
past_key_values: Optional[List[Optional[Tuple[torch.Tensor, torch.Tensor]]]] = None,
|
| 972 |
+
use_cache: bool = False,
|
| 973 |
+
n_loops: Optional[int] = None,
|
| 974 |
+
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], Optional[List[Tuple[torch.Tensor, torch.Tensor]]]]:
|
| 975 |
+
loops = max(1, n_loops or self.n_loops)
|
| 976 |
+
B, T, _ = h.shape
|
| 977 |
+
halted = torch.zeros(B, T, device=h.device, dtype=torch.bool)
|
| 978 |
+
cumulative_p = torch.zeros(B, T, device=h.device, dtype=h.dtype)
|
| 979 |
+
output = torch.zeros_like(h)
|
| 980 |
+
new_past: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = [] if use_cache else None
|
| 981 |
+
current = h
|
| 982 |
+
final_halt = None
|
| 983 |
+
|
| 984 |
+
for t in range(loops):
|
| 985 |
+
h_loop = loop_index_embedding(current, t, self.loop_embed_dim)
|
| 986 |
+
combined = self.norm(h_loop + e)
|
| 987 |
+
past_kv = None
|
| 988 |
+
if past_key_values is not None and t < len(past_key_values):
|
| 989 |
+
past_kv = past_key_values[t]
|
| 990 |
+
trans_out, layer_kv = self.block(combined, is_global=True, past_kv=past_kv, use_cache=use_cache, token_ids=token_ids)
|
| 991 |
+
trans_out = trans_out + self.lora(trans_out, t)
|
| 992 |
+
next_h = self.injection(current, e, trans_out)
|
| 993 |
+
p = self.act(next_h)
|
| 994 |
+
p = p * (~halted).to(p.dtype)
|
| 995 |
+
final_halt = p
|
| 996 |
+
should_halt = (~halted) & ((cumulative_p + p) >= self.act_threshold)
|
| 997 |
+
update_weight = torch.where(should_halt, (1.0 - cumulative_p).clamp(min=0.0), p)
|
| 998 |
+
output = output + next_h * update_weight.unsqueeze(-1)
|
| 999 |
+
cumulative_p = cumulative_p + update_weight
|
| 1000 |
+
current = torch.where(halted.unsqueeze(-1), current, next_h)
|
| 1001 |
+
halted = halted | should_halt
|
| 1002 |
+
if new_past is not None:
|
| 1003 |
+
new_past.append(layer_kv)
|
| 1004 |
+
if not use_cache and bool(halted.all()):
|
| 1005 |
+
break
|
| 1006 |
+
|
| 1007 |
+
remainder = (1.0 - cumulative_p).clamp(min=0.0)
|
| 1008 |
+
output = output + current * remainder.unsqueeze(-1)
|
| 1009 |
+
aux: Dict[str, torch.Tensor] = {}
|
| 1010 |
+
if final_halt is not None:
|
| 1011 |
+
aux["recurrent_halt_mean"] = final_halt.mean()
|
| 1012 |
+
return output, aux, new_past
|
| 1013 |
+
|
| 1014 |
+
|
| 1015 |
class TinyMemoryLM(nn.Module):
|
| 1016 |
def __init__(
|
| 1017 |
self,
|
|
|
|
| 1033 |
engram_table_size: int = 8192,
|
| 1034 |
engram_max_ngram: int = 3,
|
| 1035 |
mhc_expansion: int = 1,
|
| 1036 |
+
sleep_gate_cap: int = 0,
|
| 1037 |
+
sleep_gate_heads: int = 4,
|
| 1038 |
+
sleep_retention_enabled: bool = True,
|
| 1039 |
+
sleep_retention_hidden: int = 0,
|
| 1040 |
+
latent_think_layers: int = 0,
|
| 1041 |
+
prelude_layers: int = 0,
|
| 1042 |
+
coda_layers: int = 0,
|
| 1043 |
+
recurrent_loops: int = 0,
|
| 1044 |
+
recurrent_act_threshold: float = 0.99,
|
| 1045 |
+
recurrent_lora_rank: int = 0,
|
| 1046 |
+
recurrent_loop_embed_dim: int = 0,
|
| 1047 |
) -> None:
|
| 1048 |
super().__init__()
|
| 1049 |
self.dim = dim
|
|
|
|
| 1056 |
self.embed_tokens = nn.Embedding(vocab_size, dim)
|
| 1057 |
self.head = nn.Linear(dim, vocab_size, bias=False)
|
| 1058 |
self.head.weight = self.embed_tokens.weight
|
|
|
|
| 1059 |
self.output_bias = nn.Parameter(torch.zeros(vocab_size))
|
| 1060 |
|
| 1061 |
+
self.use_recurrent_depth = recurrent_loops > 0
|
| 1062 |
+
self.prelude_layers = max(0, prelude_layers)
|
| 1063 |
+
self.coda_layers = max(0, coda_layers)
|
| 1064 |
+
self.recurrent_loops = max(0, recurrent_loops)
|
| 1065 |
+
|
| 1066 |
+
self.blocks: Optional[nn.ModuleList] = None
|
| 1067 |
+
self.prelude: Optional[nn.ModuleList] = None
|
| 1068 |
+
self.recurrent: Optional[RecurrentDepthBlock] = None
|
| 1069 |
+
self.coda: Optional[nn.ModuleList] = None
|
| 1070 |
+
|
| 1071 |
+
def _make_blocks(n: int) -> nn.ModuleList:
|
| 1072 |
+
return nn.ModuleList([
|
| 1073 |
TransformerBlock(
|
| 1074 |
+
dim=dim, n_heads=n_heads, n_kv_heads=n_kv_heads, head_dim=head_dim,
|
| 1075 |
+
ffn_dim=ffn_dim, dropout=dropout, sliding_window=sliding_window,
|
| 1076 |
+
rope_fraction=rope_fraction, engram_dim=engram_dim,
|
| 1077 |
+
engram_heads=engram_heads, engram_table_size=engram_table_size,
|
| 1078 |
+
engram_max_ngram=engram_max_ngram, mhc_expansion=mhc_expansion,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1079 |
)
|
| 1080 |
+
for _ in range(n)
|
| 1081 |
+
])
|
| 1082 |
+
|
| 1083 |
+
if self.use_recurrent_depth:
|
| 1084 |
+
if self.prelude_layers > 0:
|
| 1085 |
+
self.prelude = _make_blocks(self.prelude_layers)
|
| 1086 |
+
self.recurrent = RecurrentDepthBlock(
|
| 1087 |
+
dim=dim, n_heads=n_heads, n_kv_heads=n_kv_heads, head_dim=head_dim,
|
| 1088 |
+
ffn_dim=ffn_dim, dropout=dropout, sliding_window=sliding_window,
|
| 1089 |
+
rope_fraction=rope_fraction, n_loops=self.recurrent_loops,
|
| 1090 |
+
act_threshold=recurrent_act_threshold, lora_rank=recurrent_lora_rank,
|
| 1091 |
+
loop_embed_dim=recurrent_loop_embed_dim or max(2, dim // 8),
|
| 1092 |
+
)
|
| 1093 |
+
if self.coda_layers > 0:
|
| 1094 |
+
self.coda = _make_blocks(self.coda_layers)
|
| 1095 |
+
else:
|
| 1096 |
+
self.blocks = _make_blocks(max(1, n_unique_layers))
|
| 1097 |
+
|
| 1098 |
self.norm = RMSNorm(dim)
|
| 1099 |
|
| 1100 |
self.mtp_horizons = sorted({int(h) for h in mtp_horizons if int(h) > 1})
|
|
|
|
| 1105 |
{str(h): RMSNorm(dim) for h in self.mtp_horizons}
|
| 1106 |
)
|
| 1107 |
|
| 1108 |
+
res_scale = (2 * max(1, n_logical_layers)) ** -0.5
|
| 1109 |
+
for group in (self.blocks, self.prelude, self.coda):
|
| 1110 |
+
if group is None:
|
| 1111 |
+
continue
|
| 1112 |
+
for block in group:
|
| 1113 |
+
block.attn.wo.weight.data.mul_(res_scale)
|
| 1114 |
+
block.ffn.down.weight.data.mul_(res_scale)
|
| 1115 |
+
if self.recurrent is not None:
|
| 1116 |
+
self.recurrent.block.attn.wo.weight.data.mul_(res_scale)
|
| 1117 |
+
self.recurrent.block.ffn.down.weight.data.mul_(res_scale)
|
| 1118 |
+
|
| 1119 |
+
self.sleep_gate: Optional[SleepGate] = None
|
| 1120 |
+
if sleep_gate_cap > 0:
|
| 1121 |
+
self.sleep_gate = SleepGate(
|
| 1122 |
+
dim=dim, cap=sleep_gate_cap, n_heads=sleep_gate_heads,
|
| 1123 |
+
retention_enabled=sleep_retention_enabled,
|
| 1124 |
+
retention_hidden=sleep_retention_hidden,
|
| 1125 |
+
)
|
| 1126 |
+
|
| 1127 |
+
self.think_blocks: Optional[nn.ModuleList] = None
|
| 1128 |
+
self.think_norm: Optional[RMSNorm] = None
|
| 1129 |
+
if latent_think_layers > 0:
|
| 1130 |
+
self.think_blocks = nn.ModuleList([
|
| 1131 |
+
TransformerBlock(
|
| 1132 |
+
dim=dim, n_heads=n_heads, n_kv_heads=n_kv_heads, head_dim=head_dim,
|
| 1133 |
+
ffn_dim=ffn_dim, dropout=0.0, sliding_window=2048,
|
| 1134 |
+
rope_fraction=rope_fraction, engram_dim=0, mhc_expansion=1,
|
| 1135 |
+
)
|
| 1136 |
+
for _ in range(latent_think_layers)
|
| 1137 |
+
])
|
| 1138 |
+
self.think_norm = RMSNorm(dim)
|
| 1139 |
|
| 1140 |
def resize_token_embeddings(self, new_vocab_size: int) -> None:
|
| 1141 |
old_vocab_size = self.embed_tokens.num_embeddings
|
|
|
|
| 1143 |
return
|
| 1144 |
device = self.embed_tokens.weight.device
|
| 1145 |
old_embed_weight = self.embed_tokens.weight.data.clone()
|
| 1146 |
+
self.embed_tokens = nn.Embedding(new_vocab_size, self.embed_tokens.embedding_dim).to(device)
|
| 1147 |
+
self.head = nn.Linear(self.embed_tokens.embedding_dim, new_vocab_size, bias=False).to(device)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1148 |
self.head.weight = self.embed_tokens.weight
|
| 1149 |
old_bias = self.output_bias.data.clone()
|
| 1150 |
self.output_bias = nn.Parameter(torch.zeros(new_vocab_size, device=device))
|
|
|
|
| 1153 |
self.embed_tokens.weight.data[:copy_size] = old_embed_weight[:copy_size]
|
| 1154 |
|
| 1155 |
def _build_logical_layers(self) -> List[Tuple[nn.Module, int]]:
|
| 1156 |
+
if self.blocks is None:
|
| 1157 |
+
return []
|
| 1158 |
blocks_list = list(self.blocks)
|
| 1159 |
full_sequence = blocks_list + blocks_list
|
| 1160 |
+
return [(block, i) for i, block in enumerate(full_sequence[: self.n_logical_layers])]
|
|
|
|
|
|
|
| 1161 |
|
| 1162 |
def forward(
|
| 1163 |
self,
|
| 1164 |
ids: torch.Tensor,
|
| 1165 |
use_cache: bool = False,
|
| 1166 |
+
past_key_values: Optional[List[Optional[Tuple[torch.Tensor, torch.Tensor]]]] = None,
|
|
|
|
|
|
|
| 1167 |
return_hidden: bool = False,
|
| 1168 |
+
) -> Tuple[torch.Tensor, Dict[int, torch.Tensor], Dict[str, torch.Tensor], Optional[torch.Tensor], Optional[List[Tuple[torch.Tensor, torch.Tensor]]]]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1169 |
B, T = ids.shape
|
| 1170 |
x = self.embed_tokens(ids) * self.embed_scale_factor
|
| 1171 |
+
new_past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = [] if use_cache else None
|
| 1172 |
+
aux: Dict[str, torch.Tensor] = {}
|
| 1173 |
+
|
| 1174 |
+
if self.use_recurrent_depth:
|
| 1175 |
+
offset = 0
|
| 1176 |
+
if self.prelude is not None:
|
| 1177 |
+
for block in self.prelude:
|
| 1178 |
+
past_kv = past_key_values[offset] if past_key_values is not None and offset < len(past_key_values) else None
|
| 1179 |
+
x, layer_kv = block(x, is_global=True, past_kv=past_kv, use_cache=use_cache, token_ids=ids)
|
| 1180 |
+
if new_past_key_values is not None:
|
| 1181 |
+
new_past_key_values.append(layer_kv)
|
| 1182 |
+
offset += 1
|
| 1183 |
+
encoded = x
|
| 1184 |
+
recurrent_past = past_key_values[offset: offset + self.recurrent_loops] if past_key_values is not None else None
|
| 1185 |
+
x, recurrent_aux, recurrent_kv = self.recurrent(
|
| 1186 |
+
x, encoded, token_ids=ids, past_key_values=recurrent_past, use_cache=use_cache,
|
| 1187 |
)
|
| 1188 |
+
aux.update(recurrent_aux)
|
| 1189 |
+
if new_past_key_values is not None and recurrent_kv is not None:
|
| 1190 |
+
new_past_key_values.extend(recurrent_kv)
|
| 1191 |
+
offset += self.recurrent_loops
|
| 1192 |
+
if self.coda is not None:
|
| 1193 |
+
for block in self.coda:
|
| 1194 |
+
past_kv = past_key_values[offset] if past_key_values is not None and offset < len(past_key_values) else None
|
| 1195 |
+
x, layer_kv = block(x, is_global=True, past_kv=past_kv, use_cache=use_cache, token_ids=ids)
|
| 1196 |
+
if new_past_key_values is not None:
|
| 1197 |
+
new_past_key_values.append(layer_kv)
|
| 1198 |
+
offset += 1
|
| 1199 |
+
else:
|
| 1200 |
+
logical_layers = self._build_logical_layers()
|
| 1201 |
+
last_logical_idx = len(logical_layers) - 1
|
| 1202 |
+
for layer_idx, (block, logical_idx) in enumerate(logical_layers):
|
| 1203 |
+
is_global = logical_idx % 2 == 0 or layer_idx == last_logical_idx
|
| 1204 |
+
past_kv = past_key_values[layer_idx] if past_key_values is not None and layer_idx < len(past_key_values) else None
|
| 1205 |
+
if self.grad_checkpoint and self.training and not use_cache:
|
| 1206 |
+
x, layer_kv = checkpoint(block, x, is_global, past_kv, use_cache, ids, use_reentrant=True)
|
| 1207 |
+
else:
|
| 1208 |
+
x, layer_kv = block(x, is_global, past_kv, use_cache, ids)
|
| 1209 |
+
if new_past_key_values is not None:
|
| 1210 |
+
new_past_key_values.append(layer_kv)
|
| 1211 |
|
| 1212 |
+
x = self.norm(x)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1213 |
|
| 1214 |
+
if self.sleep_gate is not None:
|
| 1215 |
+
x = x + self.sleep_gate.read(x)
|
| 1216 |
+
if self.training:
|
| 1217 |
+
self.sleep_gate.write(x)
|
| 1218 |
+
|
| 1219 |
+
if self.think_blocks is not None:
|
| 1220 |
+
for think_block in self.think_blocks:
|
| 1221 |
+
x, _ = think_block(x, is_global=True)
|
| 1222 |
+
x = self.think_norm(x)
|
| 1223 |
|
|
|
|
| 1224 |
h_out = x if return_hidden else None
|
| 1225 |
logits = self.head(x)
|
| 1226 |
if self.embed_scale_factor != 1.0:
|
|
|
|
| 1240 |
mtp_logits = mtp_logits + self.output_bias
|
| 1241 |
mtp[horizon] = mtp_logits
|
| 1242 |
|
| 1243 |
+
return logits, mtp, aux, h_out, new_past_key_values
|
| 1244 |
|
| 1245 |
|
| 1246 |
# ---------------------------------------------------------------------------
|
|
|
|
| 1352 |
ctx_ids = (
|
| 1353 |
input_ids_t[:, -context_window:] if context_window > 0 else input_ids_t
|
| 1354 |
)
|
| 1355 |
+
logits, *_ = model(ctx_ids)
|
| 1356 |
next_logits = logits[0, -1, :].clone()
|
| 1357 |
|
| 1358 |
# Logit soft-capping (Gemma-style) — prevents overconfident collapse
|
|
|
|
| 1463 |
if not tokenizer_path.exists():
|
| 1464 |
continue
|
| 1465 |
name = child.name
|
| 1466 |
+
series = None
|
| 1467 |
for ckpt_name in ("model.pt", "pretrain.pt"):
|
| 1468 |
+
ckpt_path = child / ckpt_name
|
| 1469 |
+
if ckpt_path.exists():
|
| 1470 |
+
series = _fast_series_from_checkpoint(ckpt_path)
|
| 1471 |
+
break
|
| 1472 |
+
if series is None:
|
| 1473 |
+
series = series_from_name(name) or "Sonnet"
|
| 1474 |
+
found = False
|
| 1475 |
+
for ckpt_name in ("model.pt", "model_rep.pt", "pretrain.pt"):
|
| 1476 |
ckpt_path = child / ckpt_name
|
| 1477 |
if ckpt_path.exists():
|
| 1478 |
models.append(
|
|
|
|
| 1484 |
"tokenizer_path": tokenizer_path,
|
| 1485 |
}
|
| 1486 |
)
|
| 1487 |
+
found = True
|
| 1488 |
+
if not found:
|
| 1489 |
+
step_ckpts = sorted(
|
| 1490 |
+
child.glob("checkpoint_step_*.pt"),
|
| 1491 |
+
key=lambda p: int(p.stem.rsplit("_", 1)[-1]),
|
| 1492 |
+
)
|
| 1493 |
+
if step_ckpts:
|
| 1494 |
+
ckpt_path = step_ckpts[-1]
|
| 1495 |
+
models.append(
|
| 1496 |
+
{
|
| 1497 |
+
"name": name,
|
| 1498 |
+
"checkpoint": ckpt_path.name,
|
| 1499 |
+
"series": series,
|
| 1500 |
+
"model_path": ckpt_path,
|
| 1501 |
+
"tokenizer_path": tokenizer_path,
|
| 1502 |
+
}
|
| 1503 |
+
)
|
| 1504 |
return models
|
| 1505 |
|
| 1506 |
|
|
|
|
| 1519 |
return 1
|
| 1520 |
|
| 1521 |
|
| 1522 |
+
def _detect_sleep_gate(state_dict) -> Tuple[int, int]:
|
| 1523 |
+
for key, val in state_dict.items():
|
| 1524 |
+
if key == "sleep_gate.mem_emb" and val.dim() == 2:
|
| 1525 |
+
cap = val.shape[0]
|
| 1526 |
+
return cap, 4
|
| 1527 |
+
return 0, 4
|
| 1528 |
+
|
| 1529 |
+
|
| 1530 |
+
def _detect_latent_think(state_dict) -> int:
|
| 1531 |
+
indices = {
|
| 1532 |
+
int(k.split(".")[1])
|
| 1533 |
+
for k in state_dict
|
| 1534 |
+
if k.startswith("think_blocks.") and k.split(".")[1].isdigit()
|
| 1535 |
+
}
|
| 1536 |
+
return max(indices) + 1 if indices else 0
|
| 1537 |
+
|
| 1538 |
+
|
| 1539 |
+
def _detect_prelude_layers(state_dict) -> int:
|
| 1540 |
+
indices = {
|
| 1541 |
+
int(k.split(".")[1])
|
| 1542 |
+
for k in state_dict
|
| 1543 |
+
if k.startswith("prelude.") and k.split(".")[1].isdigit()
|
| 1544 |
+
}
|
| 1545 |
+
return max(indices) + 1 if indices else 0
|
| 1546 |
+
|
| 1547 |
+
|
| 1548 |
+
def _detect_coda_layers(state_dict) -> int:
|
| 1549 |
+
indices = {
|
| 1550 |
+
int(k.split(".")[1])
|
| 1551 |
+
for k in state_dict
|
| 1552 |
+
if k.startswith("coda.") and k.split(".")[1].isdigit()
|
| 1553 |
+
}
|
| 1554 |
+
return max(indices) + 1 if indices else 0
|
| 1555 |
+
|
| 1556 |
+
|
| 1557 |
+
def _detect_recurrent_loops(state_dict) -> int:
|
| 1558 |
+
if "recurrent.norm.weight" in state_dict or "recurrent.block.attn.wq.weight" in state_dict:
|
| 1559 |
+
if "recurrent.lora.scale.weight" in state_dict:
|
| 1560 |
+
return state_dict["recurrent.lora.scale.weight"].shape[0]
|
| 1561 |
+
return 1
|
| 1562 |
+
return 0
|
| 1563 |
+
|
| 1564 |
+
|
| 1565 |
+
def _detect_recurrent_lora_rank(state_dict) -> int:
|
| 1566 |
+
for key in ("recurrent.lora.B", "recurrent.lora.down.weight"):
|
| 1567 |
+
if key in state_dict:
|
| 1568 |
+
shape = state_dict[key].shape
|
| 1569 |
+
if len(shape) == 2:
|
| 1570 |
+
return int(shape[0])
|
| 1571 |
+
return 0
|
| 1572 |
+
|
| 1573 |
+
|
| 1574 |
+
def _infer_series_from_lora_rank(rank: int) -> str | None:
|
| 1575 |
+
if rank == 0:
|
| 1576 |
+
return None
|
| 1577 |
+
if rank <= 8:
|
| 1578 |
+
return "haiku"
|
| 1579 |
+
if rank <= 16:
|
| 1580 |
+
return "sonnet"
|
| 1581 |
+
return "opus"
|
| 1582 |
+
|
| 1583 |
+
|
| 1584 |
+
def _fast_series_from_checkpoint(ckpt_path: Path) -> str | None:
|
| 1585 |
+
try:
|
| 1586 |
+
cp = torch.load(ckpt_path, map_location="cpu", weights_only=False)
|
| 1587 |
+
sd = cp.get("model_state", cp.get("state_dict", {}))
|
| 1588 |
+
rank = 0
|
| 1589 |
+
for key in ("recurrent.lora.B", "recurrent.lora.down.weight"):
|
| 1590 |
+
if key in sd:
|
| 1591 |
+
rank = int(sd[key].shape[0])
|
| 1592 |
+
break
|
| 1593 |
+
if rank == 0:
|
| 1594 |
+
return None
|
| 1595 |
+
if rank <= 8:
|
| 1596 |
+
return "Haiku"
|
| 1597 |
+
if rank <= 16:
|
| 1598 |
+
return "Sonnet"
|
| 1599 |
+
return "Opus"
|
| 1600 |
+
except Exception:
|
| 1601 |
+
pass
|
| 1602 |
+
return None
|
| 1603 |
+
|
| 1604 |
+
|
| 1605 |
def _infer_arch_from_state_dict(state_dict, cfg):
|
| 1606 |
"""Infer architecture hyper-parameters directly from checkpoint weights,
|
| 1607 |
falling back to *cfg* (series config) when a key is not found."""
|
| 1608 |
overrides = {}
|
| 1609 |
|
| 1610 |
+
has_prelude = any(k.startswith("prelude.") for k in state_dict)
|
| 1611 |
+
has_blocks = any(k.startswith("blocks.") for k in state_dict)
|
| 1612 |
+
has_recurrent = any(k.startswith("recurrent.") for k in state_dict)
|
| 1613 |
+
uses_recurrent_arch = has_prelude and has_recurrent and not has_blocks
|
| 1614 |
+
|
| 1615 |
# dim from embed_tokens.weight [vocab, dim]
|
| 1616 |
if "embed_tokens.weight" in state_dict:
|
| 1617 |
overrides["dim"] = state_dict["embed_tokens.weight"].shape[1]
|
| 1618 |
|
| 1619 |
+
if uses_recurrent_arch:
|
| 1620 |
+
if "prelude.0.ffn.gate.weight" in state_dict:
|
| 1621 |
+
overrides["ffn_dim"] = state_dict["prelude.0.ffn.gate.weight"].shape[0]
|
| 1622 |
+
overrides["n_unique_layers"] = 0
|
| 1623 |
+
src = "prelude.0"
|
| 1624 |
+
else:
|
| 1625 |
+
if "blocks.0.ffn.gate.weight" in state_dict:
|
| 1626 |
+
overrides["ffn_dim"] = state_dict["blocks.0.ffn.gate.weight"].shape[0]
|
| 1627 |
+
block_ids = {
|
| 1628 |
+
int(k.split(".")[1])
|
| 1629 |
+
for k in state_dict
|
| 1630 |
+
if k.startswith("blocks.") and k.split(".")[1].isdigit()
|
| 1631 |
+
}
|
| 1632 |
+
if block_ids:
|
| 1633 |
+
overrides["n_unique_layers"] = max(block_ids) + 1
|
| 1634 |
+
src = "blocks.0"
|
| 1635 |
|
|
|
|
| 1636 |
dim = overrides.get("dim", int(cfg.get("dim", model_config.dim)))
|
| 1637 |
+
if f"{src}.attn.wq.weight" in state_dict:
|
| 1638 |
+
wq_rows = state_dict[f"{src}.attn.wq.weight"].shape[0]
|
| 1639 |
+
if f"{src}.attn.q_norm.weight" in state_dict:
|
| 1640 |
+
head_dim = state_dict[f"{src}.attn.q_norm.weight"].shape[0]
|
| 1641 |
overrides["n_heads"] = wq_rows // head_dim
|
| 1642 |
+
if f"{src}.attn.wk.weight" in state_dict:
|
| 1643 |
+
wk_rows = state_dict[f"{src}.attn.wk.weight"].shape[0]
|
| 1644 |
+
if f"{src}.attn.k_norm.weight" in state_dict:
|
| 1645 |
+
head_dim = state_dict[f"{src}.attn.k_norm.weight"].shape[0]
|
| 1646 |
overrides["n_kv_heads"] = wk_rows // head_dim
|
| 1647 |
|
| 1648 |
+
# engram params
|
| 1649 |
for key, val in state_dict.items():
|
| 1650 |
if ".engram.embeddings." in key and key.endswith("_0") and val.dim() == 2:
|
| 1651 |
overrides["engram_table_size"] = val.shape[0]
|
| 1652 |
overrides["engram_dim"] = val.shape[1]
|
| 1653 |
break
|
|
|
|
|
|
|
| 1654 |
engram_dim = overrides.get("engram_dim", int(cfg.get("engram_dim", 0)))
|
| 1655 |
engram_max_ngram = int(cfg.get("engram_max_ngram", 2))
|
| 1656 |
if engram_dim > 0:
|
|
|
|
| 1662 |
overrides["engram_heads"] = total_branch_dim // denom
|
| 1663 |
break
|
| 1664 |
|
|
|
|
| 1665 |
merged = dict(cfg)
|
| 1666 |
merged.update(overrides)
|
| 1667 |
return merged
|
|
|
|
| 1675 |
|
| 1676 |
state_dict = ckpt.get("model_state") or ckpt.get("state_dict") or ckpt
|
| 1677 |
|
|
|
|
|
|
|
| 1678 |
cfg = _infer_arch_from_state_dict(state_dict, cfg)
|
| 1679 |
|
| 1680 |
engram_dim = int(cfg.get("engram_dim", 0))
|
|
|
|
| 1685 |
if mhc_expansion == 1:
|
| 1686 |
mhc_expansion = int(cfg.get("mhc_expansion", 1))
|
| 1687 |
|
| 1688 |
+
ckpt_sleep_cap, ckpt_sleep_heads = _detect_sleep_gate(state_dict)
|
| 1689 |
+
sleep_gate_cap = ckpt_sleep_cap if ckpt_sleep_cap > 0 else int(cfg.get("sleep_gate_cap", 0))
|
| 1690 |
+
sleep_gate_heads = ckpt_sleep_heads if ckpt_sleep_cap > 0 else int(cfg.get("sleep_gate_heads", 4))
|
| 1691 |
+
sleep_retention_enabled = bool(cfg.get("sleep_retention_enabled", True))
|
| 1692 |
+
sleep_retention_hidden = int(cfg.get("sleep_retention_hidden", 0))
|
| 1693 |
+
|
| 1694 |
+
latent_think_layers = _detect_latent_think(state_dict)
|
| 1695 |
+
if latent_think_layers == 0:
|
| 1696 |
+
latent_think_layers = int(cfg.get("latent_think_layers", 0))
|
| 1697 |
+
|
| 1698 |
+
prelude_layers = _detect_prelude_layers(state_dict)
|
| 1699 |
+
coda_layers = _detect_coda_layers(state_dict)
|
| 1700 |
+
recurrent_loops = _detect_recurrent_loops(state_dict)
|
| 1701 |
+
|
| 1702 |
+
ckpt_lora_rank = _detect_recurrent_lora_rank(state_dict)
|
| 1703 |
+
if ckpt_lora_rank > 0:
|
| 1704 |
+
inferred_series = _infer_series_from_lora_rank(ckpt_lora_rank)
|
| 1705 |
+
if inferred_series and inferred_series != series.lower():
|
| 1706 |
+
series = inferred_series.capitalize()
|
| 1707 |
+
cfg = series_config(series)
|
| 1708 |
+
recurrent_lora_rank = ckpt_lora_rank
|
| 1709 |
+
else:
|
| 1710 |
+
recurrent_lora_rank = int(cfg.get("recurrent_lora_rank", 0))
|
| 1711 |
+
|
| 1712 |
+
recurrent_act_threshold = float(cfg.get("recurrent_act_threshold", 0.99))
|
| 1713 |
+
recurrent_loop_embed_dim = int(cfg.get("recurrent_loop_embed_dim", 0))
|
| 1714 |
+
|
| 1715 |
+
n_unique = int(cfg.get("n_unique_layers", model_config.n_unique_layers))
|
| 1716 |
+
|
| 1717 |
model = TinyMemoryLM(
|
| 1718 |
vocab_size=vocab_size,
|
| 1719 |
dim=int(cfg.get("dim", model_config.dim)),
|
| 1720 |
+
n_unique_layers=n_unique,
|
| 1721 |
+
n_logical_layers=int(cfg.get("n_logical_layers", model_config.n_logical_layers)),
|
|
|
|
|
|
|
| 1722 |
n_heads=int(cfg.get("n_heads", model_config.n_heads)),
|
| 1723 |
n_kv_heads=int(cfg.get("n_kv_heads", model_config.n_kv_heads)),
|
| 1724 |
ffn_dim=int(cfg.get("ffn_dim", model_config.ffn_dim)),
|
| 1725 |
dropout=float(cfg.get("dropout", model_config.dropout)),
|
| 1726 |
+
mtp_horizons=tuple(int(v) for v in cfg.get("mtp_horizons", model_config.mtp_horizons)),
|
|
|
|
|
|
|
| 1727 |
grad_checkpoint=False,
|
| 1728 |
+
sliding_window=int(cfg.get("sliding_window_size", getattr(model_config, "sliding_window_size", 512))),
|
| 1729 |
+
rope_fraction=float(cfg.get("rope_fraction", getattr(model_config, "rope_fraction", 0.25))),
|
| 1730 |
+
embed_scale=bool(cfg.get("embed_scale", getattr(model_config, "embed_scale", True))),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1731 |
engram_dim=engram_dim,
|
| 1732 |
engram_heads=int(cfg.get("engram_heads", 4)),
|
| 1733 |
engram_table_size=int(cfg.get("engram_table_size", 8192)),
|
| 1734 |
engram_max_ngram=int(cfg.get("engram_max_ngram", 3)),
|
| 1735 |
mhc_expansion=mhc_expansion,
|
| 1736 |
+
sleep_gate_cap=sleep_gate_cap,
|
| 1737 |
+
sleep_gate_heads=sleep_gate_heads,
|
| 1738 |
+
sleep_retention_enabled=sleep_retention_enabled,
|
| 1739 |
+
sleep_retention_hidden=sleep_retention_hidden,
|
| 1740 |
+
latent_think_layers=latent_think_layers,
|
| 1741 |
+
prelude_layers=prelude_layers,
|
| 1742 |
+
coda_layers=coda_layers,
|
| 1743 |
+
recurrent_loops=recurrent_loops,
|
| 1744 |
+
recurrent_act_threshold=recurrent_act_threshold,
|
| 1745 |
+
recurrent_lora_rank=recurrent_lora_rank,
|
| 1746 |
+
recurrent_loop_embed_dim=recurrent_loop_embed_dim,
|
| 1747 |
)
|
| 1748 |
model.load_state_dict(state_dict, strict=False)
|
| 1749 |
model.eval()
|
|
|
|
| 1756 |
"tokenizer": tokenizer,
|
| 1757 |
"device": device,
|
| 1758 |
"series": series,
|
| 1759 |
+
"sft_mode": ckpt.get("sft_mode", None),
|
| 1760 |
+
"phase": ckpt.get("phase", None),
|
| 1761 |
}
|
| 1762 |
|
| 1763 |
|
|
|
|
| 1781 |
|
| 1782 |
print(f"Using cached {hf_id} from {local_dir}")
|
| 1783 |
|
| 1784 |
+
# Check common subdirectory names: "models/", "model/"
|
| 1785 |
+
if (local_dir / "models").exists():
|
| 1786 |
+
model_dir = local_dir / "models"
|
| 1787 |
+
elif (local_dir / "model").exists():
|
| 1788 |
+
model_dir = local_dir / "model"
|
| 1789 |
+
else:
|
| 1790 |
+
model_dir = local_dir
|
| 1791 |
model_path = model_dir / "model.pt"
|
| 1792 |
pretrain_path = model_dir / "pretrain.pt"
|
| 1793 |
tokenizer_path = model_dir / "tokenizer.json"
|
|
|
|
| 1941 |
print(f"\n{'='*60}")
|
| 1942 |
|
| 1943 |
|
| 1944 |
+
# ---------------------------------------------------------------------------
|
| 1945 |
+
# Benchmark
|
| 1946 |
+
# ---------------------------------------------------------------------------
|
| 1947 |
+
|
| 1948 |
+
BENCHMARKS = {
|
| 1949 |
+
"blimp": {
|
| 1950 |
+
"label": "BLiMP",
|
| 1951 |
+
"desc": "Grammaticality minimal pairs (67 paradigms). Accuracy = % grammatical < ungrammatical perplexity.",
|
| 1952 |
+
"hf_dataset": ("nyu-mll/blimp", None),
|
| 1953 |
+
"metric": "accuracy",
|
| 1954 |
+
},
|
| 1955 |
+
"wikitext2": {
|
| 1956 |
+
"label": "WikiText-2",
|
| 1957 |
+
"desc": "LM perplexity on Wikipedia test split. Lower is better.",
|
| 1958 |
+
"hf_dataset": ("Salesforce/wikitext", "wikitext-2-raw-v1"),
|
| 1959 |
+
"metric": "perplexity",
|
| 1960 |
+
},
|
| 1961 |
+
"arc_easy": {
|
| 1962 |
+
"label": "ARC-Easy",
|
| 1963 |
+
"desc": "Multiple-choice science QA (~2.4K). Perplexity-based answer selection.",
|
| 1964 |
+
"hf_dataset": ("allenai/ai2_arc", "ARC-Easy"),
|
| 1965 |
+
"metric": "accuracy",
|
| 1966 |
+
},
|
| 1967 |
+
}
|
| 1968 |
+
|
| 1969 |
+
|
| 1970 |
+
def _score_text(model: TinyMemoryLM, tokenizer: WordTokenizer, text: str, device: str) -> float:
|
| 1971 |
+
ids = tokenizer.encode(text, add_bos=True, add_eos=False)
|
| 1972 |
+
if len(ids) < 2:
|
| 1973 |
+
return float("nan")
|
| 1974 |
+
ids_t = torch.tensor([ids], dtype=torch.long, device=device)
|
| 1975 |
+
with torch.no_grad():
|
| 1976 |
+
logits, *_ = model(ids_t)
|
| 1977 |
+
log_probs = F.log_softmax(logits[0], dim=-1)
|
| 1978 |
+
targets = ids_t[0, 1:]
|
| 1979 |
+
nll = -log_probs[range(len(targets)), targets].mean().item()
|
| 1980 |
+
return nll
|
| 1981 |
+
|
| 1982 |
+
|
| 1983 |
+
def _score_completion(model: TinyMemoryLM, tokenizer: WordTokenizer, context: str, completion: str, device: str) -> float:
|
| 1984 |
+
full_ids = tokenizer.encode(context + completion, add_bos=True, add_eos=False)
|
| 1985 |
+
ctx_ids = tokenizer.encode(context, add_bos=True, add_eos=False)
|
| 1986 |
+
n_ctx = len(ctx_ids)
|
| 1987 |
+
n_ref = len(full_ids) - n_ctx
|
| 1988 |
+
if n_ref <= 0:
|
| 1989 |
+
return float("nan")
|
| 1990 |
+
ids_t = torch.tensor([full_ids], dtype=torch.long, device=device)
|
| 1991 |
+
with torch.no_grad():
|
| 1992 |
+
logits, *_ = model(ids_t)
|
| 1993 |
+
log_probs = F.log_softmax(logits[0], dim=-1)
|
| 1994 |
+
targets = ids_t[0, 1:]
|
| 1995 |
+
ref_start = n_ctx - 1
|
| 1996 |
+
ref_end = min(ref_start + n_ref, log_probs.shape[0])
|
| 1997 |
+
if ref_start >= ref_end:
|
| 1998 |
+
return float("nan")
|
| 1999 |
+
nll = -log_probs[ref_start:ref_end][range(ref_end - ref_start), targets[ref_start:ref_end]].mean().item()
|
| 2000 |
+
return nll
|
| 2001 |
+
|
| 2002 |
+
|
| 2003 |
+
BLIMP_PARADIGMS = [
|
| 2004 |
+
"adjunct_island", "anaphor_gender_agreement", "anaphor_number_agreement",
|
| 2005 |
+
"animate_subject_passive", "animate_subject_trans", "causative",
|
| 2006 |
+
"complex_NP_island", "coordinate_structure_constraint_complex_left_branch",
|
| 2007 |
+
"coordinate_structure_constraint_object_extraction",
|
| 2008 |
+
"determiner_noun_agreement_1", "determiner_noun_agreement_2",
|
| 2009 |
+
"determiner_noun_agreement_irregular_1", "determiner_noun_agreement_irregular_2",
|
| 2010 |
+
"determiner_noun_agreement_with_adj_2", "determiner_noun_agreement_with_adj_irregular_1",
|
| 2011 |
+
"determiner_noun_agreement_with_adj_irregular_2", "determiner_noun_agreement_with_adjective_1",
|
| 2012 |
+
"distractor_agreement_relational_noun", "distractor_agreement_relative_clause",
|
| 2013 |
+
"drop_argument", "ellipsis_n_bar_1", "ellipsis_n_bar_2",
|
| 2014 |
+
"existential_there_object_raising", "existential_there_quantifiers_1",
|
| 2015 |
+
"existential_there_quantifiers_2", "existential_there_subject_raising",
|
| 2016 |
+
"expletive_it_object_raising", "inchoative", "intransitive",
|
| 2017 |
+
"irregular_past_participle_adjectives", "irregular_past_participle_verbs",
|
| 2018 |
+
"irregular_plural_subject_verb_agreement_1", "irregular_plural_subject_verb_agreement_2",
|
| 2019 |
+
"left_branch_island_echo_question", "left_branch_island_simple_question",
|
| 2020 |
+
"matrix_question_npi_licensor_present", "npi_present_1", "npi_present_2",
|
| 2021 |
+
"only_npi_licensor_present", "only_npi_scope", "passive_1", "passive_2",
|
| 2022 |
+
"principle_A_c_command", "principle_A_case_1", "principle_A_case_2",
|
| 2023 |
+
"principle_A_domain_1", "principle_A_domain_2", "principle_A_domain_3",
|
| 2024 |
+
"principle_A_reconstruction", "regular_plural_subject_verb_agreement_1",
|
| 2025 |
+
"regular_plural_subject_verb_agreement_2", "sentential_negation_npi_licensor_present",
|
| 2026 |
+
"sentential_negation_npi_scope", "sentential_subject_island",
|
| 2027 |
+
"superlative_quantifiers_1", "superlative_quantifiers_2",
|
| 2028 |
+
"tough_vs_raising_1", "tough_vs_raising_2", "transitive", "wh_island",
|
| 2029 |
+
"wh_questions_object_gap", "wh_questions_subject_gap",
|
| 2030 |
+
"wh_questions_subject_gap_long_distance", "wh_vs_that_no_gap",
|
| 2031 |
+
"wh_vs_that_no_gap_long_distance", "wh_vs_that_with_gap",
|
| 2032 |
+
"wh_vs_that_with_gap_long_distance",
|
| 2033 |
+
]
|
| 2034 |
+
|
| 2035 |
+
|
| 2036 |
+
def _run_blimp(model, tokenizer, device, n_samples: int = 200) -> Tuple[List[str], List[float]]:
|
| 2037 |
+
from datasets import load_dataset # type: ignore
|
| 2038 |
+
accuracies: List[float] = []
|
| 2039 |
+
for paradigm in BLIMP_PARADIGMS:
|
| 2040 |
+
try:
|
| 2041 |
+
ds = load_dataset("nyu-mll/blimp", paradigm, split="train")
|
| 2042 |
+
except Exception as e:
|
| 2043 |
+
print(f" {paradigm}: skip ({e})")
|
| 2044 |
+
accuracies.append(float("nan"))
|
| 2045 |
+
continue
|
| 2046 |
+
items = list(ds)[:n_samples]
|
| 2047 |
+
correct = 0
|
| 2048 |
+
for ex in items:
|
| 2049 |
+
good_nll = _score_text(model, tokenizer, ex["sentence_good"], device)
|
| 2050 |
+
bad_nll = _score_text(model, tokenizer, ex["sentence_bad"], device)
|
| 2051 |
+
if math.isnan(good_nll) or math.isnan(bad_nll):
|
| 2052 |
+
continue
|
| 2053 |
+
if good_nll < bad_nll:
|
| 2054 |
+
correct += 1
|
| 2055 |
+
acc = correct / len(items) if items else float("nan")
|
| 2056 |
+
accuracies.append(acc)
|
| 2057 |
+
print(f" {paradigm:50s} acc={acc:.3f}")
|
| 2058 |
+
return BLIMP_PARADIGMS, accuracies
|
| 2059 |
+
|
| 2060 |
+
|
| 2061 |
+
def _run_wikitext2(model, tokenizer, device, chunk_chars: int = 512, max_chunks: int = 100) -> Tuple[List[str], List[float]]:
|
| 2062 |
+
from datasets import load_dataset # type: ignore
|
| 2063 |
+
ds = load_dataset("Salesforce/wikitext", "wikitext-2-raw-v1", split="test")
|
| 2064 |
+
full_text = "\n".join(ex["text"] for ex in ds if ex["text"].strip())
|
| 2065 |
+
chunks = [full_text[i:i + chunk_chars] for i in range(0, len(full_text), chunk_chars)]
|
| 2066 |
+
chunks = [c for c in chunks if len(c) > 20][:max_chunks]
|
| 2067 |
+
labels: List[str] = []
|
| 2068 |
+
ppls: List[float] = []
|
| 2069 |
+
for i, chunk in enumerate(chunks):
|
| 2070 |
+
nll = _score_text(model, tokenizer, chunk, device)
|
| 2071 |
+
ppl = math.exp(nll) if not math.isnan(nll) else float("nan")
|
| 2072 |
+
labels.append(f"chunk {i + 1}")
|
| 2073 |
+
ppls.append(ppl)
|
| 2074 |
+
if (i + 1) % 10 == 0:
|
| 2075 |
+
valid = [v for v in ppls if not math.isnan(v)]
|
| 2076 |
+
mean = sum(valid) / len(valid) if valid else float("nan")
|
| 2077 |
+
print(f" chunk {i + 1}/{len(chunks)} running mean ppl={mean:.2f}")
|
| 2078 |
+
return labels, ppls
|
| 2079 |
+
|
| 2080 |
+
|
| 2081 |
+
def _run_arc_easy(model, tokenizer, device, max_samples: int = 200) -> Tuple[List[str], List[float]]:
|
| 2082 |
+
from datasets import load_dataset # type: ignore
|
| 2083 |
+
ds = load_dataset("allenai/ai2_arc", "ARC-Easy", split="test")
|
| 2084 |
+
items = list(ds)[:max_samples]
|
| 2085 |
+
labels: List[str] = []
|
| 2086 |
+
scores: List[float] = []
|
| 2087 |
+
for i, ex in enumerate(items):
|
| 2088 |
+
question = ex["question"]
|
| 2089 |
+
choices = ex["choices"]["text"]
|
| 2090 |
+
choice_labels = ex["choices"]["label"]
|
| 2091 |
+
answer_key = ex["answerKey"]
|
| 2092 |
+
context = f"Question: {question}\nAnswer:"
|
| 2093 |
+
nlls = [_score_completion(model, tokenizer, context, f" {c}", device) for c in choices]
|
| 2094 |
+
if all(math.isnan(v) for v in nlls):
|
| 2095 |
+
scores.append(float("nan"))
|
| 2096 |
+
else:
|
| 2097 |
+
best_idx = min(range(len(nlls)), key=lambda j: nlls[j] if not math.isnan(nlls[j]) else float("inf"))
|
| 2098 |
+
predicted = choice_labels[best_idx]
|
| 2099 |
+
scores.append(1.0 if predicted == answer_key else 0.0)
|
| 2100 |
+
labels.append(f"Q{i + 1}")
|
| 2101 |
+
n_valid = sum(1 for s in scores if not math.isnan(s))
|
| 2102 |
+
acc = sum(s for s in scores if not math.isnan(s)) / n_valid if n_valid else float("nan")
|
| 2103 |
+
print(f" {n_valid} questions evaluated, accuracy={acc:.3f}")
|
| 2104 |
+
return labels, scores
|
| 2105 |
+
|
| 2106 |
+
|
| 2107 |
+
def run_benchmark_mode() -> None:
|
| 2108 |
+
try:
|
| 2109 |
+
import matplotlib
|
| 2110 |
+
matplotlib.use("Agg")
|
| 2111 |
+
import matplotlib.pyplot as plt
|
| 2112 |
+
except ImportError:
|
| 2113 |
+
print("matplotlib not installed. pip install matplotlib")
|
| 2114 |
+
return
|
| 2115 |
+
|
| 2116 |
+
bench_keys = list(BENCHMARKS.keys())
|
| 2117 |
+
print("\nBenchmarks:")
|
| 2118 |
+
for i, k in enumerate(bench_keys):
|
| 2119 |
+
b = BENCHMARKS[k]
|
| 2120 |
+
print(f" [{i + 1}] {b['label']} — {b['desc']}")
|
| 2121 |
+
print("Select benchmark [1]:", end=" ", flush=True)
|
| 2122 |
+
try:
|
| 2123 |
+
b_choice = input().strip() or "1"
|
| 2124 |
+
except (EOFError, KeyboardInterrupt):
|
| 2125 |
+
print()
|
| 2126 |
+
return
|
| 2127 |
+
if not (b_choice.isdigit() and 1 <= int(b_choice) <= len(bench_keys)):
|
| 2128 |
+
print("Invalid selection.")
|
| 2129 |
+
return
|
| 2130 |
+
bench_key = bench_keys[int(b_choice) - 1]
|
| 2131 |
+
bench = BENCHMARKS[bench_key]
|
| 2132 |
+
print(f"Benchmark: {bench['label']}")
|
| 2133 |
+
|
| 2134 |
+
root = Path(__file__).resolve().parent
|
| 2135 |
+
runs_dir = root / "runs"
|
| 2136 |
+
all_models = discover_models(runs_dir)
|
| 2137 |
+
|
| 2138 |
+
model_entries: List[dict] = []
|
| 2139 |
+
for m in all_models:
|
| 2140 |
+
model_entries.append({"label": f"[LOCAL] {m['name']}/{m['checkpoint']}", "type": "local", "meta": m})
|
| 2141 |
+
for hf_name, hf_id in HUGGINGFACE_MODELS.items():
|
| 2142 |
+
model_entries.append({"label": f"[HF] {hf_name}", "type": "hf", "hf_id": hf_id, "hf_name": hf_name})
|
| 2143 |
+
|
| 2144 |
+
if not model_entries:
|
| 2145 |
+
print("No models found.")
|
| 2146 |
+
return
|
| 2147 |
+
|
| 2148 |
+
print("\nAvailable models:")
|
| 2149 |
+
for i, e in enumerate(model_entries):
|
| 2150 |
+
print(f" [{i + 1}] {e['label']}")
|
| 2151 |
+
print(" [a] All models")
|
| 2152 |
+
print("Select models (comma-separated or 'a'):", end=" ", flush=True)
|
| 2153 |
+
try:
|
| 2154 |
+
raw = input().strip()
|
| 2155 |
+
except (EOFError, KeyboardInterrupt):
|
| 2156 |
+
print()
|
| 2157 |
+
return
|
| 2158 |
+
|
| 2159 |
+
if raw.lower() == "a":
|
| 2160 |
+
selected = list(range(len(model_entries)))
|
| 2161 |
+
else:
|
| 2162 |
+
selected = []
|
| 2163 |
+
for tok in raw.split(","):
|
| 2164 |
+
tok = tok.strip()
|
| 2165 |
+
if tok.isdigit() and 1 <= int(tok) <= len(model_entries):
|
| 2166 |
+
selected.append(int(tok) - 1)
|
| 2167 |
+
if not selected:
|
| 2168 |
+
print("No valid selection.")
|
| 2169 |
+
return
|
| 2170 |
+
|
| 2171 |
+
all_results: List[dict] = []
|
| 2172 |
+
shared_x_labels: Optional[List[str]] = None
|
| 2173 |
+
|
| 2174 |
+
for idx in selected:
|
| 2175 |
+
entry = model_entries[idx]
|
| 2176 |
+
print(f"\n{'='*60}\nLoading {entry['label']}...")
|
| 2177 |
+
try:
|
| 2178 |
+
if entry["type"] == "local":
|
| 2179 |
+
m = entry["meta"]
|
| 2180 |
+
bundle = load_local_model(m["model_path"], m["tokenizer_path"], m["series"])
|
| 2181 |
+
else:
|
| 2182 |
+
bundle = load_huggingface_model(entry["hf_id"], root / ".hf_cache")
|
| 2183 |
+
except Exception as e:
|
| 2184 |
+
print(f" Failed: {e}")
|
| 2185 |
+
continue
|
| 2186 |
+
|
| 2187 |
+
model = bundle["model"]
|
| 2188 |
+
tokenizer = bundle["tokenizer"]
|
| 2189 |
+
device = str(bundle["device"])
|
| 2190 |
+
model.eval()
|
| 2191 |
+
|
| 2192 |
+
if bench_key == "blimp":
|
| 2193 |
+
x_labels, y_vals = _run_blimp(model, tokenizer, device)
|
| 2194 |
+
elif bench_key == "wikitext2":
|
| 2195 |
+
x_labels, y_vals = _run_wikitext2(model, tokenizer, device)
|
| 2196 |
+
else:
|
| 2197 |
+
x_labels, y_vals = _run_arc_easy(model, tokenizer, device)
|
| 2198 |
+
|
| 2199 |
+
if shared_x_labels is None:
|
| 2200 |
+
shared_x_labels = x_labels
|
| 2201 |
+
|
| 2202 |
+
valid = [v for v in y_vals if not math.isnan(v)]
|
| 2203 |
+
summary = sum(valid) / len(valid) if valid else float("nan")
|
| 2204 |
+
all_results.append({"label": entry["label"], "y": y_vals, "summary": summary})
|
| 2205 |
+
|
| 2206 |
+
if not all_results or shared_x_labels is None:
|
| 2207 |
+
print("No results to plot.")
|
| 2208 |
+
return
|
| 2209 |
+
|
| 2210 |
+
metric = bench["metric"]
|
| 2211 |
+
paired = sorted(zip([r["summary"] for r in all_results], [r["label"] for r in all_results]),
|
| 2212 |
+
reverse=(metric != "perplexity"))
|
| 2213 |
+
summaries, model_labels = zip(*paired) if paired else ([], [])
|
| 2214 |
+
n = len(summaries)
|
| 2215 |
+
colors = [plt.cm.RdYlGn(i / max(n - 1, 1)) for i in range(n)]
|
| 2216 |
+
|
| 2217 |
+
fig, ax = plt.subplots(figsize=(max(6, n * 1.4), 6))
|
| 2218 |
+
bars = ax.bar(range(n), summaries, color=colors, edgecolor="black")
|
| 2219 |
+
for bar, val in zip(bars, summaries):
|
| 2220 |
+
ax.text(bar.get_x() + bar.get_width() / 2, bar.get_height() + 0.005,
|
| 2221 |
+
f"{val:.3f}", ha="center", va="bottom", fontsize=9, fontweight="bold")
|
| 2222 |
+
|
| 2223 |
+
ylabel = "Mean Perplexity (↓ better)" if metric == "perplexity" else "Mean Accuracy (↑ better)"
|
| 2224 |
+
ax.set_ylabel(ylabel)
|
| 2225 |
+
ax.set_title(f"{bench['label']} Benchmark — Model Comparison")
|
| 2226 |
+
ax.set_xticks(range(n))
|
| 2227 |
+
ax.set_xticklabels(model_labels, rotation=20, ha="right", fontsize=9)
|
| 2228 |
+
if metric == "accuracy":
|
| 2229 |
+
ax.set_ylim(0, 1.05)
|
| 2230 |
+
ax.grid(True, axis="y", alpha=0.3)
|
| 2231 |
+
plt.tight_layout()
|
| 2232 |
+
|
| 2233 |
+
out_path = root / f"benchmark_{bench_key}.png"
|
| 2234 |
+
plt.savefig(str(out_path), dpi=150)
|
| 2235 |
+
print(f"\nChart saved to {out_path}")
|
| 2236 |
+
try:
|
| 2237 |
+
import subprocess
|
| 2238 |
+
subprocess.Popen(["xdg-open", str(out_path)])
|
| 2239 |
+
except Exception:
|
| 2240 |
+
pass
|
| 2241 |
+
|
| 2242 |
+
|
| 2243 |
# ---------------------------------------------------------------------------
|
| 2244 |
# Interactive CLI
|
| 2245 |
# ---------------------------------------------------------------------------
|
|
|
|
| 2310 |
# ---------------------------------------------------------------------------
|
| 2311 |
|
| 2312 |
MODES = {
|
|
|
|
| 2313 |
"chat-coherent": {
|
| 2314 |
"label": "Chat — Coherent",
|
| 2315 |
"desc": "structured, consistent, strong repetition control",
|
| 2316 |
"sft_mode": "chat",
|
| 2317 |
+
"temperature": 0.35,
|
| 2318 |
+
"top_k": 20,
|
| 2319 |
+
"top_p": 0.88,
|
| 2320 |
+
"min_p": 0.10,
|
| 2321 |
+
"no_repeat_ngram_size": 4,
|
| 2322 |
+
"repetition_penalty": 1.22,
|
| 2323 |
+
"logit_soft_cap": 20.0,
|
| 2324 |
+
"loop_penalty": 20.0,
|
| 2325 |
+
"max_new_tokens": 4096,
|
| 2326 |
"context_window": 2048,
|
| 2327 |
},
|
| 2328 |
"chat-variants": {
|
| 2329 |
"label": "Chat — Variants",
|
| 2330 |
"desc": "creative, diverse, more surprising outputs",
|
| 2331 |
"sft_mode": "chat",
|
| 2332 |
+
"temperature": 0.65,
|
| 2333 |
+
"top_k": 60,
|
| 2334 |
+
"top_p": 0.92,
|
| 2335 |
+
"min_p": 0.05,
|
| 2336 |
+
"no_repeat_ngram_size": 4,
|
| 2337 |
+
"repetition_penalty": 1.12,
|
| 2338 |
+
"logit_soft_cap": 20.0,
|
| 2339 |
+
"loop_penalty": 14.0,
|
| 2340 |
+
"max_new_tokens": 4096,
|
| 2341 |
"context_window": 2048,
|
| 2342 |
},
|
|
|
|
| 2343 |
"pretrain-coherent": {
|
| 2344 |
"label": "Pretrain — Coherent",
|
| 2345 |
"desc": "grounded continuation, low temperature, tight sampling",
|
| 2346 |
"sft_mode": False,
|
| 2347 |
+
"temperature": 0.3,
|
| 2348 |
"top_k": 20,
|
| 2349 |
"top_p": 0.85,
|
| 2350 |
"min_p": 0.10,
|
| 2351 |
+
"no_repeat_ngram_size": 4,
|
| 2352 |
"repetition_penalty": 1.2,
|
| 2353 |
+
"logit_soft_cap": 20.0,
|
| 2354 |
+
"loop_penalty": 20.0,
|
| 2355 |
+
"max_new_tokens": 4096,
|
| 2356 |
"context_window": 2048,
|
| 2357 |
},
|
| 2358 |
"pretrain-variants": {
|
| 2359 |
"label": "Pretrain — Variants",
|
| 2360 |
"desc": "free-form continuation, higher temperature, more exploration",
|
| 2361 |
"sft_mode": False,
|
| 2362 |
+
"temperature": 0.7,
|
| 2363 |
"top_k": 60,
|
| 2364 |
+
"top_p": 0.93,
|
| 2365 |
+
"min_p": 0.04,
|
| 2366 |
"no_repeat_ngram_size": 4,
|
| 2367 |
+
"repetition_penalty": 1.12,
|
| 2368 |
+
"logit_soft_cap": 20.0,
|
| 2369 |
"loop_penalty": 12.0,
|
| 2370 |
+
"max_new_tokens": 4096,
|
| 2371 |
"context_window": 2048,
|
| 2372 |
},
|
| 2373 |
}
|
|
|
|
| 2465 |
{"version": "TMLM-Haiku-2", "hf_id": "CompactAI-O/TMLM-Haiku-2"},
|
| 2466 |
{"version": "TMLM-Haiku-1.3", "hf_id": "CompactAI-O/TMLM-Haiku-1.3"},
|
| 2467 |
{"version": "TMLM-Haiku-1", "hf_id": "CompactAI-O/TMLM-Haiku-1"},
|
| 2468 |
+
{"version": "Glint-1", "hf_id": "CompactAI-O/Glint-1"},
|
| 2469 |
]
|
| 2470 |
|
| 2471 |
+
_EXTRA_REPOS = ["CompactAI-O/Glint-1"]
|
| 2472 |
+
|
| 2473 |
|
| 2474 |
def _probe_repo(hf_id: str) -> dict | None:
|
| 2475 |
"""Return entry dict for one repo, or None if no usable checkpoints found."""
|
|
|
|
| 2497 |
|
| 2498 |
_LABELS = {
|
| 2499 |
"model.pt": ("Chat (SFT)", False),
|
| 2500 |
+
"model_rep.pt": ("Chat (anti-repetition)", False),
|
| 2501 |
"pretrain.pt": ("Pretrain (base)", True),
|
| 2502 |
}
|
| 2503 |
|
|
|
|
| 2538 |
infos = [type("M", (), {"id": e["hf_id"]})() for e in _FALLBACK_COLLECTION]
|
| 2539 |
|
| 2540 |
entries = []
|
| 2541 |
+
seen_ids: set = set()
|
| 2542 |
for info in infos:
|
| 2543 |
repo_id = info.id
|
| 2544 |
if _SEARCH.lower() not in repo_id.lower():
|
|
|
|
| 2546 |
entry = _probe_repo(repo_id)
|
| 2547 |
if entry:
|
| 2548 |
entries.append(entry)
|
| 2549 |
+
seen_ids.add(repo_id)
|
| 2550 |
+
|
| 2551 |
+
# Always include extra repos (e.g. Glint-1) not caught by TMLM-Haiku search
|
| 2552 |
+
for repo_id in _EXTRA_REPOS:
|
| 2553 |
+
if repo_id not in seen_ids:
|
| 2554 |
+
entry = _probe_repo(repo_id)
|
| 2555 |
+
if entry:
|
| 2556 |
+
entries.append(entry)
|
| 2557 |
+
seen_ids.add(repo_id)
|
| 2558 |
|
| 2559 |
if not entries:
|
| 2560 |
print(" No models found; using fallback list.")
|
|
|
|
| 2561 |
for fb in _FALLBACK_COLLECTION:
|
| 2562 |
e = _probe_repo(fb["hf_id"])
|
| 2563 |
if e:
|
|
|
|
| 2645 |
|
| 2646 |
|
| 2647 |
def main() -> None:
|
| 2648 |
+
import argparse
|
| 2649 |
+
parser = argparse.ArgumentParser()
|
| 2650 |
+
parser.add_argument("--compare", "-c", action="store_true")
|
| 2651 |
+
parser.add_argument("--prompt", "-p", type=str, default="Hello")
|
| 2652 |
+
mode_group = parser.add_mutually_exclusive_group()
|
| 2653 |
+
mode_group.add_argument("--pretrain", action="store_true")
|
| 2654 |
+
mode_group.add_argument("--sft", action="store_true")
|
| 2655 |
+
args, _ = parser.parse_known_args()
|
| 2656 |
+
|
| 2657 |
print("=" * 56)
|
| 2658 |
+
print(" CompactAI-O Interactive Chat")
|
| 2659 |
print(" Models: huggingface.co/CompactAI-O")
|
| 2660 |
print("=" * 56)
|
| 2661 |
|
| 2662 |
+
if args.compare:
|
| 2663 |
+
prefetch_huggingface_models()
|
| 2664 |
+
cfg = pick_mode(is_pretrain=args.pretrain)
|
| 2665 |
+
prompt_label = "You" if cfg["sft_mode"] else "Prompt"
|
| 2666 |
+
while True:
|
| 2667 |
+
print(f"{prompt_label}:", end=" ", flush=True)
|
| 2668 |
+
prompt = sys.stdin.readline().strip()
|
| 2669 |
+
if not prompt or prompt in ("/quit", "/exit", "/q"):
|
| 2670 |
+
break
|
| 2671 |
+
compare_all_models(prompt, cfg)
|
| 2672 |
+
return
|
| 2673 |
+
|
| 2674 |
collection = fetch_collection()
|
| 2675 |
if not collection:
|
| 2676 |
print("No models found. Check your internet connection.")
|
|
|
|
| 2679 |
entry = pick_version(collection)
|
| 2680 |
fname, is_pretrain = pick_checkpoint(entry)
|
| 2681 |
|
| 2682 |
+
if args.pretrain:
|
| 2683 |
+
is_pretrain = True
|
| 2684 |
+
elif args.sft:
|
| 2685 |
+
is_pretrain = False
|
| 2686 |
+
|
| 2687 |
root = Path(__file__).resolve().parent
|
| 2688 |
cache_dir = root / "cache" / "huggingface"
|
| 2689 |
cache_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
| 2703 |
print(f"Loading {entry['version']} / {fname} ...")
|
| 2704 |
bundle = load_local_model(model_path, tokenizer_path, "Haiku")
|
| 2705 |
|
| 2706 |
+
# Use checkpoint-embedded sft_mode/phase if available
|
| 2707 |
+
sft_mode_flag = bundle.get("sft_mode")
|
| 2708 |
+
phase_flag = bundle.get("phase")
|
| 2709 |
+
if sft_mode_flag is not None and not args.pretrain and not args.sft:
|
| 2710 |
+
is_pretrain = not sft_mode_flag
|
| 2711 |
+
elif phase_flag is not None and not args.pretrain and not args.sft:
|
| 2712 |
+
is_pretrain = phase_flag == "pretrain"
|
| 2713 |
+
|
| 2714 |
+
print("\nChoose action:")
|
| 2715 |
+
print(" [1] Chat with this model")
|
| 2716 |
+
print(" [2] Compare ALL models (local + HuggingFace)")
|
| 2717 |
+
print(" [3] Run Benchmark (BLiMP / WikiText-2 / ARC-Easy)")
|
| 2718 |
+
print("Select [1]:", end=" ", flush=True)
|
| 2719 |
+
choice = sys.stdin.readline().strip() or "1"
|
| 2720 |
+
|
| 2721 |
+
if choice == "1":
|
| 2722 |
+
cfg = pick_mode(is_pretrain)
|
| 2723 |
+
_run_loop(bundle, cfg)
|
| 2724 |
+
elif choice == "2":
|
| 2725 |
+
print("\nDownloading/preparing HuggingFace models...")
|
| 2726 |
+
prefetch_huggingface_models()
|
| 2727 |
+
cfg = pick_mode(is_pretrain)
|
| 2728 |
+
prompt_label = "You" if cfg["sft_mode"] else "Prompt"
|
| 2729 |
+
while True:
|
| 2730 |
+
print(f"{prompt_label}:", end=" ", flush=True)
|
| 2731 |
+
prompt = sys.stdin.readline().strip()
|
| 2732 |
+
if not prompt or prompt in ("/quit", "/exit", "/q"):
|
| 2733 |
+
break
|
| 2734 |
+
compare_all_models(prompt, cfg)
|
| 2735 |
+
elif choice == "3":
|
| 2736 |
+
run_benchmark_mode()
|
| 2737 |
+
else:
|
| 2738 |
+
print("Enter 1, 2, or 3")
|
| 2739 |
|
| 2740 |
|
| 2741 |
if __name__ == "__main__":
|
| 2742 |
main()
|
| 2743 |
+
|