CompactAI commited on
Commit
16b04a8
·
verified ·
1 Parent(s): b99d86a

Update downloads/interactive.py

Browse files
Files changed (1) hide show
  1. downloads/interactive.py +1011 -156
downloads/interactive.py CHANGED
@@ -25,10 +25,12 @@ import torch.nn as nn
25
  import torch.nn.functional as F
26
  from torch.utils.checkpoint import checkpoint
27
 
 
28
  HUGGINGFACE_MODELS = {
29
  "TMLM-Haiku-1": "CompactAI-O/TMLM-Haiku-1",
30
  "TMLM-Haiku-1.3": "CompactAI-O/TMLM-Haiku-1.3",
31
  "TMLM-Haiku-2": "CompactAI-O/TMLM-Haiku-2",
 
32
  }
33
 
34
 
@@ -79,6 +81,15 @@ MODEL_SERIES = {
79
  "engram_table_size": 64,
80
  "engram_max_ngram": 2,
81
  "mhc_expansion": 2,
 
 
 
 
 
 
 
 
 
82
  },
83
  "sonnet": {
84
  "dim": 1024,
@@ -95,6 +106,15 @@ MODEL_SERIES = {
95
  "engram_table_size": 4096,
96
  "engram_max_ngram": 2,
97
  "mhc_expansion": 2,
 
 
 
 
 
 
 
 
 
98
  },
99
  "opus": {
100
  "dim": 1536,
@@ -111,6 +131,15 @@ MODEL_SERIES = {
111
  "engram_table_size": 8192,
112
  "engram_max_ngram": 2,
113
  "mhc_expansion": 4,
 
 
 
 
 
 
 
 
 
114
  },
115
  }
116
 
@@ -422,6 +451,68 @@ class SwiGLU(nn.Module):
422
  return out
423
 
424
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
425
  class EngramBlock(nn.Module):
426
  """DeepSeek Engram: conditional memory via O(1) hashed N-gram lookup.
427
 
@@ -566,6 +657,115 @@ class EngramBlock(nn.Module):
566
  return gate * value
567
 
568
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
569
  def _sinkhorn_knopp(logits: torch.Tensor, n_iters: int = 7) -> torch.Tensor:
570
  M = torch.exp(logits.clamp(-10, 10))
571
  for _ in range(n_iters):
@@ -733,6 +933,85 @@ class TransformerBlock(nn.Module):
733
  return x, new_kv
734
 
735
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
736
  class TinyMemoryLM(nn.Module):
737
  def __init__(
738
  self,
@@ -754,6 +1033,17 @@ class TinyMemoryLM(nn.Module):
754
  engram_table_size: int = 8192,
755
  engram_max_ngram: int = 3,
756
  mhc_expansion: int = 1,
 
 
 
 
 
 
 
 
 
 
 
757
  ) -> None:
758
  super().__init__()
759
  self.dim = dim
@@ -766,29 +1056,45 @@ class TinyMemoryLM(nn.Module):
766
  self.embed_tokens = nn.Embedding(vocab_size, dim)
767
  self.head = nn.Linear(dim, vocab_size, bias=False)
768
  self.head.weight = self.embed_tokens.weight
769
-
770
  self.output_bias = nn.Parameter(torch.zeros(vocab_size))
771
 
772
- self.blocks = nn.ModuleList(
773
- [
 
 
 
 
 
 
 
 
 
 
774
  TransformerBlock(
775
- dim=dim,
776
- n_heads=n_heads,
777
- n_kv_heads=n_kv_heads,
778
- head_dim=head_dim,
779
- ffn_dim=ffn_dim,
780
- dropout=dropout,
781
- sliding_window=sliding_window,
782
- rope_fraction=rope_fraction,
783
- engram_dim=engram_dim,
784
- engram_heads=engram_heads,
785
- engram_table_size=engram_table_size,
786
- engram_max_ngram=engram_max_ngram,
787
- mhc_expansion=mhc_expansion,
788
  )
789
- for _ in range(n_unique_layers)
790
- ]
791
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
792
  self.norm = RMSNorm(dim)
793
 
794
  self.mtp_horizons = sorted({int(h) for h in mtp_horizons if int(h) > 1})
@@ -799,10 +1105,37 @@ class TinyMemoryLM(nn.Module):
799
  {str(h): RMSNorm(dim) for h in self.mtp_horizons}
800
  )
801
 
802
- res_scale = (2 * n_unique_layers) ** -0.5
803
- for block in self.blocks:
804
- block.attn.wo.weight.data.mul_(res_scale)
805
- block.ffn.down.weight.data.mul_(res_scale)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
806
 
807
  def resize_token_embeddings(self, new_vocab_size: int) -> None:
808
  old_vocab_size = self.embed_tokens.num_embeddings
@@ -810,12 +1143,8 @@ class TinyMemoryLM(nn.Module):
810
  return
811
  device = self.embed_tokens.weight.device
812
  old_embed_weight = self.embed_tokens.weight.data.clone()
813
- self.embed_tokens = nn.Embedding(
814
- new_vocab_size, self.embed_tokens.embedding_dim
815
- ).to(device)
816
- self.head = nn.Linear(
817
- self.embed_tokens.embedding_dim, new_vocab_size, bias=False
818
- ).to(device)
819
  self.head.weight = self.embed_tokens.weight
820
  old_bias = self.output_bias.data.clone()
821
  self.output_bias = nn.Parameter(torch.zeros(new_vocab_size, device=device))
@@ -824,62 +1153,74 @@ class TinyMemoryLM(nn.Module):
824
  self.embed_tokens.weight.data[:copy_size] = old_embed_weight[:copy_size]
825
 
826
  def _build_logical_layers(self) -> List[Tuple[nn.Module, int]]:
827
- logical = []
 
828
  blocks_list = list(self.blocks)
829
  full_sequence = blocks_list + blocks_list
830
- for logical_idx, block in enumerate(full_sequence[: self.n_logical_layers]):
831
- logical.append((block, logical_idx))
832
- return logical
833
 
834
  def forward(
835
  self,
836
  ids: torch.Tensor,
837
  use_cache: bool = False,
838
- past_key_values: Optional[
839
- List[Optional[Tuple[torch.Tensor, torch.Tensor]]]
840
- ] = None,
841
  return_hidden: bool = False,
842
- ) -> Tuple[
843
- torch.Tensor,
844
- Dict[int, torch.Tensor],
845
- Optional[torch.Tensor],
846
- Optional[List[Tuple[torch.Tensor, torch.Tensor]]],
847
- ]:
848
  B, T = ids.shape
849
  x = self.embed_tokens(ids) * self.embed_scale_factor
850
- token_ids = ids
851
-
852
- logical_layers = self._build_logical_layers()
853
- new_past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = (
854
- [] if use_cache else None
855
- )
856
-
857
- last_logical_idx = len(logical_layers) - 1
858
- for layer_idx, (block, logical_idx) in enumerate(logical_layers):
859
- is_global = logical_idx % 2 == 0 or layer_idx == last_logical_idx
860
- past_kv = (
861
- past_key_values[layer_idx]
862
- if past_key_values is not None and layer_idx < len(past_key_values)
863
- else None
 
 
864
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
865
 
866
- if self.grad_checkpoint and self.training and not use_cache:
867
- x, layer_kv = checkpoint(
868
- block,
869
- x,
870
- is_global,
871
- past_kv,
872
- use_cache,
873
- token_ids,
874
- use_reentrant=True,
875
- )
876
- else:
877
- x, layer_kv = block(x, is_global, past_kv, use_cache, token_ids)
878
 
879
- if new_past_key_values is not None:
880
- new_past_key_values.append(layer_kv)
 
 
 
 
 
 
 
881
 
882
- x = self.norm(x)
883
  h_out = x if return_hidden else None
884
  logits = self.head(x)
885
  if self.embed_scale_factor != 1.0:
@@ -899,7 +1240,7 @@ class TinyMemoryLM(nn.Module):
899
  mtp_logits = mtp_logits + self.output_bias
900
  mtp[horizon] = mtp_logits
901
 
902
- return logits, mtp, h_out, new_past_key_values
903
 
904
 
905
  # ---------------------------------------------------------------------------
@@ -1011,7 +1352,7 @@ def generate(
1011
  ctx_ids = (
1012
  input_ids_t[:, -context_window:] if context_window > 0 else input_ids_t
1013
  )
1014
- logits, _, _, _ = model(ctx_ids)
1015
  next_logits = logits[0, -1, :].clone()
1016
 
1017
  # Logit soft-capping (Gemma-style) — prevents overconfident collapse
@@ -1122,8 +1463,16 @@ def discover_models(runs_dir: Path) -> List[dict]:
1122
  if not tokenizer_path.exists():
1123
  continue
1124
  name = child.name
1125
- series = series_from_name(name) or "Sonnet"
1126
  for ckpt_name in ("model.pt", "pretrain.pt"):
 
 
 
 
 
 
 
 
1127
  ckpt_path = child / ckpt_name
1128
  if ckpt_path.exists():
1129
  models.append(
@@ -1135,6 +1484,23 @@ def discover_models(runs_dir: Path) -> List[dict]:
1135
  "tokenizer_path": tokenizer_path,
1136
  }
1137
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1138
  return models
1139
 
1140
 
@@ -1153,49 +1519,138 @@ def _detect_mhc(state_dict):
1153
  return 1
1154
 
1155
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1156
  def _infer_arch_from_state_dict(state_dict, cfg):
1157
  """Infer architecture hyper-parameters directly from checkpoint weights,
1158
  falling back to *cfg* (series config) when a key is not found."""
1159
  overrides = {}
1160
 
 
 
 
 
 
1161
  # dim from embed_tokens.weight [vocab, dim]
1162
  if "embed_tokens.weight" in state_dict:
1163
  overrides["dim"] = state_dict["embed_tokens.weight"].shape[1]
1164
 
1165
- # ffn_dim from blocks.0.ffn.gate.weight [ffn_dim, dim]
1166
- if "blocks.0.ffn.gate.weight" in state_dict:
1167
- overrides["ffn_dim"] = state_dict["blocks.0.ffn.gate.weight"].shape[0]
1168
-
1169
- # n_unique_layers – count block indices
1170
- block_ids = {
1171
- int(k.split(".")[1])
1172
- for k in state_dict
1173
- if k.startswith("blocks.") and k.split(".")[1].isdigit()
1174
- }
1175
- if block_ids:
1176
- overrides["n_unique_layers"] = max(block_ids) + 1
 
 
 
 
1177
 
1178
- # n_heads from wq [n_heads*head_dim, dim] and wk [n_kv*head_dim, dim]
1179
  dim = overrides.get("dim", int(cfg.get("dim", model_config.dim)))
1180
- if "blocks.0.attn.wq.weight" in state_dict:
1181
- wq_rows = state_dict["blocks.0.attn.wq.weight"].shape[0] # n_heads * head_dim
1182
- if "blocks.0.attn.q_norm.weight" in state_dict:
1183
- head_dim = state_dict["blocks.0.attn.q_norm.weight"].shape[0]
1184
  overrides["n_heads"] = wq_rows // head_dim
1185
- if "blocks.0.attn.wk.weight" in state_dict:
1186
- wk_rows = state_dict["blocks.0.attn.wk.weight"].shape[0]
1187
- if "blocks.0.attn.k_norm.weight" in state_dict:
1188
- head_dim = state_dict["blocks.0.attn.k_norm.weight"].shape[0]
1189
  overrides["n_kv_heads"] = wk_rows // head_dim
1190
 
1191
- # engram params from blocks.0.engram.embeddings.*_0 [table_size, engram_dim]
1192
  for key, val in state_dict.items():
1193
  if ".engram.embeddings." in key and key.endswith("_0") and val.dim() == 2:
1194
  overrides["engram_table_size"] = val.shape[0]
1195
  overrides["engram_dim"] = val.shape[1]
1196
  break
1197
- # engram_heads from branch_conv [total_branch_dim, 1, 4]
1198
- # total_branch_dim = engram_dim * n_heads * (max_ngram - 1)
1199
  engram_dim = overrides.get("engram_dim", int(cfg.get("engram_dim", 0)))
1200
  engram_max_ngram = int(cfg.get("engram_max_ngram", 2))
1201
  if engram_dim > 0:
@@ -1207,7 +1662,6 @@ def _infer_arch_from_state_dict(state_dict, cfg):
1207
  overrides["engram_heads"] = total_branch_dim // denom
1208
  break
1209
 
1210
- # merge: checkpoint values take priority over series config
1211
  merged = dict(cfg)
1212
  merged.update(overrides)
1213
  return merged
@@ -1221,8 +1675,6 @@ def load_local_model(model_path: Path, tokenizer_path: Path, series: str) -> dic
1221
 
1222
  state_dict = ckpt.get("model_state") or ckpt.get("state_dict") or ckpt
1223
 
1224
- # Infer architecture from checkpoint weights so config mismatches are
1225
- # handled automatically.
1226
  cfg = _infer_arch_from_state_dict(state_dict, cfg)
1227
 
1228
  engram_dim = int(cfg.get("engram_dim", 0))
@@ -1233,38 +1685,65 @@ def load_local_model(model_path: Path, tokenizer_path: Path, series: str) -> dic
1233
  if mhc_expansion == 1:
1234
  mhc_expansion = int(cfg.get("mhc_expansion", 1))
1235
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1236
  model = TinyMemoryLM(
1237
  vocab_size=vocab_size,
1238
  dim=int(cfg.get("dim", model_config.dim)),
1239
- n_unique_layers=int(cfg.get("n_unique_layers", model_config.n_unique_layers)),
1240
- n_logical_layers=int(
1241
- cfg.get("n_logical_layers", model_config.n_logical_layers)
1242
- ),
1243
  n_heads=int(cfg.get("n_heads", model_config.n_heads)),
1244
  n_kv_heads=int(cfg.get("n_kv_heads", model_config.n_kv_heads)),
1245
  ffn_dim=int(cfg.get("ffn_dim", model_config.ffn_dim)),
1246
  dropout=float(cfg.get("dropout", model_config.dropout)),
1247
- mtp_horizons=tuple(
1248
- int(v) for v in cfg.get("mtp_horizons", model_config.mtp_horizons)
1249
- ),
1250
  grad_checkpoint=False,
1251
- sliding_window=int(
1252
- cfg.get(
1253
- "sliding_window_size",
1254
- getattr(model_config, "sliding_window_size", 512),
1255
- )
1256
- ),
1257
- rope_fraction=float(
1258
- cfg.get("rope_fraction", getattr(model_config, "rope_fraction", 0.25))
1259
- ),
1260
- embed_scale=bool(
1261
- cfg.get("embed_scale", getattr(model_config, "embed_scale", True))
1262
- ),
1263
  engram_dim=engram_dim,
1264
  engram_heads=int(cfg.get("engram_heads", 4)),
1265
  engram_table_size=int(cfg.get("engram_table_size", 8192)),
1266
  engram_max_ngram=int(cfg.get("engram_max_ngram", 3)),
1267
  mhc_expansion=mhc_expansion,
 
 
 
 
 
 
 
 
 
 
 
1268
  )
1269
  model.load_state_dict(state_dict, strict=False)
1270
  model.eval()
@@ -1277,6 +1756,8 @@ def load_local_model(model_path: Path, tokenizer_path: Path, series: str) -> dic
1277
  "tokenizer": tokenizer,
1278
  "device": device,
1279
  "series": series,
 
 
1280
  }
1281
 
1282
 
@@ -1300,7 +1781,13 @@ def download_huggingface_model(hf_id: str, cache_dir: Path) -> dict:
1300
 
1301
  print(f"Using cached {hf_id} from {local_dir}")
1302
 
1303
- model_dir = local_dir / "model" if (local_dir / "model").exists() else local_dir
 
 
 
 
 
 
1304
  model_path = model_dir / "model.pt"
1305
  pretrain_path = model_dir / "pretrain.pt"
1306
  tokenizer_path = model_dir / "tokenizer.json"
@@ -1454,6 +1941,305 @@ def compare_all_models(prompt: str, cfg: dict) -> None:
1454
  print(f"\n{'='*60}")
1455
 
1456
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1457
  # ---------------------------------------------------------------------------
1458
  # Interactive CLI
1459
  # ---------------------------------------------------------------------------
@@ -1524,66 +2310,64 @@ def pick_model(runs_dir: Path) -> tuple[dict, str]:
1524
  # ---------------------------------------------------------------------------
1525
 
1526
  MODES = {
1527
- # Chat — two flavours
1528
  "chat-coherent": {
1529
  "label": "Chat — Coherent",
1530
  "desc": "structured, consistent, strong repetition control",
1531
  "sft_mode": "chat",
1532
- "temperature": 0.5,
1533
- "top_k": 40,
1534
- "top_p": 0.9,
1535
- "min_p": 0.06,
1536
- "no_repeat_ngram_size": 5,
1537
- "repetition_penalty": 1.15,
1538
- "logit_soft_cap": 25.0,
1539
- "loop_penalty": 15.0,
1540
- "max_new_tokens": 256,
1541
  "context_window": 2048,
1542
  },
1543
  "chat-variants": {
1544
  "label": "Chat — Variants",
1545
  "desc": "creative, diverse, more surprising outputs",
1546
  "sft_mode": "chat",
1547
- "temperature": 0.72,
1548
- "top_k": 50,
1549
- "top_p": 0.93,
1550
- "min_p": 0.04,
1551
- "no_repeat_ngram_size": 5,
1552
- "repetition_penalty": 1.1,
1553
- "logit_soft_cap": 25.0,
1554
- "loop_penalty": 15.0,
1555
- "max_new_tokens": 256,
1556
  "context_window": 2048,
1557
  },
1558
- # Pretrain — two flavours
1559
  "pretrain-coherent": {
1560
  "label": "Pretrain — Coherent",
1561
  "desc": "grounded continuation, low temperature, tight sampling",
1562
  "sft_mode": False,
1563
- "temperature": 0.25,
1564
  "top_k": 20,
1565
  "top_p": 0.85,
1566
  "min_p": 0.10,
1567
- "no_repeat_ngram_size": 5,
1568
  "repetition_penalty": 1.2,
1569
- "logit_soft_cap": 25.0,
1570
- "loop_penalty": 15.0,
1571
- "max_new_tokens": 256,
1572
  "context_window": 2048,
1573
  },
1574
  "pretrain-variants": {
1575
  "label": "Pretrain — Variants",
1576
  "desc": "free-form continuation, higher temperature, more exploration",
1577
  "sft_mode": False,
1578
- "temperature": 0.72,
1579
  "top_k": 60,
1580
- "top_p": 0.95,
1581
- "min_p": 0.03,
1582
  "no_repeat_ngram_size": 4,
1583
- "repetition_penalty": 1.1,
1584
- "logit_soft_cap": 25.0,
1585
  "loop_penalty": 12.0,
1586
- "max_new_tokens": 256,
1587
  "context_window": 2048,
1588
  },
1589
  }
@@ -1681,8 +2465,11 @@ _FALLBACK_COLLECTION = [
1681
  {"version": "TMLM-Haiku-2", "hf_id": "CompactAI-O/TMLM-Haiku-2"},
1682
  {"version": "TMLM-Haiku-1.3", "hf_id": "CompactAI-O/TMLM-Haiku-1.3"},
1683
  {"version": "TMLM-Haiku-1", "hf_id": "CompactAI-O/TMLM-Haiku-1"},
 
1684
  ]
1685
 
 
 
1686
 
1687
  def _probe_repo(hf_id: str) -> dict | None:
1688
  """Return entry dict for one repo, or None if no usable checkpoints found."""
@@ -1710,6 +2497,7 @@ def _probe_repo(hf_id: str) -> dict | None:
1710
 
1711
  _LABELS = {
1712
  "model.pt": ("Chat (SFT)", False),
 
1713
  "pretrain.pt": ("Pretrain (base)", True),
1714
  }
1715
 
@@ -1750,6 +2538,7 @@ def fetch_collection() -> list[dict]:
1750
  infos = [type("M", (), {"id": e["hf_id"]})() for e in _FALLBACK_COLLECTION]
1751
 
1752
  entries = []
 
1753
  for info in infos:
1754
  repo_id = info.id
1755
  if _SEARCH.lower() not in repo_id.lower():
@@ -1757,10 +2546,18 @@ def fetch_collection() -> list[dict]:
1757
  entry = _probe_repo(repo_id)
1758
  if entry:
1759
  entries.append(entry)
 
 
 
 
 
 
 
 
 
1760
 
1761
  if not entries:
1762
  print(" No models found; using fallback list.")
1763
- entries = []
1764
  for fb in _FALLBACK_COLLECTION:
1765
  e = _probe_repo(fb["hf_id"])
1766
  if e:
@@ -1848,11 +2645,32 @@ def pick_checkpoint(entry: dict) -> tuple[str, bool]:
1848
 
1849
 
1850
  def main() -> None:
 
 
 
 
 
 
 
 
 
1851
  print("=" * 56)
1852
- print(" TMLM-Haiku Interactive Chat")
1853
  print(" Models: huggingface.co/CompactAI-O")
1854
  print("=" * 56)
1855
 
 
 
 
 
 
 
 
 
 
 
 
 
1856
  collection = fetch_collection()
1857
  if not collection:
1858
  print("No models found. Check your internet connection.")
@@ -1861,6 +2679,11 @@ def main() -> None:
1861
  entry = pick_version(collection)
1862
  fname, is_pretrain = pick_checkpoint(entry)
1863
 
 
 
 
 
 
1864
  root = Path(__file__).resolve().parent
1865
  cache_dir = root / "cache" / "huggingface"
1866
  cache_dir.mkdir(parents=True, exist_ok=True)
@@ -1880,9 +2703,41 @@ def main() -> None:
1880
  print(f"Loading {entry['version']} / {fname} ...")
1881
  bundle = load_local_model(model_path, tokenizer_path, "Haiku")
1882
 
1883
- cfg = pick_mode(is_pretrain)
1884
- _run_loop(bundle, cfg)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1885
 
1886
 
1887
  if __name__ == "__main__":
1888
  main()
 
 
25
  import torch.nn.functional as F
26
  from torch.utils.checkpoint import checkpoint
27
 
28
+
29
  HUGGINGFACE_MODELS = {
30
  "TMLM-Haiku-1": "CompactAI-O/TMLM-Haiku-1",
31
  "TMLM-Haiku-1.3": "CompactAI-O/TMLM-Haiku-1.3",
32
  "TMLM-Haiku-2": "CompactAI-O/TMLM-Haiku-2",
33
+ "Glint-1": "CompactAI-O/Glint-1",
34
  }
35
 
36
 
 
81
  "engram_table_size": 64,
82
  "engram_max_ngram": 2,
83
  "mhc_expansion": 2,
84
+ "sleep_gate_cap": 0,
85
+ "sleep_gate_heads": 4,
86
+ "latent_think_layers": 0,
87
+ "prelude_layers": 0,
88
+ "coda_layers": 0,
89
+ "recurrent_loops": 0,
90
+ "recurrent_act_threshold": 0.9,
91
+ "recurrent_lora_rank": 0,
92
+ "recurrent_loop_embed_dim": 0,
93
  },
94
  "sonnet": {
95
  "dim": 1024,
 
106
  "engram_table_size": 4096,
107
  "engram_max_ngram": 2,
108
  "mhc_expansion": 2,
109
+ "sleep_gate_cap": 0,
110
+ "sleep_gate_heads": 8,
111
+ "latent_think_layers": 0,
112
+ "prelude_layers": 0,
113
+ "coda_layers": 0,
114
+ "recurrent_loops": 0,
115
+ "recurrent_act_threshold": 0.99,
116
+ "recurrent_lora_rank": 0,
117
+ "recurrent_loop_embed_dim": 0,
118
  },
119
  "opus": {
120
  "dim": 1536,
 
131
  "engram_table_size": 8192,
132
  "engram_max_ngram": 2,
133
  "mhc_expansion": 4,
134
+ "sleep_gate_cap": 0,
135
+ "sleep_gate_heads": 8,
136
+ "latent_think_layers": 0,
137
+ "prelude_layers": 0,
138
+ "coda_layers": 0,
139
+ "recurrent_loops": 0,
140
+ "recurrent_act_threshold": 0.99,
141
+ "recurrent_lora_rank": 0,
142
+ "recurrent_loop_embed_dim": 0,
143
  },
144
  }
145
 
 
451
  return out
452
 
453
 
454
+ def loop_index_embedding(h: torch.Tensor, loop_t: int, loop_dim: int, theta: float = 10000.0) -> torch.Tensor:
455
+ if loop_dim <= 0:
456
+ return h
457
+ loop_dim = min(loop_dim, h.shape[-1])
458
+ if loop_dim % 2 == 1:
459
+ loop_dim -= 1
460
+ if loop_dim <= 0:
461
+ return h
462
+ inv_freq = 1.0 / (theta ** (torch.arange(0, loop_dim, 2, device=h.device, dtype=h.dtype) / loop_dim))
463
+ phase = torch.tensor(float(loop_t), device=h.device, dtype=h.dtype) * inv_freq
464
+ loop_embed = torch.cat([phase.sin(), phase.cos()], dim=0).view(1, 1, loop_dim)
465
+ out = h.clone()
466
+ out[..., :loop_dim] = out[..., :loop_dim] + loop_embed
467
+ return out
468
+
469
+
470
+ class DepthLoRAAdapter(nn.Module):
471
+ def __init__(self, dim: int, rank: int, max_loops: int) -> None:
472
+ super().__init__()
473
+ self.rank = max(0, rank)
474
+ if self.rank <= 0:
475
+ self.down = None
476
+ self.B = None
477
+ self.scale = None
478
+ return
479
+ self.down = nn.Linear(dim, self.rank, bias=False)
480
+ self.B = nn.Parameter(torch.randn(self.rank, dim) * 0.02)
481
+ self.scale = nn.Embedding(max(1, max_loops), self.rank)
482
+ nn.init.zeros_(self.scale.weight)
483
+
484
+ def forward(self, x: torch.Tensor, loop_t: int) -> torch.Tensor:
485
+ if self.rank <= 0 or self.down is None or self.B is None or self.scale is None:
486
+ return torch.zeros_like(x)
487
+ t_idx = min(loop_t, self.scale.num_embeddings - 1)
488
+ scale = self.scale(torch.tensor(t_idx, device=x.device))
489
+ return (self.down(x) * scale) @ self.B
490
+
491
+
492
+ class StableRecurrentInjection(nn.Module):
493
+ def __init__(self, dim: int) -> None:
494
+ super().__init__()
495
+ self.log_A = nn.Parameter(torch.full((dim,), -2.0))
496
+ self.log_dt = nn.Parameter(torch.full((dim,), -2.0))
497
+ self.input_gate = nn.Parameter(torch.zeros(dim))
498
+
499
+ def forward(self, h: torch.Tensor, e: torch.Tensor, transformer_out: torch.Tensor) -> torch.Tensor:
500
+ A = torch.exp(-torch.exp((self.log_dt + self.log_A).clamp(-20, 20))).view(1, 1, -1)
501
+ B = torch.sigmoid(self.input_gate).view(1, 1, -1)
502
+ return A * h + B * e + transformer_out
503
+
504
+
505
+ class AdaptiveHalting(nn.Module):
506
+ def __init__(self, dim: int) -> None:
507
+ super().__init__()
508
+ self.halt = nn.Linear(dim, 1, bias=True)
509
+ nn.init.zeros_(self.halt.weight)
510
+ nn.init.constant_(self.halt.bias, -2.0)
511
+
512
+ def forward(self, h: torch.Tensor) -> torch.Tensor:
513
+ return torch.sigmoid(self.halt(h)).squeeze(-1)
514
+
515
+
516
  class EngramBlock(nn.Module):
517
  """DeepSeek Engram: conditional memory via O(1) hashed N-gram lookup.
518
 
 
657
  return gate * value
658
 
659
 
660
+ class SleepGate(nn.Module):
661
+ """Persistent memory + periodic consolidation gate."""
662
+
663
+ def __init__(
664
+ self,
665
+ dim: int,
666
+ cap: int = 128,
667
+ n_heads: int = 4,
668
+ retention_enabled: bool = True,
669
+ retention_hidden: int = 0,
670
+ ) -> None:
671
+ super().__init__()
672
+ self.dim = dim
673
+ self.cap = cap
674
+ self.n_heads = n_heads
675
+ self.head_dim = dim // n_heads
676
+ self.scale = self.head_dim ** -0.5
677
+ self.retention_enabled = retention_enabled
678
+
679
+ self.register_buffer("mem_emb", torch.zeros(cap, dim, dtype=torch.bfloat16))
680
+ self.register_buffer("mem_age", torch.zeros(cap, dtype=torch.long))
681
+ self.register_buffer("mem_beta", torch.ones(cap, dtype=torch.float32))
682
+ self.register_buffer("mem_count", torch.zeros((), dtype=torch.long))
683
+ self.register_buffer("mem_head", torch.zeros((), dtype=torch.long))
684
+ self.register_buffer("global_step", torch.zeros((), dtype=torch.long))
685
+
686
+ self.q_proj = nn.Linear(dim, dim, bias=False)
687
+ self.k_proj = nn.Linear(dim, dim, bias=False)
688
+ self.v_proj = nn.Linear(dim, dim, bias=False)
689
+ self.o_proj = nn.Linear(dim, dim, bias=False)
690
+ nn.init.zeros_(self.o_proj.weight)
691
+ self.gate_scale = nn.Parameter(torch.zeros(()))
692
+
693
+ if retention_enabled:
694
+ if retention_hidden > 0:
695
+ self.retention_gate: Optional[nn.Module] = nn.Sequential(
696
+ nn.Linear(dim, retention_hidden, bias=False),
697
+ nn.GELU(),
698
+ nn.Linear(retention_hidden, 1, bias=True),
699
+ )
700
+ nn.init.constant_(self.retention_gate[-1].bias, 2.2)
701
+ else:
702
+ self.retention_gate = nn.Linear(dim, 1, bias=True)
703
+ nn.init.constant_(self.retention_gate.bias, 2.2)
704
+ else:
705
+ self.retention_gate = None
706
+
707
+ self._last_beta: Optional[torch.Tensor] = None
708
+
709
+ def write(self, hidden: torch.Tensor) -> None:
710
+ B, T, _ = hidden.shape
711
+ tail_full = hidden[:, max(0, T - 16):, :].float().mean(dim=1)
712
+ if self.retention_gate is not None:
713
+ beta_live = torch.sigmoid(self.retention_gate(tail_full).squeeze(-1))
714
+ self._last_beta = beta_live if self.training else None
715
+ beta_store = beta_live.detach().float()
716
+ else:
717
+ self._last_beta = None
718
+ beta_store = torch.ones(B, device=hidden.device, dtype=torch.float32)
719
+ tail = tail_full.to(self.mem_emb.dtype).detach()
720
+ with torch.no_grad():
721
+ head = int(self.mem_head.item())
722
+ count = int(self.mem_count.item())
723
+ step = int(self.global_step.item())
724
+ for b in range(B):
725
+ self.mem_emb[head] = tail[b]
726
+ self.mem_age[head] = step
727
+ self.mem_beta[head] = beta_store[b]
728
+ head = (head + 1) % self.cap
729
+ if count < self.cap:
730
+ count += 1
731
+ self.mem_head.fill_(head)
732
+ self.mem_count.fill_(count)
733
+
734
+ def read(self, x: torch.Tensor) -> torch.Tensor:
735
+ count = int(self.mem_count.item())
736
+ if count == 0:
737
+ return torch.zeros_like(x)
738
+ B, T, D = x.shape
739
+ mem = self.mem_emb[:count].clone().to(x.dtype)
740
+ q = self.q_proj(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
741
+ k = self.k_proj(mem).view(count, self.n_heads, self.head_dim).transpose(0, 1)
742
+ v = self.v_proj(mem).view(count, self.n_heads, self.head_dim).transpose(0, 1)
743
+ attn = torch.einsum("bhtd,hmd->bhtm", q, k) * self.scale
744
+ attn = F.softmax(attn, dim=-1)
745
+ if self.retention_enabled:
746
+ step = int(self.global_step.item())
747
+ ages = self.mem_age[:count].to(x.device)
748
+ delta = (step - ages).clamp(min=0).to(x.dtype)
749
+ betas = self.mem_beta[:count].to(x.dtype).clamp(min=1e-6, max=1.0)
750
+ weights = betas.pow(delta)
751
+ attn = attn * weights.view(1, 1, 1, count)
752
+ attn = attn / attn.sum(dim=-1, keepdim=True).clamp_min(1e-9)
753
+ out = torch.einsum("bhtm,hmd->bhtd", attn, v)
754
+ out = out.transpose(1, 2).contiguous().view(B, T, D)
755
+ out = self.o_proj(out)
756
+ return torch.sigmoid(self.gate_scale) * out
757
+
758
+ @torch.no_grad()
759
+ def reset(self) -> None:
760
+ self.mem_emb.zero_()
761
+ self.mem_age.zero_()
762
+ self.mem_beta.fill_(1.0)
763
+ self.mem_count.zero_()
764
+ self.mem_head.zero_()
765
+ self.global_step.zero_()
766
+ self._last_beta = None
767
+
768
+
769
  def _sinkhorn_knopp(logits: torch.Tensor, n_iters: int = 7) -> torch.Tensor:
770
  M = torch.exp(logits.clamp(-10, 10))
771
  for _ in range(n_iters):
 
933
  return x, new_kv
934
 
935
 
936
+ class RecurrentDepthBlock(nn.Module):
937
+ def __init__(
938
+ self,
939
+ dim: int,
940
+ n_heads: int,
941
+ n_kv_heads: int,
942
+ head_dim: int,
943
+ ffn_dim: int,
944
+ dropout: float,
945
+ sliding_window: int,
946
+ rope_fraction: float,
947
+ n_loops: int,
948
+ act_threshold: float,
949
+ lora_rank: int,
950
+ loop_embed_dim: int,
951
+ ) -> None:
952
+ super().__init__()
953
+ self.n_loops = max(1, n_loops)
954
+ self.act_threshold = act_threshold
955
+ self.loop_embed_dim = max(0, loop_embed_dim)
956
+ self.norm = RMSNorm(dim)
957
+ self.block = TransformerBlock(
958
+ dim=dim, n_heads=n_heads, n_kv_heads=n_kv_heads, head_dim=head_dim,
959
+ ffn_dim=ffn_dim, dropout=dropout, sliding_window=sliding_window,
960
+ rope_fraction=rope_fraction, engram_dim=0, mhc_expansion=1,
961
+ )
962
+ self.injection = StableRecurrentInjection(dim)
963
+ self.act = AdaptiveHalting(dim)
964
+ self.lora = DepthLoRAAdapter(dim, lora_rank, self.n_loops)
965
+
966
+ def forward(
967
+ self,
968
+ h: torch.Tensor,
969
+ e: torch.Tensor,
970
+ token_ids: Optional[torch.Tensor] = None,
971
+ past_key_values: Optional[List[Optional[Tuple[torch.Tensor, torch.Tensor]]]] = None,
972
+ use_cache: bool = False,
973
+ n_loops: Optional[int] = None,
974
+ ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], Optional[List[Tuple[torch.Tensor, torch.Tensor]]]]:
975
+ loops = max(1, n_loops or self.n_loops)
976
+ B, T, _ = h.shape
977
+ halted = torch.zeros(B, T, device=h.device, dtype=torch.bool)
978
+ cumulative_p = torch.zeros(B, T, device=h.device, dtype=h.dtype)
979
+ output = torch.zeros_like(h)
980
+ new_past: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = [] if use_cache else None
981
+ current = h
982
+ final_halt = None
983
+
984
+ for t in range(loops):
985
+ h_loop = loop_index_embedding(current, t, self.loop_embed_dim)
986
+ combined = self.norm(h_loop + e)
987
+ past_kv = None
988
+ if past_key_values is not None and t < len(past_key_values):
989
+ past_kv = past_key_values[t]
990
+ trans_out, layer_kv = self.block(combined, is_global=True, past_kv=past_kv, use_cache=use_cache, token_ids=token_ids)
991
+ trans_out = trans_out + self.lora(trans_out, t)
992
+ next_h = self.injection(current, e, trans_out)
993
+ p = self.act(next_h)
994
+ p = p * (~halted).to(p.dtype)
995
+ final_halt = p
996
+ should_halt = (~halted) & ((cumulative_p + p) >= self.act_threshold)
997
+ update_weight = torch.where(should_halt, (1.0 - cumulative_p).clamp(min=0.0), p)
998
+ output = output + next_h * update_weight.unsqueeze(-1)
999
+ cumulative_p = cumulative_p + update_weight
1000
+ current = torch.where(halted.unsqueeze(-1), current, next_h)
1001
+ halted = halted | should_halt
1002
+ if new_past is not None:
1003
+ new_past.append(layer_kv)
1004
+ if not use_cache and bool(halted.all()):
1005
+ break
1006
+
1007
+ remainder = (1.0 - cumulative_p).clamp(min=0.0)
1008
+ output = output + current * remainder.unsqueeze(-1)
1009
+ aux: Dict[str, torch.Tensor] = {}
1010
+ if final_halt is not None:
1011
+ aux["recurrent_halt_mean"] = final_halt.mean()
1012
+ return output, aux, new_past
1013
+
1014
+
1015
  class TinyMemoryLM(nn.Module):
1016
  def __init__(
1017
  self,
 
1033
  engram_table_size: int = 8192,
1034
  engram_max_ngram: int = 3,
1035
  mhc_expansion: int = 1,
1036
+ sleep_gate_cap: int = 0,
1037
+ sleep_gate_heads: int = 4,
1038
+ sleep_retention_enabled: bool = True,
1039
+ sleep_retention_hidden: int = 0,
1040
+ latent_think_layers: int = 0,
1041
+ prelude_layers: int = 0,
1042
+ coda_layers: int = 0,
1043
+ recurrent_loops: int = 0,
1044
+ recurrent_act_threshold: float = 0.99,
1045
+ recurrent_lora_rank: int = 0,
1046
+ recurrent_loop_embed_dim: int = 0,
1047
  ) -> None:
1048
  super().__init__()
1049
  self.dim = dim
 
1056
  self.embed_tokens = nn.Embedding(vocab_size, dim)
1057
  self.head = nn.Linear(dim, vocab_size, bias=False)
1058
  self.head.weight = self.embed_tokens.weight
 
1059
  self.output_bias = nn.Parameter(torch.zeros(vocab_size))
1060
 
1061
+ self.use_recurrent_depth = recurrent_loops > 0
1062
+ self.prelude_layers = max(0, prelude_layers)
1063
+ self.coda_layers = max(0, coda_layers)
1064
+ self.recurrent_loops = max(0, recurrent_loops)
1065
+
1066
+ self.blocks: Optional[nn.ModuleList] = None
1067
+ self.prelude: Optional[nn.ModuleList] = None
1068
+ self.recurrent: Optional[RecurrentDepthBlock] = None
1069
+ self.coda: Optional[nn.ModuleList] = None
1070
+
1071
+ def _make_blocks(n: int) -> nn.ModuleList:
1072
+ return nn.ModuleList([
1073
  TransformerBlock(
1074
+ dim=dim, n_heads=n_heads, n_kv_heads=n_kv_heads, head_dim=head_dim,
1075
+ ffn_dim=ffn_dim, dropout=dropout, sliding_window=sliding_window,
1076
+ rope_fraction=rope_fraction, engram_dim=engram_dim,
1077
+ engram_heads=engram_heads, engram_table_size=engram_table_size,
1078
+ engram_max_ngram=engram_max_ngram, mhc_expansion=mhc_expansion,
 
 
 
 
 
 
 
 
1079
  )
1080
+ for _ in range(n)
1081
+ ])
1082
+
1083
+ if self.use_recurrent_depth:
1084
+ if self.prelude_layers > 0:
1085
+ self.prelude = _make_blocks(self.prelude_layers)
1086
+ self.recurrent = RecurrentDepthBlock(
1087
+ dim=dim, n_heads=n_heads, n_kv_heads=n_kv_heads, head_dim=head_dim,
1088
+ ffn_dim=ffn_dim, dropout=dropout, sliding_window=sliding_window,
1089
+ rope_fraction=rope_fraction, n_loops=self.recurrent_loops,
1090
+ act_threshold=recurrent_act_threshold, lora_rank=recurrent_lora_rank,
1091
+ loop_embed_dim=recurrent_loop_embed_dim or max(2, dim // 8),
1092
+ )
1093
+ if self.coda_layers > 0:
1094
+ self.coda = _make_blocks(self.coda_layers)
1095
+ else:
1096
+ self.blocks = _make_blocks(max(1, n_unique_layers))
1097
+
1098
  self.norm = RMSNorm(dim)
1099
 
1100
  self.mtp_horizons = sorted({int(h) for h in mtp_horizons if int(h) > 1})
 
1105
  {str(h): RMSNorm(dim) for h in self.mtp_horizons}
1106
  )
1107
 
1108
+ res_scale = (2 * max(1, n_logical_layers)) ** -0.5
1109
+ for group in (self.blocks, self.prelude, self.coda):
1110
+ if group is None:
1111
+ continue
1112
+ for block in group:
1113
+ block.attn.wo.weight.data.mul_(res_scale)
1114
+ block.ffn.down.weight.data.mul_(res_scale)
1115
+ if self.recurrent is not None:
1116
+ self.recurrent.block.attn.wo.weight.data.mul_(res_scale)
1117
+ self.recurrent.block.ffn.down.weight.data.mul_(res_scale)
1118
+
1119
+ self.sleep_gate: Optional[SleepGate] = None
1120
+ if sleep_gate_cap > 0:
1121
+ self.sleep_gate = SleepGate(
1122
+ dim=dim, cap=sleep_gate_cap, n_heads=sleep_gate_heads,
1123
+ retention_enabled=sleep_retention_enabled,
1124
+ retention_hidden=sleep_retention_hidden,
1125
+ )
1126
+
1127
+ self.think_blocks: Optional[nn.ModuleList] = None
1128
+ self.think_norm: Optional[RMSNorm] = None
1129
+ if latent_think_layers > 0:
1130
+ self.think_blocks = nn.ModuleList([
1131
+ TransformerBlock(
1132
+ dim=dim, n_heads=n_heads, n_kv_heads=n_kv_heads, head_dim=head_dim,
1133
+ ffn_dim=ffn_dim, dropout=0.0, sliding_window=2048,
1134
+ rope_fraction=rope_fraction, engram_dim=0, mhc_expansion=1,
1135
+ )
1136
+ for _ in range(latent_think_layers)
1137
+ ])
1138
+ self.think_norm = RMSNorm(dim)
1139
 
1140
  def resize_token_embeddings(self, new_vocab_size: int) -> None:
1141
  old_vocab_size = self.embed_tokens.num_embeddings
 
1143
  return
1144
  device = self.embed_tokens.weight.device
1145
  old_embed_weight = self.embed_tokens.weight.data.clone()
1146
+ self.embed_tokens = nn.Embedding(new_vocab_size, self.embed_tokens.embedding_dim).to(device)
1147
+ self.head = nn.Linear(self.embed_tokens.embedding_dim, new_vocab_size, bias=False).to(device)
 
 
 
 
1148
  self.head.weight = self.embed_tokens.weight
1149
  old_bias = self.output_bias.data.clone()
1150
  self.output_bias = nn.Parameter(torch.zeros(new_vocab_size, device=device))
 
1153
  self.embed_tokens.weight.data[:copy_size] = old_embed_weight[:copy_size]
1154
 
1155
  def _build_logical_layers(self) -> List[Tuple[nn.Module, int]]:
1156
+ if self.blocks is None:
1157
+ return []
1158
  blocks_list = list(self.blocks)
1159
  full_sequence = blocks_list + blocks_list
1160
+ return [(block, i) for i, block in enumerate(full_sequence[: self.n_logical_layers])]
 
 
1161
 
1162
  def forward(
1163
  self,
1164
  ids: torch.Tensor,
1165
  use_cache: bool = False,
1166
+ past_key_values: Optional[List[Optional[Tuple[torch.Tensor, torch.Tensor]]]] = None,
 
 
1167
  return_hidden: bool = False,
1168
+ ) -> Tuple[torch.Tensor, Dict[int, torch.Tensor], Dict[str, torch.Tensor], Optional[torch.Tensor], Optional[List[Tuple[torch.Tensor, torch.Tensor]]]]:
 
 
 
 
 
1169
  B, T = ids.shape
1170
  x = self.embed_tokens(ids) * self.embed_scale_factor
1171
+ new_past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = [] if use_cache else None
1172
+ aux: Dict[str, torch.Tensor] = {}
1173
+
1174
+ if self.use_recurrent_depth:
1175
+ offset = 0
1176
+ if self.prelude is not None:
1177
+ for block in self.prelude:
1178
+ past_kv = past_key_values[offset] if past_key_values is not None and offset < len(past_key_values) else None
1179
+ x, layer_kv = block(x, is_global=True, past_kv=past_kv, use_cache=use_cache, token_ids=ids)
1180
+ if new_past_key_values is not None:
1181
+ new_past_key_values.append(layer_kv)
1182
+ offset += 1
1183
+ encoded = x
1184
+ recurrent_past = past_key_values[offset: offset + self.recurrent_loops] if past_key_values is not None else None
1185
+ x, recurrent_aux, recurrent_kv = self.recurrent(
1186
+ x, encoded, token_ids=ids, past_key_values=recurrent_past, use_cache=use_cache,
1187
  )
1188
+ aux.update(recurrent_aux)
1189
+ if new_past_key_values is not None and recurrent_kv is not None:
1190
+ new_past_key_values.extend(recurrent_kv)
1191
+ offset += self.recurrent_loops
1192
+ if self.coda is not None:
1193
+ for block in self.coda:
1194
+ past_kv = past_key_values[offset] if past_key_values is not None and offset < len(past_key_values) else None
1195
+ x, layer_kv = block(x, is_global=True, past_kv=past_kv, use_cache=use_cache, token_ids=ids)
1196
+ if new_past_key_values is not None:
1197
+ new_past_key_values.append(layer_kv)
1198
+ offset += 1
1199
+ else:
1200
+ logical_layers = self._build_logical_layers()
1201
+ last_logical_idx = len(logical_layers) - 1
1202
+ for layer_idx, (block, logical_idx) in enumerate(logical_layers):
1203
+ is_global = logical_idx % 2 == 0 or layer_idx == last_logical_idx
1204
+ past_kv = past_key_values[layer_idx] if past_key_values is not None and layer_idx < len(past_key_values) else None
1205
+ if self.grad_checkpoint and self.training and not use_cache:
1206
+ x, layer_kv = checkpoint(block, x, is_global, past_kv, use_cache, ids, use_reentrant=True)
1207
+ else:
1208
+ x, layer_kv = block(x, is_global, past_kv, use_cache, ids)
1209
+ if new_past_key_values is not None:
1210
+ new_past_key_values.append(layer_kv)
1211
 
1212
+ x = self.norm(x)
 
 
 
 
 
 
 
 
 
 
 
1213
 
1214
+ if self.sleep_gate is not None:
1215
+ x = x + self.sleep_gate.read(x)
1216
+ if self.training:
1217
+ self.sleep_gate.write(x)
1218
+
1219
+ if self.think_blocks is not None:
1220
+ for think_block in self.think_blocks:
1221
+ x, _ = think_block(x, is_global=True)
1222
+ x = self.think_norm(x)
1223
 
 
1224
  h_out = x if return_hidden else None
1225
  logits = self.head(x)
1226
  if self.embed_scale_factor != 1.0:
 
1240
  mtp_logits = mtp_logits + self.output_bias
1241
  mtp[horizon] = mtp_logits
1242
 
1243
+ return logits, mtp, aux, h_out, new_past_key_values
1244
 
1245
 
1246
  # ---------------------------------------------------------------------------
 
1352
  ctx_ids = (
1353
  input_ids_t[:, -context_window:] if context_window > 0 else input_ids_t
1354
  )
1355
+ logits, *_ = model(ctx_ids)
1356
  next_logits = logits[0, -1, :].clone()
1357
 
1358
  # Logit soft-capping (Gemma-style) — prevents overconfident collapse
 
1463
  if not tokenizer_path.exists():
1464
  continue
1465
  name = child.name
1466
+ series = None
1467
  for ckpt_name in ("model.pt", "pretrain.pt"):
1468
+ ckpt_path = child / ckpt_name
1469
+ if ckpt_path.exists():
1470
+ series = _fast_series_from_checkpoint(ckpt_path)
1471
+ break
1472
+ if series is None:
1473
+ series = series_from_name(name) or "Sonnet"
1474
+ found = False
1475
+ for ckpt_name in ("model.pt", "model_rep.pt", "pretrain.pt"):
1476
  ckpt_path = child / ckpt_name
1477
  if ckpt_path.exists():
1478
  models.append(
 
1484
  "tokenizer_path": tokenizer_path,
1485
  }
1486
  )
1487
+ found = True
1488
+ if not found:
1489
+ step_ckpts = sorted(
1490
+ child.glob("checkpoint_step_*.pt"),
1491
+ key=lambda p: int(p.stem.rsplit("_", 1)[-1]),
1492
+ )
1493
+ if step_ckpts:
1494
+ ckpt_path = step_ckpts[-1]
1495
+ models.append(
1496
+ {
1497
+ "name": name,
1498
+ "checkpoint": ckpt_path.name,
1499
+ "series": series,
1500
+ "model_path": ckpt_path,
1501
+ "tokenizer_path": tokenizer_path,
1502
+ }
1503
+ )
1504
  return models
1505
 
1506
 
 
1519
  return 1
1520
 
1521
 
1522
+ def _detect_sleep_gate(state_dict) -> Tuple[int, int]:
1523
+ for key, val in state_dict.items():
1524
+ if key == "sleep_gate.mem_emb" and val.dim() == 2:
1525
+ cap = val.shape[0]
1526
+ return cap, 4
1527
+ return 0, 4
1528
+
1529
+
1530
+ def _detect_latent_think(state_dict) -> int:
1531
+ indices = {
1532
+ int(k.split(".")[1])
1533
+ for k in state_dict
1534
+ if k.startswith("think_blocks.") and k.split(".")[1].isdigit()
1535
+ }
1536
+ return max(indices) + 1 if indices else 0
1537
+
1538
+
1539
+ def _detect_prelude_layers(state_dict) -> int:
1540
+ indices = {
1541
+ int(k.split(".")[1])
1542
+ for k in state_dict
1543
+ if k.startswith("prelude.") and k.split(".")[1].isdigit()
1544
+ }
1545
+ return max(indices) + 1 if indices else 0
1546
+
1547
+
1548
+ def _detect_coda_layers(state_dict) -> int:
1549
+ indices = {
1550
+ int(k.split(".")[1])
1551
+ for k in state_dict
1552
+ if k.startswith("coda.") and k.split(".")[1].isdigit()
1553
+ }
1554
+ return max(indices) + 1 if indices else 0
1555
+
1556
+
1557
+ def _detect_recurrent_loops(state_dict) -> int:
1558
+ if "recurrent.norm.weight" in state_dict or "recurrent.block.attn.wq.weight" in state_dict:
1559
+ if "recurrent.lora.scale.weight" in state_dict:
1560
+ return state_dict["recurrent.lora.scale.weight"].shape[0]
1561
+ return 1
1562
+ return 0
1563
+
1564
+
1565
+ def _detect_recurrent_lora_rank(state_dict) -> int:
1566
+ for key in ("recurrent.lora.B", "recurrent.lora.down.weight"):
1567
+ if key in state_dict:
1568
+ shape = state_dict[key].shape
1569
+ if len(shape) == 2:
1570
+ return int(shape[0])
1571
+ return 0
1572
+
1573
+
1574
+ def _infer_series_from_lora_rank(rank: int) -> str | None:
1575
+ if rank == 0:
1576
+ return None
1577
+ if rank <= 8:
1578
+ return "haiku"
1579
+ if rank <= 16:
1580
+ return "sonnet"
1581
+ return "opus"
1582
+
1583
+
1584
+ def _fast_series_from_checkpoint(ckpt_path: Path) -> str | None:
1585
+ try:
1586
+ cp = torch.load(ckpt_path, map_location="cpu", weights_only=False)
1587
+ sd = cp.get("model_state", cp.get("state_dict", {}))
1588
+ rank = 0
1589
+ for key in ("recurrent.lora.B", "recurrent.lora.down.weight"):
1590
+ if key in sd:
1591
+ rank = int(sd[key].shape[0])
1592
+ break
1593
+ if rank == 0:
1594
+ return None
1595
+ if rank <= 8:
1596
+ return "Haiku"
1597
+ if rank <= 16:
1598
+ return "Sonnet"
1599
+ return "Opus"
1600
+ except Exception:
1601
+ pass
1602
+ return None
1603
+
1604
+
1605
  def _infer_arch_from_state_dict(state_dict, cfg):
1606
  """Infer architecture hyper-parameters directly from checkpoint weights,
1607
  falling back to *cfg* (series config) when a key is not found."""
1608
  overrides = {}
1609
 
1610
+ has_prelude = any(k.startswith("prelude.") for k in state_dict)
1611
+ has_blocks = any(k.startswith("blocks.") for k in state_dict)
1612
+ has_recurrent = any(k.startswith("recurrent.") for k in state_dict)
1613
+ uses_recurrent_arch = has_prelude and has_recurrent and not has_blocks
1614
+
1615
  # dim from embed_tokens.weight [vocab, dim]
1616
  if "embed_tokens.weight" in state_dict:
1617
  overrides["dim"] = state_dict["embed_tokens.weight"].shape[1]
1618
 
1619
+ if uses_recurrent_arch:
1620
+ if "prelude.0.ffn.gate.weight" in state_dict:
1621
+ overrides["ffn_dim"] = state_dict["prelude.0.ffn.gate.weight"].shape[0]
1622
+ overrides["n_unique_layers"] = 0
1623
+ src = "prelude.0"
1624
+ else:
1625
+ if "blocks.0.ffn.gate.weight" in state_dict:
1626
+ overrides["ffn_dim"] = state_dict["blocks.0.ffn.gate.weight"].shape[0]
1627
+ block_ids = {
1628
+ int(k.split(".")[1])
1629
+ for k in state_dict
1630
+ if k.startswith("blocks.") and k.split(".")[1].isdigit()
1631
+ }
1632
+ if block_ids:
1633
+ overrides["n_unique_layers"] = max(block_ids) + 1
1634
+ src = "blocks.0"
1635
 
 
1636
  dim = overrides.get("dim", int(cfg.get("dim", model_config.dim)))
1637
+ if f"{src}.attn.wq.weight" in state_dict:
1638
+ wq_rows = state_dict[f"{src}.attn.wq.weight"].shape[0]
1639
+ if f"{src}.attn.q_norm.weight" in state_dict:
1640
+ head_dim = state_dict[f"{src}.attn.q_norm.weight"].shape[0]
1641
  overrides["n_heads"] = wq_rows // head_dim
1642
+ if f"{src}.attn.wk.weight" in state_dict:
1643
+ wk_rows = state_dict[f"{src}.attn.wk.weight"].shape[0]
1644
+ if f"{src}.attn.k_norm.weight" in state_dict:
1645
+ head_dim = state_dict[f"{src}.attn.k_norm.weight"].shape[0]
1646
  overrides["n_kv_heads"] = wk_rows // head_dim
1647
 
1648
+ # engram params
1649
  for key, val in state_dict.items():
1650
  if ".engram.embeddings." in key and key.endswith("_0") and val.dim() == 2:
1651
  overrides["engram_table_size"] = val.shape[0]
1652
  overrides["engram_dim"] = val.shape[1]
1653
  break
 
 
1654
  engram_dim = overrides.get("engram_dim", int(cfg.get("engram_dim", 0)))
1655
  engram_max_ngram = int(cfg.get("engram_max_ngram", 2))
1656
  if engram_dim > 0:
 
1662
  overrides["engram_heads"] = total_branch_dim // denom
1663
  break
1664
 
 
1665
  merged = dict(cfg)
1666
  merged.update(overrides)
1667
  return merged
 
1675
 
1676
  state_dict = ckpt.get("model_state") or ckpt.get("state_dict") or ckpt
1677
 
 
 
1678
  cfg = _infer_arch_from_state_dict(state_dict, cfg)
1679
 
1680
  engram_dim = int(cfg.get("engram_dim", 0))
 
1685
  if mhc_expansion == 1:
1686
  mhc_expansion = int(cfg.get("mhc_expansion", 1))
1687
 
1688
+ ckpt_sleep_cap, ckpt_sleep_heads = _detect_sleep_gate(state_dict)
1689
+ sleep_gate_cap = ckpt_sleep_cap if ckpt_sleep_cap > 0 else int(cfg.get("sleep_gate_cap", 0))
1690
+ sleep_gate_heads = ckpt_sleep_heads if ckpt_sleep_cap > 0 else int(cfg.get("sleep_gate_heads", 4))
1691
+ sleep_retention_enabled = bool(cfg.get("sleep_retention_enabled", True))
1692
+ sleep_retention_hidden = int(cfg.get("sleep_retention_hidden", 0))
1693
+
1694
+ latent_think_layers = _detect_latent_think(state_dict)
1695
+ if latent_think_layers == 0:
1696
+ latent_think_layers = int(cfg.get("latent_think_layers", 0))
1697
+
1698
+ prelude_layers = _detect_prelude_layers(state_dict)
1699
+ coda_layers = _detect_coda_layers(state_dict)
1700
+ recurrent_loops = _detect_recurrent_loops(state_dict)
1701
+
1702
+ ckpt_lora_rank = _detect_recurrent_lora_rank(state_dict)
1703
+ if ckpt_lora_rank > 0:
1704
+ inferred_series = _infer_series_from_lora_rank(ckpt_lora_rank)
1705
+ if inferred_series and inferred_series != series.lower():
1706
+ series = inferred_series.capitalize()
1707
+ cfg = series_config(series)
1708
+ recurrent_lora_rank = ckpt_lora_rank
1709
+ else:
1710
+ recurrent_lora_rank = int(cfg.get("recurrent_lora_rank", 0))
1711
+
1712
+ recurrent_act_threshold = float(cfg.get("recurrent_act_threshold", 0.99))
1713
+ recurrent_loop_embed_dim = int(cfg.get("recurrent_loop_embed_dim", 0))
1714
+
1715
+ n_unique = int(cfg.get("n_unique_layers", model_config.n_unique_layers))
1716
+
1717
  model = TinyMemoryLM(
1718
  vocab_size=vocab_size,
1719
  dim=int(cfg.get("dim", model_config.dim)),
1720
+ n_unique_layers=n_unique,
1721
+ n_logical_layers=int(cfg.get("n_logical_layers", model_config.n_logical_layers)),
 
 
1722
  n_heads=int(cfg.get("n_heads", model_config.n_heads)),
1723
  n_kv_heads=int(cfg.get("n_kv_heads", model_config.n_kv_heads)),
1724
  ffn_dim=int(cfg.get("ffn_dim", model_config.ffn_dim)),
1725
  dropout=float(cfg.get("dropout", model_config.dropout)),
1726
+ mtp_horizons=tuple(int(v) for v in cfg.get("mtp_horizons", model_config.mtp_horizons)),
 
 
1727
  grad_checkpoint=False,
1728
+ sliding_window=int(cfg.get("sliding_window_size", getattr(model_config, "sliding_window_size", 512))),
1729
+ rope_fraction=float(cfg.get("rope_fraction", getattr(model_config, "rope_fraction", 0.25))),
1730
+ embed_scale=bool(cfg.get("embed_scale", getattr(model_config, "embed_scale", True))),
 
 
 
 
 
 
 
 
 
1731
  engram_dim=engram_dim,
1732
  engram_heads=int(cfg.get("engram_heads", 4)),
1733
  engram_table_size=int(cfg.get("engram_table_size", 8192)),
1734
  engram_max_ngram=int(cfg.get("engram_max_ngram", 3)),
1735
  mhc_expansion=mhc_expansion,
1736
+ sleep_gate_cap=sleep_gate_cap,
1737
+ sleep_gate_heads=sleep_gate_heads,
1738
+ sleep_retention_enabled=sleep_retention_enabled,
1739
+ sleep_retention_hidden=sleep_retention_hidden,
1740
+ latent_think_layers=latent_think_layers,
1741
+ prelude_layers=prelude_layers,
1742
+ coda_layers=coda_layers,
1743
+ recurrent_loops=recurrent_loops,
1744
+ recurrent_act_threshold=recurrent_act_threshold,
1745
+ recurrent_lora_rank=recurrent_lora_rank,
1746
+ recurrent_loop_embed_dim=recurrent_loop_embed_dim,
1747
  )
1748
  model.load_state_dict(state_dict, strict=False)
1749
  model.eval()
 
1756
  "tokenizer": tokenizer,
1757
  "device": device,
1758
  "series": series,
1759
+ "sft_mode": ckpt.get("sft_mode", None),
1760
+ "phase": ckpt.get("phase", None),
1761
  }
1762
 
1763
 
 
1781
 
1782
  print(f"Using cached {hf_id} from {local_dir}")
1783
 
1784
+ # Check common subdirectory names: "models/", "model/"
1785
+ if (local_dir / "models").exists():
1786
+ model_dir = local_dir / "models"
1787
+ elif (local_dir / "model").exists():
1788
+ model_dir = local_dir / "model"
1789
+ else:
1790
+ model_dir = local_dir
1791
  model_path = model_dir / "model.pt"
1792
  pretrain_path = model_dir / "pretrain.pt"
1793
  tokenizer_path = model_dir / "tokenizer.json"
 
1941
  print(f"\n{'='*60}")
1942
 
1943
 
1944
+ # ---------------------------------------------------------------------------
1945
+ # Benchmark
1946
+ # ---------------------------------------------------------------------------
1947
+
1948
+ BENCHMARKS = {
1949
+ "blimp": {
1950
+ "label": "BLiMP",
1951
+ "desc": "Grammaticality minimal pairs (67 paradigms). Accuracy = % grammatical < ungrammatical perplexity.",
1952
+ "hf_dataset": ("nyu-mll/blimp", None),
1953
+ "metric": "accuracy",
1954
+ },
1955
+ "wikitext2": {
1956
+ "label": "WikiText-2",
1957
+ "desc": "LM perplexity on Wikipedia test split. Lower is better.",
1958
+ "hf_dataset": ("Salesforce/wikitext", "wikitext-2-raw-v1"),
1959
+ "metric": "perplexity",
1960
+ },
1961
+ "arc_easy": {
1962
+ "label": "ARC-Easy",
1963
+ "desc": "Multiple-choice science QA (~2.4K). Perplexity-based answer selection.",
1964
+ "hf_dataset": ("allenai/ai2_arc", "ARC-Easy"),
1965
+ "metric": "accuracy",
1966
+ },
1967
+ }
1968
+
1969
+
1970
+ def _score_text(model: TinyMemoryLM, tokenizer: WordTokenizer, text: str, device: str) -> float:
1971
+ ids = tokenizer.encode(text, add_bos=True, add_eos=False)
1972
+ if len(ids) < 2:
1973
+ return float("nan")
1974
+ ids_t = torch.tensor([ids], dtype=torch.long, device=device)
1975
+ with torch.no_grad():
1976
+ logits, *_ = model(ids_t)
1977
+ log_probs = F.log_softmax(logits[0], dim=-1)
1978
+ targets = ids_t[0, 1:]
1979
+ nll = -log_probs[range(len(targets)), targets].mean().item()
1980
+ return nll
1981
+
1982
+
1983
+ def _score_completion(model: TinyMemoryLM, tokenizer: WordTokenizer, context: str, completion: str, device: str) -> float:
1984
+ full_ids = tokenizer.encode(context + completion, add_bos=True, add_eos=False)
1985
+ ctx_ids = tokenizer.encode(context, add_bos=True, add_eos=False)
1986
+ n_ctx = len(ctx_ids)
1987
+ n_ref = len(full_ids) - n_ctx
1988
+ if n_ref <= 0:
1989
+ return float("nan")
1990
+ ids_t = torch.tensor([full_ids], dtype=torch.long, device=device)
1991
+ with torch.no_grad():
1992
+ logits, *_ = model(ids_t)
1993
+ log_probs = F.log_softmax(logits[0], dim=-1)
1994
+ targets = ids_t[0, 1:]
1995
+ ref_start = n_ctx - 1
1996
+ ref_end = min(ref_start + n_ref, log_probs.shape[0])
1997
+ if ref_start >= ref_end:
1998
+ return float("nan")
1999
+ nll = -log_probs[ref_start:ref_end][range(ref_end - ref_start), targets[ref_start:ref_end]].mean().item()
2000
+ return nll
2001
+
2002
+
2003
+ BLIMP_PARADIGMS = [
2004
+ "adjunct_island", "anaphor_gender_agreement", "anaphor_number_agreement",
2005
+ "animate_subject_passive", "animate_subject_trans", "causative",
2006
+ "complex_NP_island", "coordinate_structure_constraint_complex_left_branch",
2007
+ "coordinate_structure_constraint_object_extraction",
2008
+ "determiner_noun_agreement_1", "determiner_noun_agreement_2",
2009
+ "determiner_noun_agreement_irregular_1", "determiner_noun_agreement_irregular_2",
2010
+ "determiner_noun_agreement_with_adj_2", "determiner_noun_agreement_with_adj_irregular_1",
2011
+ "determiner_noun_agreement_with_adj_irregular_2", "determiner_noun_agreement_with_adjective_1",
2012
+ "distractor_agreement_relational_noun", "distractor_agreement_relative_clause",
2013
+ "drop_argument", "ellipsis_n_bar_1", "ellipsis_n_bar_2",
2014
+ "existential_there_object_raising", "existential_there_quantifiers_1",
2015
+ "existential_there_quantifiers_2", "existential_there_subject_raising",
2016
+ "expletive_it_object_raising", "inchoative", "intransitive",
2017
+ "irregular_past_participle_adjectives", "irregular_past_participle_verbs",
2018
+ "irregular_plural_subject_verb_agreement_1", "irregular_plural_subject_verb_agreement_2",
2019
+ "left_branch_island_echo_question", "left_branch_island_simple_question",
2020
+ "matrix_question_npi_licensor_present", "npi_present_1", "npi_present_2",
2021
+ "only_npi_licensor_present", "only_npi_scope", "passive_1", "passive_2",
2022
+ "principle_A_c_command", "principle_A_case_1", "principle_A_case_2",
2023
+ "principle_A_domain_1", "principle_A_domain_2", "principle_A_domain_3",
2024
+ "principle_A_reconstruction", "regular_plural_subject_verb_agreement_1",
2025
+ "regular_plural_subject_verb_agreement_2", "sentential_negation_npi_licensor_present",
2026
+ "sentential_negation_npi_scope", "sentential_subject_island",
2027
+ "superlative_quantifiers_1", "superlative_quantifiers_2",
2028
+ "tough_vs_raising_1", "tough_vs_raising_2", "transitive", "wh_island",
2029
+ "wh_questions_object_gap", "wh_questions_subject_gap",
2030
+ "wh_questions_subject_gap_long_distance", "wh_vs_that_no_gap",
2031
+ "wh_vs_that_no_gap_long_distance", "wh_vs_that_with_gap",
2032
+ "wh_vs_that_with_gap_long_distance",
2033
+ ]
2034
+
2035
+
2036
+ def _run_blimp(model, tokenizer, device, n_samples: int = 200) -> Tuple[List[str], List[float]]:
2037
+ from datasets import load_dataset # type: ignore
2038
+ accuracies: List[float] = []
2039
+ for paradigm in BLIMP_PARADIGMS:
2040
+ try:
2041
+ ds = load_dataset("nyu-mll/blimp", paradigm, split="train")
2042
+ except Exception as e:
2043
+ print(f" {paradigm}: skip ({e})")
2044
+ accuracies.append(float("nan"))
2045
+ continue
2046
+ items = list(ds)[:n_samples]
2047
+ correct = 0
2048
+ for ex in items:
2049
+ good_nll = _score_text(model, tokenizer, ex["sentence_good"], device)
2050
+ bad_nll = _score_text(model, tokenizer, ex["sentence_bad"], device)
2051
+ if math.isnan(good_nll) or math.isnan(bad_nll):
2052
+ continue
2053
+ if good_nll < bad_nll:
2054
+ correct += 1
2055
+ acc = correct / len(items) if items else float("nan")
2056
+ accuracies.append(acc)
2057
+ print(f" {paradigm:50s} acc={acc:.3f}")
2058
+ return BLIMP_PARADIGMS, accuracies
2059
+
2060
+
2061
+ def _run_wikitext2(model, tokenizer, device, chunk_chars: int = 512, max_chunks: int = 100) -> Tuple[List[str], List[float]]:
2062
+ from datasets import load_dataset # type: ignore
2063
+ ds = load_dataset("Salesforce/wikitext", "wikitext-2-raw-v1", split="test")
2064
+ full_text = "\n".join(ex["text"] for ex in ds if ex["text"].strip())
2065
+ chunks = [full_text[i:i + chunk_chars] for i in range(0, len(full_text), chunk_chars)]
2066
+ chunks = [c for c in chunks if len(c) > 20][:max_chunks]
2067
+ labels: List[str] = []
2068
+ ppls: List[float] = []
2069
+ for i, chunk in enumerate(chunks):
2070
+ nll = _score_text(model, tokenizer, chunk, device)
2071
+ ppl = math.exp(nll) if not math.isnan(nll) else float("nan")
2072
+ labels.append(f"chunk {i + 1}")
2073
+ ppls.append(ppl)
2074
+ if (i + 1) % 10 == 0:
2075
+ valid = [v for v in ppls if not math.isnan(v)]
2076
+ mean = sum(valid) / len(valid) if valid else float("nan")
2077
+ print(f" chunk {i + 1}/{len(chunks)} running mean ppl={mean:.2f}")
2078
+ return labels, ppls
2079
+
2080
+
2081
+ def _run_arc_easy(model, tokenizer, device, max_samples: int = 200) -> Tuple[List[str], List[float]]:
2082
+ from datasets import load_dataset # type: ignore
2083
+ ds = load_dataset("allenai/ai2_arc", "ARC-Easy", split="test")
2084
+ items = list(ds)[:max_samples]
2085
+ labels: List[str] = []
2086
+ scores: List[float] = []
2087
+ for i, ex in enumerate(items):
2088
+ question = ex["question"]
2089
+ choices = ex["choices"]["text"]
2090
+ choice_labels = ex["choices"]["label"]
2091
+ answer_key = ex["answerKey"]
2092
+ context = f"Question: {question}\nAnswer:"
2093
+ nlls = [_score_completion(model, tokenizer, context, f" {c}", device) for c in choices]
2094
+ if all(math.isnan(v) for v in nlls):
2095
+ scores.append(float("nan"))
2096
+ else:
2097
+ best_idx = min(range(len(nlls)), key=lambda j: nlls[j] if not math.isnan(nlls[j]) else float("inf"))
2098
+ predicted = choice_labels[best_idx]
2099
+ scores.append(1.0 if predicted == answer_key else 0.0)
2100
+ labels.append(f"Q{i + 1}")
2101
+ n_valid = sum(1 for s in scores if not math.isnan(s))
2102
+ acc = sum(s for s in scores if not math.isnan(s)) / n_valid if n_valid else float("nan")
2103
+ print(f" {n_valid} questions evaluated, accuracy={acc:.3f}")
2104
+ return labels, scores
2105
+
2106
+
2107
+ def run_benchmark_mode() -> None:
2108
+ try:
2109
+ import matplotlib
2110
+ matplotlib.use("Agg")
2111
+ import matplotlib.pyplot as plt
2112
+ except ImportError:
2113
+ print("matplotlib not installed. pip install matplotlib")
2114
+ return
2115
+
2116
+ bench_keys = list(BENCHMARKS.keys())
2117
+ print("\nBenchmarks:")
2118
+ for i, k in enumerate(bench_keys):
2119
+ b = BENCHMARKS[k]
2120
+ print(f" [{i + 1}] {b['label']} — {b['desc']}")
2121
+ print("Select benchmark [1]:", end=" ", flush=True)
2122
+ try:
2123
+ b_choice = input().strip() or "1"
2124
+ except (EOFError, KeyboardInterrupt):
2125
+ print()
2126
+ return
2127
+ if not (b_choice.isdigit() and 1 <= int(b_choice) <= len(bench_keys)):
2128
+ print("Invalid selection.")
2129
+ return
2130
+ bench_key = bench_keys[int(b_choice) - 1]
2131
+ bench = BENCHMARKS[bench_key]
2132
+ print(f"Benchmark: {bench['label']}")
2133
+
2134
+ root = Path(__file__).resolve().parent
2135
+ runs_dir = root / "runs"
2136
+ all_models = discover_models(runs_dir)
2137
+
2138
+ model_entries: List[dict] = []
2139
+ for m in all_models:
2140
+ model_entries.append({"label": f"[LOCAL] {m['name']}/{m['checkpoint']}", "type": "local", "meta": m})
2141
+ for hf_name, hf_id in HUGGINGFACE_MODELS.items():
2142
+ model_entries.append({"label": f"[HF] {hf_name}", "type": "hf", "hf_id": hf_id, "hf_name": hf_name})
2143
+
2144
+ if not model_entries:
2145
+ print("No models found.")
2146
+ return
2147
+
2148
+ print("\nAvailable models:")
2149
+ for i, e in enumerate(model_entries):
2150
+ print(f" [{i + 1}] {e['label']}")
2151
+ print(" [a] All models")
2152
+ print("Select models (comma-separated or 'a'):", end=" ", flush=True)
2153
+ try:
2154
+ raw = input().strip()
2155
+ except (EOFError, KeyboardInterrupt):
2156
+ print()
2157
+ return
2158
+
2159
+ if raw.lower() == "a":
2160
+ selected = list(range(len(model_entries)))
2161
+ else:
2162
+ selected = []
2163
+ for tok in raw.split(","):
2164
+ tok = tok.strip()
2165
+ if tok.isdigit() and 1 <= int(tok) <= len(model_entries):
2166
+ selected.append(int(tok) - 1)
2167
+ if not selected:
2168
+ print("No valid selection.")
2169
+ return
2170
+
2171
+ all_results: List[dict] = []
2172
+ shared_x_labels: Optional[List[str]] = None
2173
+
2174
+ for idx in selected:
2175
+ entry = model_entries[idx]
2176
+ print(f"\n{'='*60}\nLoading {entry['label']}...")
2177
+ try:
2178
+ if entry["type"] == "local":
2179
+ m = entry["meta"]
2180
+ bundle = load_local_model(m["model_path"], m["tokenizer_path"], m["series"])
2181
+ else:
2182
+ bundle = load_huggingface_model(entry["hf_id"], root / ".hf_cache")
2183
+ except Exception as e:
2184
+ print(f" Failed: {e}")
2185
+ continue
2186
+
2187
+ model = bundle["model"]
2188
+ tokenizer = bundle["tokenizer"]
2189
+ device = str(bundle["device"])
2190
+ model.eval()
2191
+
2192
+ if bench_key == "blimp":
2193
+ x_labels, y_vals = _run_blimp(model, tokenizer, device)
2194
+ elif bench_key == "wikitext2":
2195
+ x_labels, y_vals = _run_wikitext2(model, tokenizer, device)
2196
+ else:
2197
+ x_labels, y_vals = _run_arc_easy(model, tokenizer, device)
2198
+
2199
+ if shared_x_labels is None:
2200
+ shared_x_labels = x_labels
2201
+
2202
+ valid = [v for v in y_vals if not math.isnan(v)]
2203
+ summary = sum(valid) / len(valid) if valid else float("nan")
2204
+ all_results.append({"label": entry["label"], "y": y_vals, "summary": summary})
2205
+
2206
+ if not all_results or shared_x_labels is None:
2207
+ print("No results to plot.")
2208
+ return
2209
+
2210
+ metric = bench["metric"]
2211
+ paired = sorted(zip([r["summary"] for r in all_results], [r["label"] for r in all_results]),
2212
+ reverse=(metric != "perplexity"))
2213
+ summaries, model_labels = zip(*paired) if paired else ([], [])
2214
+ n = len(summaries)
2215
+ colors = [plt.cm.RdYlGn(i / max(n - 1, 1)) for i in range(n)]
2216
+
2217
+ fig, ax = plt.subplots(figsize=(max(6, n * 1.4), 6))
2218
+ bars = ax.bar(range(n), summaries, color=colors, edgecolor="black")
2219
+ for bar, val in zip(bars, summaries):
2220
+ ax.text(bar.get_x() + bar.get_width() / 2, bar.get_height() + 0.005,
2221
+ f"{val:.3f}", ha="center", va="bottom", fontsize=9, fontweight="bold")
2222
+
2223
+ ylabel = "Mean Perplexity (↓ better)" if metric == "perplexity" else "Mean Accuracy (↑ better)"
2224
+ ax.set_ylabel(ylabel)
2225
+ ax.set_title(f"{bench['label']} Benchmark — Model Comparison")
2226
+ ax.set_xticks(range(n))
2227
+ ax.set_xticklabels(model_labels, rotation=20, ha="right", fontsize=9)
2228
+ if metric == "accuracy":
2229
+ ax.set_ylim(0, 1.05)
2230
+ ax.grid(True, axis="y", alpha=0.3)
2231
+ plt.tight_layout()
2232
+
2233
+ out_path = root / f"benchmark_{bench_key}.png"
2234
+ plt.savefig(str(out_path), dpi=150)
2235
+ print(f"\nChart saved to {out_path}")
2236
+ try:
2237
+ import subprocess
2238
+ subprocess.Popen(["xdg-open", str(out_path)])
2239
+ except Exception:
2240
+ pass
2241
+
2242
+
2243
  # ---------------------------------------------------------------------------
2244
  # Interactive CLI
2245
  # ---------------------------------------------------------------------------
 
2310
  # ---------------------------------------------------------------------------
2311
 
2312
  MODES = {
 
2313
  "chat-coherent": {
2314
  "label": "Chat — Coherent",
2315
  "desc": "structured, consistent, strong repetition control",
2316
  "sft_mode": "chat",
2317
+ "temperature": 0.35,
2318
+ "top_k": 20,
2319
+ "top_p": 0.88,
2320
+ "min_p": 0.10,
2321
+ "no_repeat_ngram_size": 4,
2322
+ "repetition_penalty": 1.22,
2323
+ "logit_soft_cap": 20.0,
2324
+ "loop_penalty": 20.0,
2325
+ "max_new_tokens": 4096,
2326
  "context_window": 2048,
2327
  },
2328
  "chat-variants": {
2329
  "label": "Chat — Variants",
2330
  "desc": "creative, diverse, more surprising outputs",
2331
  "sft_mode": "chat",
2332
+ "temperature": 0.65,
2333
+ "top_k": 60,
2334
+ "top_p": 0.92,
2335
+ "min_p": 0.05,
2336
+ "no_repeat_ngram_size": 4,
2337
+ "repetition_penalty": 1.12,
2338
+ "logit_soft_cap": 20.0,
2339
+ "loop_penalty": 14.0,
2340
+ "max_new_tokens": 4096,
2341
  "context_window": 2048,
2342
  },
 
2343
  "pretrain-coherent": {
2344
  "label": "Pretrain — Coherent",
2345
  "desc": "grounded continuation, low temperature, tight sampling",
2346
  "sft_mode": False,
2347
+ "temperature": 0.3,
2348
  "top_k": 20,
2349
  "top_p": 0.85,
2350
  "min_p": 0.10,
2351
+ "no_repeat_ngram_size": 4,
2352
  "repetition_penalty": 1.2,
2353
+ "logit_soft_cap": 20.0,
2354
+ "loop_penalty": 20.0,
2355
+ "max_new_tokens": 4096,
2356
  "context_window": 2048,
2357
  },
2358
  "pretrain-variants": {
2359
  "label": "Pretrain — Variants",
2360
  "desc": "free-form continuation, higher temperature, more exploration",
2361
  "sft_mode": False,
2362
+ "temperature": 0.7,
2363
  "top_k": 60,
2364
+ "top_p": 0.93,
2365
+ "min_p": 0.04,
2366
  "no_repeat_ngram_size": 4,
2367
+ "repetition_penalty": 1.12,
2368
+ "logit_soft_cap": 20.0,
2369
  "loop_penalty": 12.0,
2370
+ "max_new_tokens": 4096,
2371
  "context_window": 2048,
2372
  },
2373
  }
 
2465
  {"version": "TMLM-Haiku-2", "hf_id": "CompactAI-O/TMLM-Haiku-2"},
2466
  {"version": "TMLM-Haiku-1.3", "hf_id": "CompactAI-O/TMLM-Haiku-1.3"},
2467
  {"version": "TMLM-Haiku-1", "hf_id": "CompactAI-O/TMLM-Haiku-1"},
2468
+ {"version": "Glint-1", "hf_id": "CompactAI-O/Glint-1"},
2469
  ]
2470
 
2471
+ _EXTRA_REPOS = ["CompactAI-O/Glint-1"]
2472
+
2473
 
2474
  def _probe_repo(hf_id: str) -> dict | None:
2475
  """Return entry dict for one repo, or None if no usable checkpoints found."""
 
2497
 
2498
  _LABELS = {
2499
  "model.pt": ("Chat (SFT)", False),
2500
+ "model_rep.pt": ("Chat (anti-repetition)", False),
2501
  "pretrain.pt": ("Pretrain (base)", True),
2502
  }
2503
 
 
2538
  infos = [type("M", (), {"id": e["hf_id"]})() for e in _FALLBACK_COLLECTION]
2539
 
2540
  entries = []
2541
+ seen_ids: set = set()
2542
  for info in infos:
2543
  repo_id = info.id
2544
  if _SEARCH.lower() not in repo_id.lower():
 
2546
  entry = _probe_repo(repo_id)
2547
  if entry:
2548
  entries.append(entry)
2549
+ seen_ids.add(repo_id)
2550
+
2551
+ # Always include extra repos (e.g. Glint-1) not caught by TMLM-Haiku search
2552
+ for repo_id in _EXTRA_REPOS:
2553
+ if repo_id not in seen_ids:
2554
+ entry = _probe_repo(repo_id)
2555
+ if entry:
2556
+ entries.append(entry)
2557
+ seen_ids.add(repo_id)
2558
 
2559
  if not entries:
2560
  print(" No models found; using fallback list.")
 
2561
  for fb in _FALLBACK_COLLECTION:
2562
  e = _probe_repo(fb["hf_id"])
2563
  if e:
 
2645
 
2646
 
2647
  def main() -> None:
2648
+ import argparse
2649
+ parser = argparse.ArgumentParser()
2650
+ parser.add_argument("--compare", "-c", action="store_true")
2651
+ parser.add_argument("--prompt", "-p", type=str, default="Hello")
2652
+ mode_group = parser.add_mutually_exclusive_group()
2653
+ mode_group.add_argument("--pretrain", action="store_true")
2654
+ mode_group.add_argument("--sft", action="store_true")
2655
+ args, _ = parser.parse_known_args()
2656
+
2657
  print("=" * 56)
2658
+ print(" CompactAI-O Interactive Chat")
2659
  print(" Models: huggingface.co/CompactAI-O")
2660
  print("=" * 56)
2661
 
2662
+ if args.compare:
2663
+ prefetch_huggingface_models()
2664
+ cfg = pick_mode(is_pretrain=args.pretrain)
2665
+ prompt_label = "You" if cfg["sft_mode"] else "Prompt"
2666
+ while True:
2667
+ print(f"{prompt_label}:", end=" ", flush=True)
2668
+ prompt = sys.stdin.readline().strip()
2669
+ if not prompt or prompt in ("/quit", "/exit", "/q"):
2670
+ break
2671
+ compare_all_models(prompt, cfg)
2672
+ return
2673
+
2674
  collection = fetch_collection()
2675
  if not collection:
2676
  print("No models found. Check your internet connection.")
 
2679
  entry = pick_version(collection)
2680
  fname, is_pretrain = pick_checkpoint(entry)
2681
 
2682
+ if args.pretrain:
2683
+ is_pretrain = True
2684
+ elif args.sft:
2685
+ is_pretrain = False
2686
+
2687
  root = Path(__file__).resolve().parent
2688
  cache_dir = root / "cache" / "huggingface"
2689
  cache_dir.mkdir(parents=True, exist_ok=True)
 
2703
  print(f"Loading {entry['version']} / {fname} ...")
2704
  bundle = load_local_model(model_path, tokenizer_path, "Haiku")
2705
 
2706
+ # Use checkpoint-embedded sft_mode/phase if available
2707
+ sft_mode_flag = bundle.get("sft_mode")
2708
+ phase_flag = bundle.get("phase")
2709
+ if sft_mode_flag is not None and not args.pretrain and not args.sft:
2710
+ is_pretrain = not sft_mode_flag
2711
+ elif phase_flag is not None and not args.pretrain and not args.sft:
2712
+ is_pretrain = phase_flag == "pretrain"
2713
+
2714
+ print("\nChoose action:")
2715
+ print(" [1] Chat with this model")
2716
+ print(" [2] Compare ALL models (local + HuggingFace)")
2717
+ print(" [3] Run Benchmark (BLiMP / WikiText-2 / ARC-Easy)")
2718
+ print("Select [1]:", end=" ", flush=True)
2719
+ choice = sys.stdin.readline().strip() or "1"
2720
+
2721
+ if choice == "1":
2722
+ cfg = pick_mode(is_pretrain)
2723
+ _run_loop(bundle, cfg)
2724
+ elif choice == "2":
2725
+ print("\nDownloading/preparing HuggingFace models...")
2726
+ prefetch_huggingface_models()
2727
+ cfg = pick_mode(is_pretrain)
2728
+ prompt_label = "You" if cfg["sft_mode"] else "Prompt"
2729
+ while True:
2730
+ print(f"{prompt_label}:", end=" ", flush=True)
2731
+ prompt = sys.stdin.readline().strip()
2732
+ if not prompt or prompt in ("/quit", "/exit", "/q"):
2733
+ break
2734
+ compare_all_models(prompt, cfg)
2735
+ elif choice == "3":
2736
+ run_benchmark_mode()
2737
+ else:
2738
+ print("Enter 1, 2, or 3")
2739
 
2740
 
2741
  if __name__ == "__main__":
2742
  main()
2743
+