#!/usr/bin/env python3 """Public-facing TMLM-Haiku interactive CLI. Pulls models from the CompactAI-O HuggingFace collection: https://huggingface.co/collections/CompactAI-O/tmlm-haiku-series """ from __future__ import annotations #!/usr/bin/env python3 from __future__ import annotations import hashlib import json import math import os import string import sys from dataclasses import dataclass from pathlib import Path from typing import Dict, Iterator, List, Optional, Sequence, Tuple import torch import torch.nn as nn import torch.nn.functional as F from torch.utils.checkpoint import checkpoint HUGGINGFACE_MODELS = { "TMLM-Haiku-1": "CompactAI-O/TMLM-Haiku-1", "TMLM-Haiku-1.3": "CompactAI-O/TMLM-Haiku-1.3", "TMLM-Haiku-2": "CompactAI-O/TMLM-Haiku-2", "Glint-1": "CompactAI-O/Glint-1", } # --------------------------------------------------------------------------- # Config # --------------------------------------------------------------------------- @dataclass class ModelConfig: dim: int = 128 n_unique_layers: int = 8 n_logical_layers: int = 16 n_heads: int = 4 n_kv_heads: int = 2 ffn_dim: int = 224 dropout: float = 0.0 seq_len: int = 2048 sliding_window_size: int = 512 mtp_horizons: Tuple[int, ...] = (2, 3, 4) rope_fraction: float = 0.5 embed_scale: bool = True logit_soft_cap: float = -1.0 quantization: str = "nvfp4" @property def head_dim(self) -> int: return self.dim // self.n_heads model_config = ModelConfig() MODEL_SERIES = { "haiku": { "dim": 64, "n_unique_layers": 12, "n_logical_layers": 24, "n_heads": 4, "n_kv_heads": 2, "ffn_dim": 384, "dropout": 0.0, "seq_len": 2048, "sliding_window_size": 2048, "mtp_horizons": (), "rope_fraction": 0.5, "engram_dim": 8, "engram_heads": 2, "engram_table_size": 64, "engram_max_ngram": 2, "mhc_expansion": 2, "sleep_gate_cap": 0, "sleep_gate_heads": 4, "latent_think_layers": 0, "prelude_layers": 0, "coda_layers": 0, "recurrent_loops": 0, "recurrent_act_threshold": 0.9, "recurrent_lora_rank": 0, "recurrent_loop_embed_dim": 0, }, "sonnet": { "dim": 1024, "n_unique_layers": 20, "n_logical_layers": 40, "n_heads": 16, "n_kv_heads": 4, "ffn_dim": 4096, "dropout": 0.0, "seq_len": 2048, "mtp_horizons": (2,), "engram_dim": 32, "engram_heads": 8, "engram_table_size": 4096, "engram_max_ngram": 2, "mhc_expansion": 2, "sleep_gate_cap": 0, "sleep_gate_heads": 8, "latent_think_layers": 0, "prelude_layers": 0, "coda_layers": 0, "recurrent_loops": 0, "recurrent_act_threshold": 0.99, "recurrent_lora_rank": 0, "recurrent_loop_embed_dim": 0, }, "opus": { "dim": 1536, "n_unique_layers": 18, "n_logical_layers": 36, "n_heads": 16, "n_kv_heads": 4, "ffn_dim": 5888, "dropout": 0.0, "seq_len": 2048, "mtp_horizons": (2,), "engram_dim": 64, "engram_heads": 8, "engram_table_size": 8192, "engram_max_ngram": 2, "mhc_expansion": 4, "sleep_gate_cap": 0, "sleep_gate_heads": 8, "latent_think_layers": 0, "prelude_layers": 0, "coda_layers": 0, "recurrent_loops": 0, "recurrent_act_threshold": 0.99, "recurrent_lora_rank": 0, "recurrent_loop_embed_dim": 0, }, } # --------------------------------------------------------------------------- # Tokenizer # --------------------------------------------------------------------------- FORMAT_TOKENS = [ "<|user|>", "<|assistant|>", "<|system|>", "<|start_header_id|>", "<|end_header_id|>", "<|begin_of_thought|>", "<|end_of_thought|>", "<|begin_of_solution|>", "<|end_of_solution|>", ] class WordTokenizer: def __init__( self, extra_chars: str = "", format_tokens: Optional[List[str]] = None ) -> None: base = string.ascii_letters + string.digits + string.punctuation + " \n\t\r" fallback_chars = sorted(set(base + extra_chars)) self.core_special = ["", "", "", ""] self.format_tokens = ( list(format_tokens) if format_tokens else list(FORMAT_TOKENS) ) self.special = list(self.core_special) + list(self.format_tokens) self.id_to_token: List[str] = ( list(self.core_special) + self.format_tokens + fallback_chars ) self.token_to_id: Dict[str, int] = { t: i for i, t in enumerate(self.id_to_token) } self.special_multi_tokens = sorted( [t for t in self.special if len(t) > 1], key=len, reverse=True ) self.multi_char_tokens = self.special_multi_tokens self.dynamic_additions = 0 @property def pad_id(self) -> int: return self.token_to_id[""] @property def bos_id(self) -> int: return self.token_to_id[""] @property def eos_id(self) -> int: return self.token_to_id[""] @property def unk_id(self) -> int: return self.token_to_id[""] @property def vocab_size(self) -> int: return len(self.id_to_token) def maybe_add_char(self, ch: str) -> bool: if ch in self.token_to_id: return False self.token_to_id[ch] = len(self.id_to_token) self.id_to_token.append(ch) self.dynamic_additions += 1 return True def iter_lexical_tokens(self, text: str) -> Iterator[str]: i = 0 n = len(text) while i < n: matched_special = False for token in self.special_multi_tokens: if text.startswith(token, i): yield token i += len(token) matched_special = True break if matched_special: continue yield text[i] i += 1 def encode( self, text: str, add_bos: bool = False, add_eos: bool = False ) -> List[int]: out: List[int] = [] if add_bos: out.append(self.bos_id) unk = self.unk_id t2i = self.token_to_id for tok in self.iter_lexical_tokens(text): out.append(t2i.get(tok, unk)) if add_eos: out.append(self.eos_id) return out def decode(self, ids: Sequence[int], skip_special: bool = True) -> str: pieces: List[str] = [] for idx in ids: if int(idx) < 0 or int(idx) >= len(self.id_to_token): continue tok = self.id_to_token[int(idx)] if skip_special and tok in self.special: continue pieces.append(tok) return "".join(pieces) @classmethod def load(cls, path: Path) -> WordTokenizer: with path.open("r", encoding="utf-8") as f: data = json.load(f) format_tokens = data.get("format_tokens", FORMAT_TOKENS) tokenizer = cls(extra_chars="", format_tokens=format_tokens) tokenizer.id_to_token = data["id_to_token"] tokenizer.token_to_id = {t: i for i, t in enumerate(tokenizer.id_to_token)} tokenizer.special = list(tokenizer.core_special) + list(tokenizer.format_tokens) tokenizer.special_multi_tokens = sorted( [t for t in tokenizer.special if len(t) > 1], key=len, reverse=True ) tokenizer.multi_char_tokens = tokenizer.special_multi_tokens return tokenizer LetterTokenizer = WordTokenizer # --------------------------------------------------------------------------- # Model # --------------------------------------------------------------------------- class RMSNorm(nn.Module): def __init__(self, dim: int, eps: float = 1e-6) -> None: super().__init__() self.weight = nn.Parameter(torch.ones(dim)) self.eps = eps def forward(self, x: torch.Tensor) -> torch.Tensor: if hasattr(torch.nn.functional, "rms_norm"): return torch.nn.functional.rms_norm( x, self.weight.shape, self.weight, self.eps ) x_fp = x.float() rms = torch.rsqrt(x_fp.pow(2).mean(dim=-1, keepdim=True) + self.eps) return (x_fp * rms).to(dtype=x.dtype) * self.weight class RotaryEmbedding(nn.Module): def __init__(self, dim: int, base: float = 10000.0) -> None: super().__init__() inv = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)) self.register_buffer("inv_freq", inv, persistent=False) def cos_sin( self, seq_len: int, device: torch.device, dtype: torch.dtype ) -> Tuple[torch.Tensor, torch.Tensor]: t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) freqs = torch.outer(t, self.inv_freq) emb = torch.cat([freqs, freqs], dim=-1) cos = emb.cos()[None, None, :, :].to(dtype=dtype) sin = emb.sin()[None, None, :, :].to(dtype=dtype) return cos, sin def _rotate_half(x: torch.Tensor) -> torch.Tensor: x1 = x[..., : x.shape[-1] // 2] x2 = x[..., x.shape[-1] // 2 :] return torch.cat((-x2, x1), dim=-1) class CausalSelfAttention(nn.Module): def __init__( self, dim: int, n_heads: int, n_kv_heads: int, head_dim: int, dropout: float, sliding_window: int, rope_fraction: float, ) -> None: super().__init__() self.dim = dim self.n_heads = n_heads self.n_kv_heads = n_kv_heads self.head_dim = head_dim self.n_rep = n_heads // n_kv_heads self.dropout = dropout self.sliding_window = sliding_window self.wq = nn.Linear(dim, n_heads * head_dim, bias=False) self.wk = nn.Linear(dim, n_kv_heads * head_dim, bias=False) self.wv = nn.Linear(dim, n_kv_heads * head_dim, bias=False) self.wo = nn.Linear(n_heads * head_dim, dim, bias=False) self.rope_dim = max(2, int(head_dim * rope_fraction) // 2 * 2) self.rope = RotaryEmbedding(self.rope_dim) self.q_norm = RMSNorm(head_dim) self.k_norm = RMSNorm(head_dim) self.output_gate = nn.Parameter(torch.ones(n_heads)) def forward( self, x: torch.Tensor, is_global: bool, past_kv: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, use_cache: bool = False, ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]: B, T, _ = x.shape q = self.wq(x).view(B, T, self.n_heads, self.head_dim) k = self.wk(x).view(B, T, self.n_kv_heads, self.head_dim) v = self.wv(x).view(B, T, self.n_kv_heads, self.head_dim) q = self.q_norm(q) k = self.k_norm(k) q = q.transpose(1, 2) k = k.transpose(1, 2) v = v.transpose(1, 2) past_len = past_kv[0].shape[2] if past_kv is not None else 0 cos, sin = self.rope.cos_sin(T + past_len, x.device, q.dtype) cos_slice = cos[:, :, past_len : past_len + T, :] sin_slice = sin[:, :, past_len : past_len + T, :] q_rope = q[..., : self.rope_dim] q_pass = q[..., self.rope_dim :] k_rope = k[..., : self.rope_dim] k_pass = k[..., self.rope_dim :] q_rope = (q_rope * cos_slice) + (_rotate_half(q_rope) * sin_slice) k_rope = (k_rope * cos_slice) + (_rotate_half(k_rope) * sin_slice) q = torch.cat([q_rope, q_pass], dim=-1) k = torch.cat([k_rope, k_pass], dim=-1) if past_kv is not None: k = torch.cat([past_kv[0], k], dim=2) v = torch.cat([past_kv[1], v], dim=2) new_kv = (k, v) if use_cache else None S = k.shape[2] if self.n_rep > 1: k = ( k[:, :, None, :, :] .expand(B, self.n_kv_heads, self.n_rep, S, self.head_dim) .reshape(B, self.n_heads, S, self.head_dim) ) v = ( v[:, :, None, :, :] .expand(B, self.n_kv_heads, self.n_rep, S, self.head_dim) .reshape(B, self.n_heads, S, self.head_dim) ) drop_p = self.dropout if (self.training and torch.is_grad_enabled()) else 0.0 if is_global: if past_kv is None and T > 1: out = F.scaled_dot_product_attention( q, k, v, is_causal=True, dropout_p=drop_p ) else: out = F.scaled_dot_product_attention(q, k, v, dropout_p=drop_p) else: T_q = q.shape[2] q_pos = torch.arange(past_len, past_len + T_q, device=q.device).unsqueeze(1) k_pos = torch.arange(S, device=q.device).unsqueeze(0) mask = (q_pos >= k_pos) & ((q_pos - k_pos) < self.sliding_window) out = F.scaled_dot_product_attention( q, k, v, attn_mask=mask.unsqueeze(0).unsqueeze(0), dropout_p=drop_p ) gate = torch.sigmoid(self.output_gate).view(1, self.n_heads, 1, 1) out = out * gate out = out.transpose(1, 2).contiguous().view(B, T, self.n_heads * self.head_dim) out = self.wo(out) return out, new_kv class SwiGLU(nn.Module): def __init__(self, dim: int, hidden_dim: int, dropout: float) -> None: super().__init__() self.gate = nn.Linear(dim, hidden_dim, bias=False) self.up = nn.Linear(dim, hidden_dim, bias=False) self.down = nn.Linear(hidden_dim, dim, bias=False) self.drop = nn.Dropout(dropout) nn.init.normal_(self.gate.weight, std=dim**-0.5) nn.init.normal_(self.up.weight, std=dim**-0.5) nn.init.normal_(self.down.weight, std=hidden_dim**-0.5) def forward(self, x: torch.Tensor) -> torch.Tensor: h = F.silu(self.gate(x)) * self.up(x) out = self.down(h) if self.training and torch.is_grad_enabled(): out = self.drop(out) return out def loop_index_embedding(h: torch.Tensor, loop_t: int, loop_dim: int, theta: float = 10000.0) -> torch.Tensor: if loop_dim <= 0: return h loop_dim = min(loop_dim, h.shape[-1]) if loop_dim % 2 == 1: loop_dim -= 1 if loop_dim <= 0: return h inv_freq = 1.0 / (theta ** (torch.arange(0, loop_dim, 2, device=h.device, dtype=h.dtype) / loop_dim)) phase = torch.tensor(float(loop_t), device=h.device, dtype=h.dtype) * inv_freq loop_embed = torch.cat([phase.sin(), phase.cos()], dim=0).view(1, 1, loop_dim) out = h.clone() out[..., :loop_dim] = out[..., :loop_dim] + loop_embed return out class DepthLoRAAdapter(nn.Module): def __init__(self, dim: int, rank: int, max_loops: int) -> None: super().__init__() self.rank = max(0, rank) if self.rank <= 0: self.down = None self.B = None self.scale = None return self.down = nn.Linear(dim, self.rank, bias=False) self.B = nn.Parameter(torch.randn(self.rank, dim) * 0.02) self.scale = nn.Embedding(max(1, max_loops), self.rank) nn.init.zeros_(self.scale.weight) def forward(self, x: torch.Tensor, loop_t: int) -> torch.Tensor: if self.rank <= 0 or self.down is None or self.B is None or self.scale is None: return torch.zeros_like(x) t_idx = min(loop_t, self.scale.num_embeddings - 1) scale = self.scale(torch.tensor(t_idx, device=x.device)) return (self.down(x) * scale) @ self.B class StableRecurrentInjection(nn.Module): def __init__(self, dim: int) -> None: super().__init__() self.log_A = nn.Parameter(torch.full((dim,), -2.0)) self.log_dt = nn.Parameter(torch.full((dim,), -2.0)) self.input_gate = nn.Parameter(torch.zeros(dim)) def forward(self, h: torch.Tensor, e: torch.Tensor, transformer_out: torch.Tensor) -> torch.Tensor: A = torch.exp(-torch.exp((self.log_dt + self.log_A).clamp(-20, 20))).view(1, 1, -1) B = torch.sigmoid(self.input_gate).view(1, 1, -1) return A * h + B * e + transformer_out class AdaptiveHalting(nn.Module): def __init__(self, dim: int) -> None: super().__init__() self.halt = nn.Linear(dim, 1, bias=True) nn.init.zeros_(self.halt.weight) nn.init.constant_(self.halt.bias, -2.0) def forward(self, h: torch.Tensor) -> torch.Tensor: return torch.sigmoid(self.halt(h)).squeeze(-1) class EngramBlock(nn.Module): """DeepSeek Engram: conditional memory via O(1) hashed N-gram lookup. Stores common token-pair/triplet patterns in an embedding table and retrieves them with multi-head hashing. A context-aware gate (using the current hidden state as query) decides how much of the retrieved memory to inject into the residual stream. Reference: DeepSeek-AI, "Conditional Memory via Scalable Lookup" (2025). """ def __init__( self, dim: int, engram_dim: int, n_heads: int = 4, table_size: int = 8192, max_ngram: int = 3, ) -> None: super().__init__() self.dim = dim self.engram_dim = engram_dim self.n_heads = n_heads self.table_size = table_size self.max_ngram = max_ngram # One embedding table per (ngram_order, hash_head) self.embeddings = nn.ParameterDict() for n in range(2, max_ngram + 1): for k in range(n_heads): self.embeddings[f"{n}_{k}"] = nn.Parameter( torch.randn(table_size, engram_dim) * (engram_dim**-0.5) ) # Fixed hash parameters (non-learnable, deterministic) for n in range(2, max_ngram + 1): for k in range(n_heads): seed = int(hashlib.md5(f"engram_{n}_{k}".encode()).hexdigest()[:8], 16) rng = torch.Generator().manual_seed(seed) a = torch.randint(1, 2**31, (1,), generator=rng).item() b = torch.randint(0, 2**31, (1,), generator=rng).item() self.register_buffer( f"hash_a_{n}_{k}", torch.tensor(a), persistent=False ) self.register_buffer( f"hash_b_{n}_{k}", torch.tensor(b), persistent=False ) # Causal convolution over N-gram branch outputs (kernel=4, dilation=max_ngram) total_branch_dim = engram_dim * n_heads * (max_ngram - 1) self.branch_conv = nn.Conv1d( total_branch_dim, total_branch_dim, kernel_size=4, dilation=max_ngram, padding=0, groups=total_branch_dim, bias=True, ) nn.init.zeros_(self.branch_conv.weight) nn.init.zeros_(self.branch_conv.bias) # Context-aware gating: hidden state as query, memory as key/value self.gate_query = nn.Linear(dim, engram_dim, bias=False) self.gate_key = nn.Linear(total_branch_dim, engram_dim, bias=False) self.gate_value = nn.Linear(total_branch_dim, dim, bias=False) self.gate_scale = engram_dim**-0.5 def _hash_ngram(self, token_ids: torch.Tensor, n: int, k: int) -> torch.Tensor: """Hash n-gram token sequences into table indices. Args: token_ids: (B, T) token IDs n: n-gram order (2 = bigram, 3 = trigram) k: hash head index Returns: indices: (B, T) integer indices into embedding table """ a = getattr(self, f"hash_a_{n}_{k}") b = getattr(self, f"hash_b_{n}_{k}") B, T = token_ids.shape # Pad left with zeros so every position has a valid n-gram padded = F.pad(token_ids, (n - 1, 0), value=0) # (B, T+n-1) # Polynomial rolling hash combined = torch.zeros(B, T, dtype=torch.long, device=token_ids.device) for i in range(n): combined = combined * 31 + padded[:, i : i + T].long() indices = ((a * combined) ^ b) % self.table_size return indices def forward( self, hidden: torch.Tensor, token_ids: Optional[torch.Tensor] = None ) -> torch.Tensor: """Forward pass. Args: hidden: (B, T, dim) current hidden state token_ids: (B, T) input token IDs for n-gram hashing. If None, uses argmax of hidden projections as proxy. Returns: output: (B, T, dim) memory injection for residual stream """ B, T, _ = hidden.shape if token_ids is None: # Fallback: derive pseudo-token-ids from hidden state token_ids = hidden.mean(dim=-1).long() % self.table_size # Retrieve and concatenate across n-gram orders and hash heads branch_outputs = [] for n in range(2, self.max_ngram + 1): for k in range(self.n_heads): indices = self._hash_ngram(token_ids, n, k) # (B, T) table = self.embeddings[f"{n}_{k}"] # (table_size, engram_dim) retrieved = table[indices] # (B, T, engram_dim) branch_outputs.append(retrieved) # (B, T, engram_dim * n_heads * (max_ngram - 1)) memory = torch.cat(branch_outputs, dim=-1) # Causal convolution over sequence dimension # Pad left for causality (kernel_size - 1 = 3) conv_in = memory.transpose(1, 2) # (B, C, T) conv_in = F.pad( conv_in, ((self.branch_conv.kernel_size[0] - 1) * self.branch_conv.dilation[0], 0), ) conv_out = self.branch_conv(conv_in) # (B, C, T) memory = conv_out.transpose(1, 2) # (B, T, C) # Context-aware gating query = self.gate_query(hidden) # (B, T, engram_dim) key = self.gate_key(memory) # (B, T, engram_dim) gate = torch.sigmoid( (query * key).sum(dim=-1, keepdim=True) * self.gate_scale ) # (B, T, 1) value = self.gate_value(memory) # (B, T, dim) return gate * value class SleepGate(nn.Module): """Persistent memory + periodic consolidation gate.""" def __init__( self, dim: int, cap: int = 128, n_heads: int = 4, retention_enabled: bool = True, retention_hidden: int = 0, ) -> None: super().__init__() self.dim = dim self.cap = cap self.n_heads = n_heads self.head_dim = dim // n_heads self.scale = self.head_dim ** -0.5 self.retention_enabled = retention_enabled self.register_buffer("mem_emb", torch.zeros(cap, dim, dtype=torch.bfloat16)) self.register_buffer("mem_age", torch.zeros(cap, dtype=torch.long)) self.register_buffer("mem_beta", torch.ones(cap, dtype=torch.float32)) self.register_buffer("mem_count", torch.zeros((), dtype=torch.long)) self.register_buffer("mem_head", torch.zeros((), dtype=torch.long)) self.register_buffer("global_step", torch.zeros((), dtype=torch.long)) self.q_proj = nn.Linear(dim, dim, bias=False) self.k_proj = nn.Linear(dim, dim, bias=False) self.v_proj = nn.Linear(dim, dim, bias=False) self.o_proj = nn.Linear(dim, dim, bias=False) nn.init.zeros_(self.o_proj.weight) self.gate_scale = nn.Parameter(torch.zeros(())) if retention_enabled: if retention_hidden > 0: self.retention_gate: Optional[nn.Module] = nn.Sequential( nn.Linear(dim, retention_hidden, bias=False), nn.GELU(), nn.Linear(retention_hidden, 1, bias=True), ) nn.init.constant_(self.retention_gate[-1].bias, 2.2) else: self.retention_gate = nn.Linear(dim, 1, bias=True) nn.init.constant_(self.retention_gate.bias, 2.2) else: self.retention_gate = None self._last_beta: Optional[torch.Tensor] = None def write(self, hidden: torch.Tensor) -> None: B, T, _ = hidden.shape tail_full = hidden[:, max(0, T - 16):, :].float().mean(dim=1) if self.retention_gate is not None: beta_live = torch.sigmoid(self.retention_gate(tail_full).squeeze(-1)) self._last_beta = beta_live if self.training else None beta_store = beta_live.detach().float() else: self._last_beta = None beta_store = torch.ones(B, device=hidden.device, dtype=torch.float32) tail = tail_full.to(self.mem_emb.dtype).detach() with torch.no_grad(): head = int(self.mem_head.item()) count = int(self.mem_count.item()) step = int(self.global_step.item()) for b in range(B): self.mem_emb[head] = tail[b] self.mem_age[head] = step self.mem_beta[head] = beta_store[b] head = (head + 1) % self.cap if count < self.cap: count += 1 self.mem_head.fill_(head) self.mem_count.fill_(count) def read(self, x: torch.Tensor) -> torch.Tensor: count = int(self.mem_count.item()) if count == 0: return torch.zeros_like(x) B, T, D = x.shape mem = self.mem_emb[:count].clone().to(x.dtype) q = self.q_proj(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2) k = self.k_proj(mem).view(count, self.n_heads, self.head_dim).transpose(0, 1) v = self.v_proj(mem).view(count, self.n_heads, self.head_dim).transpose(0, 1) attn = torch.einsum("bhtd,hmd->bhtm", q, k) * self.scale attn = F.softmax(attn, dim=-1) if self.retention_enabled: step = int(self.global_step.item()) ages = self.mem_age[:count].to(x.device) delta = (step - ages).clamp(min=0).to(x.dtype) betas = self.mem_beta[:count].to(x.dtype).clamp(min=1e-6, max=1.0) weights = betas.pow(delta) attn = attn * weights.view(1, 1, 1, count) attn = attn / attn.sum(dim=-1, keepdim=True).clamp_min(1e-9) out = torch.einsum("bhtm,hmd->bhtd", attn, v) out = out.transpose(1, 2).contiguous().view(B, T, D) out = self.o_proj(out) return torch.sigmoid(self.gate_scale) * out @torch.no_grad() def reset(self) -> None: self.mem_emb.zero_() self.mem_age.zero_() self.mem_beta.fill_(1.0) self.mem_count.zero_() self.mem_head.zero_() self.global_step.zero_() self._last_beta = None def _sinkhorn_knopp(logits: torch.Tensor, n_iters: int = 7) -> torch.Tensor: M = torch.exp(logits.clamp(-10, 10)) for _ in range(n_iters): M = M / M.sum(dim=-1, keepdim=True).clamp(min=1e-10) M = M / M.sum(dim=-2, keepdim=True).clamp(min=1e-10) return M class ManifoldHyperConnection(nn.Module): def __init__(self, dim: int, expansion: int = 2) -> None: super().__init__() self.dim = dim self.expansion = expansion n = expansion self.expand_fn = "duplicate" self.collapse_fn = "mean" self.bias_pre = nn.Parameter(torch.zeros(1, n)) self.bias_post = nn.Parameter(torch.zeros(1, n)) self.bias_res = nn.Parameter(torch.zeros(n, n)) self.theta_pre = nn.Linear(n * dim, n, bias=False) self.theta_post = nn.Linear(n * dim, n, bias=False) self.theta_res = nn.Linear(n * dim, n * n, bias=False) self.alpha_pre = nn.Parameter(torch.tensor(0.0)) self.alpha_post = nn.Parameter(torch.tensor(0.0)) self.alpha_res = nn.Parameter(torch.tensor(0.0)) def _compute_mappings( self, x_expanded: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: B, T, _ = x_expanded.shape n = self.expansion x_norm = F.rms_norm(x_expanded, [x_expanded.shape[-1]]) d_pre = torch.tanh(self.theta_pre(x_norm)) d_post = torch.tanh(self.theta_post(x_norm)) d_res = self.theta_res(x_norm) H_pre_raw = torch.sigmoid(self.alpha_pre * d_pre + self.bias_pre) H_post_raw = 2.0 * torch.sigmoid(self.alpha_post * d_post + self.bias_post) H_res_raw = (self.alpha_res * d_res + self.bias_res.reshape(1, 1, -1)).reshape( B, T, n, n ) H_res = _sinkhorn_knopp(H_res_raw) return H_pre_raw.unsqueeze(-2), H_post_raw.unsqueeze(-2), H_res def expand_stream(self, x: torch.Tensor) -> torch.Tensor: return x.repeat(1, 1, self.expansion) def collapse_stream(self, x_expanded: torch.Tensor) -> torch.Tensor: B, T, _ = x_expanded.shape n = self.expansion C = self.dim return x_expanded.view(B, T, n, C).mean(dim=-2) def pre_mix(self, x_expanded: torch.Tensor, H_pre: torch.Tensor) -> torch.Tensor: B, T, _ = x_expanded.shape n = self.expansion x_streams = x_expanded.view(B, T, n, self.dim) return (H_pre @ x_streams).squeeze(-2) def post_res_mix( self, layer_output: torch.Tensor, x_expanded: torch.Tensor, H_post: torch.Tensor, H_res: torch.Tensor, ) -> torch.Tensor: B, T, _ = x_expanded.shape n = self.expansion C = self.dim x_streams = x_expanded.view(B, T, n, C) mixed = torch.matmul(H_res, x_streams) post_out = torch.matmul(H_post.transpose(-2, -1), layer_output.unsqueeze(-2)) result = mixed + post_out return result.reshape(B, T, n * C) class TransformerBlock(nn.Module): def __init__( self, dim: int, n_heads: int, n_kv_heads: int, head_dim: int, ffn_dim: int, dropout: float, sliding_window: int, rope_fraction: float, engram_dim: int = 0, engram_heads: int = 4, engram_table_size: int = 8192, engram_max_ngram: int = 3, mhc_expansion: int = 1, ) -> None: super().__init__() self.norm1 = RMSNorm(dim) self.attn = CausalSelfAttention( dim=dim, n_heads=n_heads, n_kv_heads=n_kv_heads, head_dim=head_dim, dropout=dropout, sliding_window=sliding_window, rope_fraction=rope_fraction, ) self.norm2 = RMSNorm(dim) self.ffn = SwiGLU(dim, ffn_dim, dropout) self.use_engram = engram_dim > 0 if self.use_engram: self.engram = EngramBlock( dim=dim, engram_dim=engram_dim, n_heads=engram_heads, table_size=engram_table_size, max_ngram=engram_max_ngram, ) self.engram_norm = RMSNorm(dim) self.use_mhc = mhc_expansion > 1 if self.use_mhc: self.mhc_attn = ManifoldHyperConnection(dim, expansion=mhc_expansion) self.mhc_ffn = ManifoldHyperConnection(dim, expansion=mhc_expansion) def forward( self, x: torch.Tensor, is_global: bool, past_kv: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, use_cache: bool = False, token_ids: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]: if self.use_mhc: x_exp = self.mhc_attn.expand_stream(x) H_pre, H_post, H_res = self.mhc_attn._compute_mappings(x_exp) attn_in = self.mhc_attn.pre_mix(x_exp, H_pre) attn_out, new_kv = self.attn( self.norm1(attn_in), is_global, past_kv, use_cache ) x_exp = self.mhc_attn.post_res_mix(attn_out, x_exp, H_post, H_res) if self.use_engram: collapsed = self.mhc_attn.collapse_stream(x_exp) collapsed = collapsed + self.engram( self.engram_norm(collapsed), token_ids=token_ids ) x_exp = self.mhc_attn.expand_stream(collapsed) H_pre2, H_post2, H_res2 = self.mhc_ffn._compute_mappings(x_exp) ffn_in = self.mhc_ffn.pre_mix(x_exp, H_pre2) ffn_out = self.ffn(self.norm2(ffn_in)) x_exp = self.mhc_ffn.post_res_mix(ffn_out, x_exp, H_post2, H_res2) x = self.mhc_attn.collapse_stream(x_exp) else: attn_out, new_kv = self.attn(self.norm1(x), is_global, past_kv, use_cache) x = x + attn_out if self.use_engram: x = x + self.engram(self.engram_norm(x), token_ids=token_ids) x = x + self.ffn(self.norm2(x)) return x, new_kv class RecurrentDepthBlock(nn.Module): def __init__( self, dim: int, n_heads: int, n_kv_heads: int, head_dim: int, ffn_dim: int, dropout: float, sliding_window: int, rope_fraction: float, n_loops: int, act_threshold: float, lora_rank: int, loop_embed_dim: int, ) -> None: super().__init__() self.n_loops = max(1, n_loops) self.act_threshold = act_threshold self.loop_embed_dim = max(0, loop_embed_dim) self.norm = RMSNorm(dim) self.block = TransformerBlock( dim=dim, n_heads=n_heads, n_kv_heads=n_kv_heads, head_dim=head_dim, ffn_dim=ffn_dim, dropout=dropout, sliding_window=sliding_window, rope_fraction=rope_fraction, engram_dim=0, mhc_expansion=1, ) self.injection = StableRecurrentInjection(dim) self.act = AdaptiveHalting(dim) self.lora = DepthLoRAAdapter(dim, lora_rank, self.n_loops) def forward( self, h: torch.Tensor, e: torch.Tensor, token_ids: Optional[torch.Tensor] = None, past_key_values: Optional[List[Optional[Tuple[torch.Tensor, torch.Tensor]]]] = None, use_cache: bool = False, n_loops: Optional[int] = None, ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], Optional[List[Tuple[torch.Tensor, torch.Tensor]]]]: loops = max(1, n_loops or self.n_loops) B, T, _ = h.shape halted = torch.zeros(B, T, device=h.device, dtype=torch.bool) cumulative_p = torch.zeros(B, T, device=h.device, dtype=h.dtype) output = torch.zeros_like(h) new_past: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = [] if use_cache else None current = h final_halt = None for t in range(loops): h_loop = loop_index_embedding(current, t, self.loop_embed_dim) combined = self.norm(h_loop + e) past_kv = None if past_key_values is not None and t < len(past_key_values): past_kv = past_key_values[t] trans_out, layer_kv = self.block(combined, is_global=True, past_kv=past_kv, use_cache=use_cache, token_ids=token_ids) trans_out = trans_out + self.lora(trans_out, t) next_h = self.injection(current, e, trans_out) p = self.act(next_h) p = p * (~halted).to(p.dtype) final_halt = p should_halt = (~halted) & ((cumulative_p + p) >= self.act_threshold) update_weight = torch.where(should_halt, (1.0 - cumulative_p).clamp(min=0.0), p) output = output + next_h * update_weight.unsqueeze(-1) cumulative_p = cumulative_p + update_weight current = torch.where(halted.unsqueeze(-1), current, next_h) halted = halted | should_halt if new_past is not None: new_past.append(layer_kv) if not use_cache and bool(halted.all()): break remainder = (1.0 - cumulative_p).clamp(min=0.0) output = output + current * remainder.unsqueeze(-1) aux: Dict[str, torch.Tensor] = {} if final_halt is not None: aux["recurrent_halt_mean"] = final_halt.mean() return output, aux, new_past class TinyMemoryLM(nn.Module): def __init__( self, vocab_size: int, dim: int, n_unique_layers: int, n_logical_layers: int, n_heads: int, n_kv_heads: int, ffn_dim: int, dropout: float, mtp_horizons: Sequence[int], grad_checkpoint: bool, sliding_window: int = 512, rope_fraction: float = 0.5, embed_scale: bool = True, engram_dim: int = 0, engram_heads: int = 4, engram_table_size: int = 8192, engram_max_ngram: int = 3, mhc_expansion: int = 1, sleep_gate_cap: int = 0, sleep_gate_heads: int = 4, sleep_retention_enabled: bool = True, sleep_retention_hidden: int = 0, latent_think_layers: int = 0, prelude_layers: int = 0, coda_layers: int = 0, recurrent_loops: int = 0, recurrent_act_threshold: float = 0.99, recurrent_lora_rank: int = 0, recurrent_loop_embed_dim: int = 0, ) -> None: super().__init__() self.dim = dim self.n_unique_layers = n_unique_layers self.n_logical_layers = n_logical_layers self.grad_checkpoint = grad_checkpoint self.embed_scale_factor = math.sqrt(dim) if embed_scale else 1.0 head_dim = dim // n_heads self.embed_tokens = nn.Embedding(vocab_size, dim) self.head = nn.Linear(dim, vocab_size, bias=False) self.head.weight = self.embed_tokens.weight self.output_bias = nn.Parameter(torch.zeros(vocab_size)) self.use_recurrent_depth = recurrent_loops > 0 self.prelude_layers = max(0, prelude_layers) self.coda_layers = max(0, coda_layers) self.recurrent_loops = max(0, recurrent_loops) self.blocks: Optional[nn.ModuleList] = None self.prelude: Optional[nn.ModuleList] = None self.recurrent: Optional[RecurrentDepthBlock] = None self.coda: Optional[nn.ModuleList] = None def _make_blocks(n: int) -> nn.ModuleList: return nn.ModuleList([ TransformerBlock( dim=dim, n_heads=n_heads, n_kv_heads=n_kv_heads, head_dim=head_dim, ffn_dim=ffn_dim, dropout=dropout, sliding_window=sliding_window, rope_fraction=rope_fraction, engram_dim=engram_dim, engram_heads=engram_heads, engram_table_size=engram_table_size, engram_max_ngram=engram_max_ngram, mhc_expansion=mhc_expansion, ) for _ in range(n) ]) if self.use_recurrent_depth: if self.prelude_layers > 0: self.prelude = _make_blocks(self.prelude_layers) self.recurrent = RecurrentDepthBlock( dim=dim, n_heads=n_heads, n_kv_heads=n_kv_heads, head_dim=head_dim, ffn_dim=ffn_dim, dropout=dropout, sliding_window=sliding_window, rope_fraction=rope_fraction, n_loops=self.recurrent_loops, act_threshold=recurrent_act_threshold, lora_rank=recurrent_lora_rank, loop_embed_dim=recurrent_loop_embed_dim or max(2, dim // 8), ) if self.coda_layers > 0: self.coda = _make_blocks(self.coda_layers) else: self.blocks = _make_blocks(max(1, n_unique_layers)) self.norm = RMSNorm(dim) self.mtp_horizons = sorted({int(h) for h in mtp_horizons if int(h) > 1}) self.mtp_adapters = nn.ModuleDict( {str(h): nn.Linear(dim, dim, bias=False) for h in self.mtp_horizons} ) self.mtp_norms = nn.ModuleDict( {str(h): RMSNorm(dim) for h in self.mtp_horizons} ) res_scale = (2 * max(1, n_logical_layers)) ** -0.5 for group in (self.blocks, self.prelude, self.coda): if group is None: continue for block in group: block.attn.wo.weight.data.mul_(res_scale) block.ffn.down.weight.data.mul_(res_scale) if self.recurrent is not None: self.recurrent.block.attn.wo.weight.data.mul_(res_scale) self.recurrent.block.ffn.down.weight.data.mul_(res_scale) self.sleep_gate: Optional[SleepGate] = None if sleep_gate_cap > 0: self.sleep_gate = SleepGate( dim=dim, cap=sleep_gate_cap, n_heads=sleep_gate_heads, retention_enabled=sleep_retention_enabled, retention_hidden=sleep_retention_hidden, ) self.think_blocks: Optional[nn.ModuleList] = None self.think_norm: Optional[RMSNorm] = None if latent_think_layers > 0: self.think_blocks = nn.ModuleList([ TransformerBlock( dim=dim, n_heads=n_heads, n_kv_heads=n_kv_heads, head_dim=head_dim, ffn_dim=ffn_dim, dropout=0.0, sliding_window=2048, rope_fraction=rope_fraction, engram_dim=0, mhc_expansion=1, ) for _ in range(latent_think_layers) ]) self.think_norm = RMSNorm(dim) def resize_token_embeddings(self, new_vocab_size: int) -> None: old_vocab_size = self.embed_tokens.num_embeddings if new_vocab_size == old_vocab_size: return device = self.embed_tokens.weight.device old_embed_weight = self.embed_tokens.weight.data.clone() self.embed_tokens = nn.Embedding(new_vocab_size, self.embed_tokens.embedding_dim).to(device) self.head = nn.Linear(self.embed_tokens.embedding_dim, new_vocab_size, bias=False).to(device) self.head.weight = self.embed_tokens.weight old_bias = self.output_bias.data.clone() self.output_bias = nn.Parameter(torch.zeros(new_vocab_size, device=device)) copy_size = min(old_vocab_size, new_vocab_size) self.output_bias.data[:copy_size] = old_bias[:copy_size] self.embed_tokens.weight.data[:copy_size] = old_embed_weight[:copy_size] def _build_logical_layers(self) -> List[Tuple[nn.Module, int]]: if self.blocks is None: return [] blocks_list = list(self.blocks) full_sequence = blocks_list + blocks_list return [(block, i) for i, block in enumerate(full_sequence[: self.n_logical_layers])] def forward( self, ids: torch.Tensor, use_cache: bool = False, past_key_values: Optional[List[Optional[Tuple[torch.Tensor, torch.Tensor]]]] = None, return_hidden: bool = False, ) -> Tuple[torch.Tensor, Dict[int, torch.Tensor], Dict[str, torch.Tensor], Optional[torch.Tensor], Optional[List[Tuple[torch.Tensor, torch.Tensor]]]]: B, T = ids.shape x = self.embed_tokens(ids) * self.embed_scale_factor new_past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = [] if use_cache else None aux: Dict[str, torch.Tensor] = {} if self.use_recurrent_depth: offset = 0 if self.prelude is not None: for block in self.prelude: past_kv = past_key_values[offset] if past_key_values is not None and offset < len(past_key_values) else None x, layer_kv = block(x, is_global=True, past_kv=past_kv, use_cache=use_cache, token_ids=ids) if new_past_key_values is not None: new_past_key_values.append(layer_kv) offset += 1 encoded = x recurrent_past = past_key_values[offset: offset + self.recurrent_loops] if past_key_values is not None else None x, recurrent_aux, recurrent_kv = self.recurrent( x, encoded, token_ids=ids, past_key_values=recurrent_past, use_cache=use_cache, ) aux.update(recurrent_aux) if new_past_key_values is not None and recurrent_kv is not None: new_past_key_values.extend(recurrent_kv) offset += self.recurrent_loops if self.coda is not None: for block in self.coda: past_kv = past_key_values[offset] if past_key_values is not None and offset < len(past_key_values) else None x, layer_kv = block(x, is_global=True, past_kv=past_kv, use_cache=use_cache, token_ids=ids) if new_past_key_values is not None: new_past_key_values.append(layer_kv) offset += 1 else: logical_layers = self._build_logical_layers() last_logical_idx = len(logical_layers) - 1 for layer_idx, (block, logical_idx) in enumerate(logical_layers): is_global = logical_idx % 2 == 0 or layer_idx == last_logical_idx past_kv = past_key_values[layer_idx] if past_key_values is not None and layer_idx < len(past_key_values) else None if self.grad_checkpoint and self.training and not use_cache: x, layer_kv = checkpoint(block, x, is_global, past_kv, use_cache, ids, use_reentrant=True) else: x, layer_kv = block(x, is_global, past_kv, use_cache, ids) if new_past_key_values is not None: new_past_key_values.append(layer_kv) x = self.norm(x) if self.sleep_gate is not None: x = x + self.sleep_gate.read(x) if self.training: self.sleep_gate.write(x) if self.think_blocks is not None: for think_block in self.think_blocks: x, _ = think_block(x, is_global=True) x = self.think_norm(x) h_out = x if return_hidden else None logits = self.head(x) if self.embed_scale_factor != 1.0: logits = logits / self.embed_scale_factor logits = logits + self.output_bias mtp: Dict[int, torch.Tensor] = {} if self.mtp_horizons and self.training: for horizon in self.mtp_horizons: if horizon > 1 and horizon <= T - 1: shifted_h = x[:, :-horizon, :] adapted_h = self.mtp_adapters[str(horizon)](shifted_h) adapted_h = self.mtp_norms[str(horizon)](adapted_h) mtp_logits = self.head(adapted_h) if self.embed_scale_factor != 1.0: mtp_logits = mtp_logits / self.embed_scale_factor mtp_logits = mtp_logits + self.output_bias mtp[horizon] = mtp_logits return logits, mtp, aux, h_out, new_past_key_values # --------------------------------------------------------------------------- # Generation # --------------------------------------------------------------------------- def build_stop_token_ids(tokenizer: WordTokenizer) -> set: stop_tokens = {tokenizer.eos_id} for tok in ("<|user|>", "<|system|>", "<|assistant|>"): tid = tokenizer.token_to_id.get(tok) if tid is not None: stop_tokens.add(int(tid)) return stop_tokens def apply_no_repeat_ngram( logits: torch.Tensor, token_history: Sequence[int], ngram_size: int, ) -> torch.Tensor: if ngram_size <= 1 or len(token_history) < max(0, ngram_size - 1): return logits prefix = tuple(token_history[-(ngram_size - 1) :]) if ngram_size > 1 else tuple() banned: set = set() for i in range(len(token_history) - ngram_size + 1): if tuple(token_history[i : i + ngram_size - 1]) == prefix: banned.add(int(token_history[i + ngram_size - 1])) if not banned: return logits out = logits.clone() banned_ids = torch.tensor(sorted(banned), device=logits.device, dtype=torch.long) out[banned_ids] = float("-inf") return out def apply_loop_penalty( logits: torch.Tensor, tokenizer: WordTokenizer, generated_text: str, penalty: float = 5.0, ) -> torch.Tensor: """Detect repeated substring loops and penalise continuation tokens.""" if len(generated_text) < 16: return logits out = logits.clone() for span_len in [24, 16, 12, 8]: if len(generated_text) < span_len * 2: continue suffix = generated_text[-span_len:] prev = generated_text[:-span_len].rfind(suffix) if prev == -1: continue next_pos = prev + span_len if next_pos < len(generated_text): next_char = generated_text[next_pos] tid = tokenizer.token_to_id.get(next_char) if tid is not None: out[tid] -= penalty break return out def apply_min_p(logits: torch.Tensor, min_p: float) -> torch.Tensor: """Filter tokens below min_p fraction of the top token probability.""" if min_p <= 0.0: return logits probs = torch.softmax(logits, dim=-1) threshold = probs.max() * min_p out = logits.clone() out[probs < threshold] = float("-inf") return out def generate( model: TinyMemoryLM, tokenizer: WordTokenizer, prompt: str, max_new_tokens: int = 256, temperature: float = 0.8, top_k: int = 16, top_p: float = 0.95, repetition_penalty: float = 1.0, device: str = "cuda", sft_mode: bool = True, stream: bool = True, no_repeat_ngram_size: int = 0, context_window: int = 2048, logit_soft_cap: float = 15.0, min_p: float = 0.05, loop_penalty: float = 5.0, ) -> str: if sft_mode: full_prompt = f"<|user|>\n{prompt}\n<|assistant|>\n" else: full_prompt = prompt input_ids = tokenizer.encode(full_prompt, add_bos=True, add_eos=False) input_ids_t = torch.tensor([input_ids], dtype=torch.long, device=device) visible_tokens: List[str] = [] stop_token_ids = build_stop_token_ids(tokenizer) generated_text = "" generated_ids: List[int] = [] # Full history (prompt + generated) for ngram blocking — prevents echoing prompt full_ids_history: List[int] = list(input_ids) with torch.no_grad(): for _ in range(max_new_tokens): ctx_ids = ( input_ids_t[:, -context_window:] if context_window > 0 else input_ids_t ) logits, *_ = model(ctx_ids) next_logits = logits[0, -1, :].clone() # Logit soft-capping (Gemma-style) — prevents overconfident collapse if logit_soft_cap > 0: next_logits = logit_soft_cap * torch.tanh(next_logits / logit_soft_cap) raw_next_logits = next_logits.clone() # Repetition penalty on previously generated tokens if repetition_penalty != 1.0 and generated_ids: for tok_id in set(generated_ids): if next_logits[tok_id] > 0: next_logits[tok_id] /= repetition_penalty else: next_logits[tok_id] *= repetition_penalty # No-repeat n-gram blocking on generated tokens only if no_repeat_ngram_size > 0 and generated_ids: next_logits = apply_no_repeat_ngram(next_logits, generated_ids, no_repeat_ngram_size) # Substring loop detection next_logits = apply_loop_penalty(next_logits, tokenizer, generated_text, penalty=loop_penalty) # Temperature scaling if temperature != 1.0: next_logits = next_logits / max(temperature, 1e-6) # Min-p filtering — remove tokens below min_p * max_prob if min_p > 0: next_logits = apply_min_p(next_logits, min_p) # Top-k filtering if top_k > 0: v, _ = torch.topk(next_logits, min(top_k, next_logits.size(0))) next_logits[next_logits < v[-1]] = float("-inf") # Top-p (nucleus) filtering if 0.0 < top_p < 1.0: sorted_logits, sorted_indices = torch.sort(next_logits, descending=True) sorted_probs = torch.softmax(sorted_logits, dim=-1) cumulative_probs = torch.cumsum(sorted_probs, dim=-1) remove_mask = cumulative_probs > top_p remove_mask[0] = False indices_to_remove = sorted_indices[remove_mask] next_logits[indices_to_remove] = float("-inf") # Fallback if all tokens masked if not torch.isfinite(next_logits).any(): next_logits = raw_next_logits if temperature != 1.0: next_logits = next_logits / max(temperature, 1e-6) if temperature == 0: next_id = torch.argmax(next_logits).item() else: probs = torch.softmax(next_logits, dim=-1) next_id = torch.multinomial(probs, num_samples=1).item() if next_id in stop_token_ids: break token_str = ( tokenizer.id_to_token[next_id] if next_id < len(tokenizer.id_to_token) else "" ) generated_ids.append(next_id) full_ids_history.append(next_id) if token_str not in tokenizer.special: visible_tokens.append(token_str) generated_text += token_str if stream: print(token_str, end="", flush=True) input_ids_t = torch.cat( [input_ids_t, torch.tensor([[next_id]], device=device)], dim=1 ) if stream: print() return "".join(visible_tokens) # --------------------------------------------------------------------------- # Local model loading # --------------------------------------------------------------------------- def series_from_name(name: str) -> str | None: lower = (name or "").lower() if "haiku" in lower: return "Haiku" if "sonnet" in lower: return "Sonnet" if "opus" in lower: return "Opus" return None def series_config(series: str) -> dict[str, object]: return MODEL_SERIES.get(series.lower(), MODEL_SERIES["sonnet"]) def discover_models(runs_dir: Path) -> List[dict]: models = [] if not runs_dir.is_dir(): return models for child in sorted(runs_dir.iterdir()): if not child.is_dir(): continue tokenizer_path = child / "tokenizer.json" if not tokenizer_path.exists(): continue name = child.name series = None for ckpt_name in ("model.pt", "pretrain.pt"): ckpt_path = child / ckpt_name if ckpt_path.exists(): series = _fast_series_from_checkpoint(ckpt_path) break if series is None: series = series_from_name(name) or "Sonnet" found = False for ckpt_name in ("model.pt", "model_rep.pt", "pretrain.pt"): ckpt_path = child / ckpt_name if ckpt_path.exists(): models.append( { "name": name, "checkpoint": ckpt_name, "series": series, "model_path": ckpt_path, "tokenizer_path": tokenizer_path, } ) found = True if not found: step_ckpts = sorted( child.glob("checkpoint_step_*.pt"), key=lambda p: int(p.stem.rsplit("_", 1)[-1]), ) if step_ckpts: ckpt_path = step_ckpts[-1] models.append( { "name": name, "checkpoint": ckpt_path.name, "series": series, "model_path": ckpt_path, "tokenizer_path": tokenizer_path, } ) return models def _detect_engram(state_dict): for key in state_dict: if ".engram." in key: if ".embeddings." in key: return state_dict[key].shape[-1] return 0 def _detect_mhc(state_dict): for key, val in state_dict.items(): if ".mhc_attn.bias_pre" in key and val.dim() == 2: return val.shape[-1] # (1, expansion) return 1 def _detect_sleep_gate(state_dict) -> Tuple[int, int]: for key, val in state_dict.items(): if key == "sleep_gate.mem_emb" and val.dim() == 2: cap = val.shape[0] return cap, 4 return 0, 4 def _detect_latent_think(state_dict) -> int: indices = { int(k.split(".")[1]) for k in state_dict if k.startswith("think_blocks.") and k.split(".")[1].isdigit() } return max(indices) + 1 if indices else 0 def _detect_prelude_layers(state_dict) -> int: indices = { int(k.split(".")[1]) for k in state_dict if k.startswith("prelude.") and k.split(".")[1].isdigit() } return max(indices) + 1 if indices else 0 def _detect_coda_layers(state_dict) -> int: indices = { int(k.split(".")[1]) for k in state_dict if k.startswith("coda.") and k.split(".")[1].isdigit() } return max(indices) + 1 if indices else 0 def _detect_recurrent_loops(state_dict) -> int: if "recurrent.norm.weight" in state_dict or "recurrent.block.attn.wq.weight" in state_dict: if "recurrent.lora.scale.weight" in state_dict: return state_dict["recurrent.lora.scale.weight"].shape[0] return 1 return 0 def _detect_recurrent_lora_rank(state_dict) -> int: for key in ("recurrent.lora.B", "recurrent.lora.down.weight"): if key in state_dict: shape = state_dict[key].shape if len(shape) == 2: return int(shape[0]) return 0 def _infer_series_from_lora_rank(rank: int) -> str | None: if rank == 0: return None if rank <= 8: return "haiku" if rank <= 16: return "sonnet" return "opus" def _fast_series_from_checkpoint(ckpt_path: Path) -> str | None: try: cp = torch.load(ckpt_path, map_location="cpu", weights_only=False) sd = cp.get("model_state", cp.get("state_dict", {})) rank = 0 for key in ("recurrent.lora.B", "recurrent.lora.down.weight"): if key in sd: rank = int(sd[key].shape[0]) break if rank == 0: return None if rank <= 8: return "Haiku" if rank <= 16: return "Sonnet" return "Opus" except Exception: pass return None def _infer_arch_from_state_dict(state_dict, cfg): """Infer architecture hyper-parameters directly from checkpoint weights, falling back to *cfg* (series config) when a key is not found.""" overrides = {} has_prelude = any(k.startswith("prelude.") for k in state_dict) has_blocks = any(k.startswith("blocks.") for k in state_dict) has_recurrent = any(k.startswith("recurrent.") for k in state_dict) uses_recurrent_arch = has_prelude and has_recurrent and not has_blocks # dim from embed_tokens.weight [vocab, dim] if "embed_tokens.weight" in state_dict: overrides["dim"] = state_dict["embed_tokens.weight"].shape[1] if uses_recurrent_arch: if "prelude.0.ffn.gate.weight" in state_dict: overrides["ffn_dim"] = state_dict["prelude.0.ffn.gate.weight"].shape[0] overrides["n_unique_layers"] = 0 src = "prelude.0" else: if "blocks.0.ffn.gate.weight" in state_dict: overrides["ffn_dim"] = state_dict["blocks.0.ffn.gate.weight"].shape[0] block_ids = { int(k.split(".")[1]) for k in state_dict if k.startswith("blocks.") and k.split(".")[1].isdigit() } if block_ids: overrides["n_unique_layers"] = max(block_ids) + 1 src = "blocks.0" dim = overrides.get("dim", int(cfg.get("dim", model_config.dim))) if f"{src}.attn.wq.weight" in state_dict: wq_rows = state_dict[f"{src}.attn.wq.weight"].shape[0] if f"{src}.attn.q_norm.weight" in state_dict: head_dim = state_dict[f"{src}.attn.q_norm.weight"].shape[0] overrides["n_heads"] = wq_rows // head_dim if f"{src}.attn.wk.weight" in state_dict: wk_rows = state_dict[f"{src}.attn.wk.weight"].shape[0] if f"{src}.attn.k_norm.weight" in state_dict: head_dim = state_dict[f"{src}.attn.k_norm.weight"].shape[0] overrides["n_kv_heads"] = wk_rows // head_dim # engram params for key, val in state_dict.items(): if ".engram.embeddings." in key and key.endswith("_0") and val.dim() == 2: overrides["engram_table_size"] = val.shape[0] overrides["engram_dim"] = val.shape[1] break engram_dim = overrides.get("engram_dim", int(cfg.get("engram_dim", 0))) engram_max_ngram = int(cfg.get("engram_max_ngram", 2)) if engram_dim > 0: for key, val in state_dict.items(): if ".engram.branch_conv.weight" in key and val.dim() == 3: total_branch_dim = val.shape[0] denom = engram_dim * (engram_max_ngram - 1) if denom > 0 and total_branch_dim % denom == 0: overrides["engram_heads"] = total_branch_dim // denom break merged = dict(cfg) merged.update(overrides) return merged def load_local_model(model_path: Path, tokenizer_path: Path, series: str) -> dict: tokenizer = WordTokenizer.load(tokenizer_path) ckpt = torch.load(str(model_path), map_location="cpu", weights_only=False) cfg = series_config(series) vocab_size = int(ckpt.get("vocab_size", tokenizer.vocab_size)) state_dict = ckpt.get("model_state") or ckpt.get("state_dict") or ckpt cfg = _infer_arch_from_state_dict(state_dict, cfg) engram_dim = int(cfg.get("engram_dim", 0)) if _detect_engram(state_dict) == 0: engram_dim = 0 mhc_expansion = _detect_mhc(state_dict) if mhc_expansion == 1: mhc_expansion = int(cfg.get("mhc_expansion", 1)) ckpt_sleep_cap, ckpt_sleep_heads = _detect_sleep_gate(state_dict) sleep_gate_cap = ckpt_sleep_cap if ckpt_sleep_cap > 0 else int(cfg.get("sleep_gate_cap", 0)) sleep_gate_heads = ckpt_sleep_heads if ckpt_sleep_cap > 0 else int(cfg.get("sleep_gate_heads", 4)) sleep_retention_enabled = bool(cfg.get("sleep_retention_enabled", True)) sleep_retention_hidden = int(cfg.get("sleep_retention_hidden", 0)) latent_think_layers = _detect_latent_think(state_dict) if latent_think_layers == 0: latent_think_layers = int(cfg.get("latent_think_layers", 0)) prelude_layers = _detect_prelude_layers(state_dict) coda_layers = _detect_coda_layers(state_dict) recurrent_loops = _detect_recurrent_loops(state_dict) ckpt_lora_rank = _detect_recurrent_lora_rank(state_dict) if ckpt_lora_rank > 0: inferred_series = _infer_series_from_lora_rank(ckpt_lora_rank) if inferred_series and inferred_series != series.lower(): series = inferred_series.capitalize() cfg = series_config(series) recurrent_lora_rank = ckpt_lora_rank else: recurrent_lora_rank = int(cfg.get("recurrent_lora_rank", 0)) recurrent_act_threshold = float(cfg.get("recurrent_act_threshold", 0.99)) recurrent_loop_embed_dim = int(cfg.get("recurrent_loop_embed_dim", 0)) n_unique = int(cfg.get("n_unique_layers", model_config.n_unique_layers)) model = TinyMemoryLM( vocab_size=vocab_size, dim=int(cfg.get("dim", model_config.dim)), n_unique_layers=n_unique, n_logical_layers=int(cfg.get("n_logical_layers", model_config.n_logical_layers)), n_heads=int(cfg.get("n_heads", model_config.n_heads)), n_kv_heads=int(cfg.get("n_kv_heads", model_config.n_kv_heads)), ffn_dim=int(cfg.get("ffn_dim", model_config.ffn_dim)), dropout=float(cfg.get("dropout", model_config.dropout)), mtp_horizons=tuple(int(v) for v in cfg.get("mtp_horizons", model_config.mtp_horizons)), grad_checkpoint=False, sliding_window=int(cfg.get("sliding_window_size", getattr(model_config, "sliding_window_size", 512))), rope_fraction=float(cfg.get("rope_fraction", getattr(model_config, "rope_fraction", 0.25))), embed_scale=bool(cfg.get("embed_scale", getattr(model_config, "embed_scale", True))), engram_dim=engram_dim, engram_heads=int(cfg.get("engram_heads", 4)), engram_table_size=int(cfg.get("engram_table_size", 8192)), engram_max_ngram=int(cfg.get("engram_max_ngram", 3)), mhc_expansion=mhc_expansion, sleep_gate_cap=sleep_gate_cap, sleep_gate_heads=sleep_gate_heads, sleep_retention_enabled=sleep_retention_enabled, sleep_retention_hidden=sleep_retention_hidden, latent_think_layers=latent_think_layers, prelude_layers=prelude_layers, coda_layers=coda_layers, recurrent_loops=recurrent_loops, recurrent_act_threshold=recurrent_act_threshold, recurrent_lora_rank=recurrent_lora_rank, recurrent_loop_embed_dim=recurrent_loop_embed_dim, ) model.load_state_dict(state_dict, strict=False) model.eval() if tokenizer.vocab_size > vocab_size: model.resize_token_embeddings(tokenizer.vocab_size) device = "cuda" if torch.cuda.is_available() else "cpu" model = model.to(device) return { "model": model, "tokenizer": tokenizer, "device": device, "series": series, "sft_mode": ckpt.get("sft_mode", None), "phase": ckpt.get("phase", None), } # --------------------------------------------------------------------------- # HuggingFace Model Download & Loading # --------------------------------------------------------------------------- def download_huggingface_model(hf_id: str, cache_dir: Path) -> dict: try: from huggingface_hub import snapshot_download except ImportError: print("huggingface_hub not installed. Install with: pip install huggingface_hub") sys.exit(1) print(f"Downloading {hf_id}...") try: local_dir = Path(snapshot_download(repo_id=hf_id, cache_dir=str(cache_dir))) except Exception as e: print(f"Failed to download {hf_id}: {e}") return None print(f"Using cached {hf_id} from {local_dir}") # Check common subdirectory names: "models/", "model/" if (local_dir / "models").exists(): model_dir = local_dir / "models" elif (local_dir / "model").exists(): model_dir = local_dir / "model" else: model_dir = local_dir model_path = model_dir / "model.pt" pretrain_path = model_dir / "pretrain.pt" tokenizer_path = model_dir / "tokenizer.json" ckpt_path = None for p in [model_path, pretrain_path]: if p.exists(): ckpt_path = p break if ckpt_path is None or not tokenizer_path.exists(): print(f"Missing model files in {model_dir}") print(f" model.pt exists: {model_path.exists()}") print(f" pretrain.pt exists: {pretrain_path.exists()}") print(f" tokenizer.json exists: {tokenizer_path.exists()}") return None return { "model_path": ckpt_path, "tokenizer_path": tokenizer_path, "model_name": ckpt_path.stem, } def load_huggingface_model(hf_id: str, cache_dir: Path) -> dict: files = download_huggingface_model(hf_id, cache_dir) if files is None: return None return load_local_model(files["model_path"], files["tokenizer_path"], "Haiku") # --------------------------------------------------------------------------- # Compare All Models # --------------------------------------------------------------------------- _hf_model_cache: Dict[str, dict] = {} def prefetch_huggingface_models() -> None: root = Path(__file__).resolve().parent cache_dir = root / "cache" / "huggingface" cache_dir.mkdir(parents=True, exist_ok=True) print("Downloading/preparing HuggingFace models...") for name, hf_id in HUGGINGFACE_MODELS.items(): print(f" {name}...") bundle = load_huggingface_model(hf_id, cache_dir) if bundle: _hf_model_cache[name] = bundle print(f"Prepared {len(_hf_model_cache)} HuggingFace models") def compare_all_models(prompt: str, cfg: dict) -> None: root = Path(__file__).resolve().parent runs_dir = root / "runs" all_models = discover_models(runs_dir) is_pretrain = not cfg.get("sft_mode", True) local_models = [ m for m in all_models if ("pretrain" in m["checkpoint"]) == is_pretrain ] if not local_models: print("No models found matching mode.") return results: List[dict] = [] for m in local_models: print(f"\n{'='*60}") print(f"Loading local {m['name']}/{m['checkpoint']}...") try: bundle = load_local_model(m["model_path"], m["tokenizer_path"], m["series"]) except Exception as e: print(f"Failed to load {m['name']}: {e}") continue model = bundle["model"] tokenizer = bundle["tokenizer"] device = bundle["device"] print(f"Generating on '{prompt}'...") output = generate( model=model, tokenizer=tokenizer, prompt=prompt, max_new_tokens=cfg["max_new_tokens"], temperature=cfg["temperature"], top_k=cfg["top_k"], top_p=cfg["top_p"], min_p=cfg["min_p"], no_repeat_ngram_size=cfg["no_repeat_ngram_size"], repetition_penalty=cfg["repetition_penalty"], logit_soft_cap=cfg["logit_soft_cap"], loop_penalty=cfg["loop_penalty"], device=str(device), sft_mode=cfg["sft_mode"], stream=True, context_window=cfg["context_window"], ) results.append({ "name": f"[LOCAL] {m['name']}/{m['checkpoint']}", "output": output, "device": device, }) for name, bundle in _hf_model_cache.items(): print(f"\n{'='*60}") print(f"Loading {name} (cached)...") model = bundle["model"] tokenizer = bundle["tokenizer"] device = bundle["device"] print(f"Generating on '{prompt}'...") output = generate( model=model, tokenizer=tokenizer, prompt=prompt, max_new_tokens=cfg["max_new_tokens"], temperature=cfg["temperature"], top_k=cfg["top_k"], top_p=cfg["top_p"], min_p=cfg["min_p"], no_repeat_ngram_size=cfg["no_repeat_ngram_size"], repetition_penalty=cfg["repetition_penalty"], logit_soft_cap=cfg["logit_soft_cap"], loop_penalty=cfg["loop_penalty"], device=str(device), sft_mode=cfg["sft_mode"], stream=True, context_window=cfg["context_window"], ) results.append({ "name": name, "output": output, "device": device, }) print(f"\n{'='*60}") print("=" * 60) print("SIDE-BY-SIDE COMPARISON") print("=" * 60) for r in results: print(f"\n--- {r['name']} ---") print(r["output"]) print(f"\n{'='*60}") # --------------------------------------------------------------------------- # Benchmark # --------------------------------------------------------------------------- BENCHMARKS = { "blimp": { "label": "BLiMP", "desc": "Grammaticality minimal pairs (67 paradigms). Accuracy = % grammatical < ungrammatical perplexity.", "hf_dataset": ("nyu-mll/blimp", None), "metric": "accuracy", }, "wikitext2": { "label": "WikiText-2", "desc": "LM perplexity on Wikipedia test split. Lower is better.", "hf_dataset": ("Salesforce/wikitext", "wikitext-2-raw-v1"), "metric": "perplexity", }, "arc_easy": { "label": "ARC-Easy", "desc": "Multiple-choice science QA (~2.4K). Perplexity-based answer selection.", "hf_dataset": ("allenai/ai2_arc", "ARC-Easy"), "metric": "accuracy", }, } def _score_text(model: TinyMemoryLM, tokenizer: WordTokenizer, text: str, device: str) -> float: ids = tokenizer.encode(text, add_bos=True, add_eos=False) if len(ids) < 2: return float("nan") ids_t = torch.tensor([ids], dtype=torch.long, device=device) with torch.no_grad(): logits, *_ = model(ids_t) log_probs = F.log_softmax(logits[0], dim=-1) targets = ids_t[0, 1:] nll = -log_probs[range(len(targets)), targets].mean().item() return nll def _score_completion(model: TinyMemoryLM, tokenizer: WordTokenizer, context: str, completion: str, device: str) -> float: full_ids = tokenizer.encode(context + completion, add_bos=True, add_eos=False) ctx_ids = tokenizer.encode(context, add_bos=True, add_eos=False) n_ctx = len(ctx_ids) n_ref = len(full_ids) - n_ctx if n_ref <= 0: return float("nan") ids_t = torch.tensor([full_ids], dtype=torch.long, device=device) with torch.no_grad(): logits, *_ = model(ids_t) log_probs = F.log_softmax(logits[0], dim=-1) targets = ids_t[0, 1:] ref_start = n_ctx - 1 ref_end = min(ref_start + n_ref, log_probs.shape[0]) if ref_start >= ref_end: return float("nan") nll = -log_probs[ref_start:ref_end][range(ref_end - ref_start), targets[ref_start:ref_end]].mean().item() return nll BLIMP_PARADIGMS = [ "adjunct_island", "anaphor_gender_agreement", "anaphor_number_agreement", "animate_subject_passive", "animate_subject_trans", "causative", "complex_NP_island", "coordinate_structure_constraint_complex_left_branch", "coordinate_structure_constraint_object_extraction", "determiner_noun_agreement_1", "determiner_noun_agreement_2", "determiner_noun_agreement_irregular_1", "determiner_noun_agreement_irregular_2", "determiner_noun_agreement_with_adj_2", "determiner_noun_agreement_with_adj_irregular_1", "determiner_noun_agreement_with_adj_irregular_2", "determiner_noun_agreement_with_adjective_1", "distractor_agreement_relational_noun", "distractor_agreement_relative_clause", "drop_argument", "ellipsis_n_bar_1", "ellipsis_n_bar_2", "existential_there_object_raising", "existential_there_quantifiers_1", "existential_there_quantifiers_2", "existential_there_subject_raising", "expletive_it_object_raising", "inchoative", "intransitive", "irregular_past_participle_adjectives", "irregular_past_participle_verbs", "irregular_plural_subject_verb_agreement_1", "irregular_plural_subject_verb_agreement_2", "left_branch_island_echo_question", "left_branch_island_simple_question", "matrix_question_npi_licensor_present", "npi_present_1", "npi_present_2", "only_npi_licensor_present", "only_npi_scope", "passive_1", "passive_2", "principle_A_c_command", "principle_A_case_1", "principle_A_case_2", "principle_A_domain_1", "principle_A_domain_2", "principle_A_domain_3", "principle_A_reconstruction", "regular_plural_subject_verb_agreement_1", "regular_plural_subject_verb_agreement_2", "sentential_negation_npi_licensor_present", "sentential_negation_npi_scope", "sentential_subject_island", "superlative_quantifiers_1", "superlative_quantifiers_2", "tough_vs_raising_1", "tough_vs_raising_2", "transitive", "wh_island", "wh_questions_object_gap", "wh_questions_subject_gap", "wh_questions_subject_gap_long_distance", "wh_vs_that_no_gap", "wh_vs_that_no_gap_long_distance", "wh_vs_that_with_gap", "wh_vs_that_with_gap_long_distance", ] def _run_blimp(model, tokenizer, device, n_samples: int = 200) -> Tuple[List[str], List[float]]: from datasets import load_dataset # type: ignore accuracies: List[float] = [] for paradigm in BLIMP_PARADIGMS: try: ds = load_dataset("nyu-mll/blimp", paradigm, split="train") except Exception as e: print(f" {paradigm}: skip ({e})") accuracies.append(float("nan")) continue items = list(ds)[:n_samples] correct = 0 for ex in items: good_nll = _score_text(model, tokenizer, ex["sentence_good"], device) bad_nll = _score_text(model, tokenizer, ex["sentence_bad"], device) if math.isnan(good_nll) or math.isnan(bad_nll): continue if good_nll < bad_nll: correct += 1 acc = correct / len(items) if items else float("nan") accuracies.append(acc) print(f" {paradigm:50s} acc={acc:.3f}") return BLIMP_PARADIGMS, accuracies def _run_wikitext2(model, tokenizer, device, chunk_chars: int = 512, max_chunks: int = 100) -> Tuple[List[str], List[float]]: from datasets import load_dataset # type: ignore ds = load_dataset("Salesforce/wikitext", "wikitext-2-raw-v1", split="test") full_text = "\n".join(ex["text"] for ex in ds if ex["text"].strip()) chunks = [full_text[i:i + chunk_chars] for i in range(0, len(full_text), chunk_chars)] chunks = [c for c in chunks if len(c) > 20][:max_chunks] labels: List[str] = [] ppls: List[float] = [] for i, chunk in enumerate(chunks): nll = _score_text(model, tokenizer, chunk, device) ppl = math.exp(nll) if not math.isnan(nll) else float("nan") labels.append(f"chunk {i + 1}") ppls.append(ppl) if (i + 1) % 10 == 0: valid = [v for v in ppls if not math.isnan(v)] mean = sum(valid) / len(valid) if valid else float("nan") print(f" chunk {i + 1}/{len(chunks)} running mean ppl={mean:.2f}") return labels, ppls def _run_arc_easy(model, tokenizer, device, max_samples: int = 200) -> Tuple[List[str], List[float]]: from datasets import load_dataset # type: ignore ds = load_dataset("allenai/ai2_arc", "ARC-Easy", split="test") items = list(ds)[:max_samples] labels: List[str] = [] scores: List[float] = [] for i, ex in enumerate(items): question = ex["question"] choices = ex["choices"]["text"] choice_labels = ex["choices"]["label"] answer_key = ex["answerKey"] context = f"Question: {question}\nAnswer:" nlls = [_score_completion(model, tokenizer, context, f" {c}", device) for c in choices] if all(math.isnan(v) for v in nlls): scores.append(float("nan")) else: best_idx = min(range(len(nlls)), key=lambda j: nlls[j] if not math.isnan(nlls[j]) else float("inf")) predicted = choice_labels[best_idx] scores.append(1.0 if predicted == answer_key else 0.0) labels.append(f"Q{i + 1}") n_valid = sum(1 for s in scores if not math.isnan(s)) acc = sum(s for s in scores if not math.isnan(s)) / n_valid if n_valid else float("nan") print(f" {n_valid} questions evaluated, accuracy={acc:.3f}") return labels, scores def run_benchmark_mode() -> None: try: import matplotlib matplotlib.use("Agg") import matplotlib.pyplot as plt except ImportError: print("matplotlib not installed. pip install matplotlib") return bench_keys = list(BENCHMARKS.keys()) print("\nBenchmarks:") for i, k in enumerate(bench_keys): b = BENCHMARKS[k] print(f" [{i + 1}] {b['label']} — {b['desc']}") print("Select benchmark [1]:", end=" ", flush=True) try: b_choice = input().strip() or "1" except (EOFError, KeyboardInterrupt): print() return if not (b_choice.isdigit() and 1 <= int(b_choice) <= len(bench_keys)): print("Invalid selection.") return bench_key = bench_keys[int(b_choice) - 1] bench = BENCHMARKS[bench_key] print(f"Benchmark: {bench['label']}") root = Path(__file__).resolve().parent runs_dir = root / "runs" all_models = discover_models(runs_dir) model_entries: List[dict] = [] for m in all_models: model_entries.append({"label": f"[LOCAL] {m['name']}/{m['checkpoint']}", "type": "local", "meta": m}) for hf_name, hf_id in HUGGINGFACE_MODELS.items(): model_entries.append({"label": f"[HF] {hf_name}", "type": "hf", "hf_id": hf_id, "hf_name": hf_name}) if not model_entries: print("No models found.") return print("\nAvailable models:") for i, e in enumerate(model_entries): print(f" [{i + 1}] {e['label']}") print(" [a] All models") print("Select models (comma-separated or 'a'):", end=" ", flush=True) try: raw = input().strip() except (EOFError, KeyboardInterrupt): print() return if raw.lower() == "a": selected = list(range(len(model_entries))) else: selected = [] for tok in raw.split(","): tok = tok.strip() if tok.isdigit() and 1 <= int(tok) <= len(model_entries): selected.append(int(tok) - 1) if not selected: print("No valid selection.") return all_results: List[dict] = [] shared_x_labels: Optional[List[str]] = None for idx in selected: entry = model_entries[idx] print(f"\n{'='*60}\nLoading {entry['label']}...") try: if entry["type"] == "local": m = entry["meta"] bundle = load_local_model(m["model_path"], m["tokenizer_path"], m["series"]) else: bundle = load_huggingface_model(entry["hf_id"], root / ".hf_cache") except Exception as e: print(f" Failed: {e}") continue model = bundle["model"] tokenizer = bundle["tokenizer"] device = str(bundle["device"]) model.eval() if bench_key == "blimp": x_labels, y_vals = _run_blimp(model, tokenizer, device) elif bench_key == "wikitext2": x_labels, y_vals = _run_wikitext2(model, tokenizer, device) else: x_labels, y_vals = _run_arc_easy(model, tokenizer, device) if shared_x_labels is None: shared_x_labels = x_labels valid = [v for v in y_vals if not math.isnan(v)] summary = sum(valid) / len(valid) if valid else float("nan") all_results.append({"label": entry["label"], "y": y_vals, "summary": summary}) if not all_results or shared_x_labels is None: print("No results to plot.") return metric = bench["metric"] paired = sorted(zip([r["summary"] for r in all_results], [r["label"] for r in all_results]), reverse=(metric != "perplexity")) summaries, model_labels = zip(*paired) if paired else ([], []) n = len(summaries) colors = [plt.cm.RdYlGn(i / max(n - 1, 1)) for i in range(n)] fig, ax = plt.subplots(figsize=(max(6, n * 1.4), 6)) bars = ax.bar(range(n), summaries, color=colors, edgecolor="black") for bar, val in zip(bars, summaries): ax.text(bar.get_x() + bar.get_width() / 2, bar.get_height() + 0.005, f"{val:.3f}", ha="center", va="bottom", fontsize=9, fontweight="bold") ylabel = "Mean Perplexity (↓ better)" if metric == "perplexity" else "Mean Accuracy (↑ better)" ax.set_ylabel(ylabel) ax.set_title(f"{bench['label']} Benchmark — Model Comparison") ax.set_xticks(range(n)) ax.set_xticklabels(model_labels, rotation=20, ha="right", fontsize=9) if metric == "accuracy": ax.set_ylim(0, 1.05) ax.grid(True, axis="y", alpha=0.3) plt.tight_layout() out_path = root / f"benchmark_{bench_key}.png" plt.savefig(str(out_path), dpi=150) print(f"\nChart saved to {out_path}") try: import subprocess subprocess.Popen(["xdg-open", str(out_path)]) except Exception: pass # --------------------------------------------------------------------------- # Interactive CLI # --------------------------------------------------------------------------- def _pick_series(detected: str) -> str: series_list = list(MODEL_SERIES.keys()) detected_lower = detected.lower() default_idx = next( (i + 1 for i, s in enumerate(series_list) if s == detected_lower), 1 ) # Skip selection if only one series available if len(series_list) == 1: return series_list[0].capitalize() print("Series:") for i, s in enumerate(series_list): marker = " (detected)" if s == detected_lower else "" print(f" [{i + 1}] {s.capitalize()}{marker}") while True: try: choice = input(f"Select series [{default_idx}]: ").strip() except (EOFError, KeyboardInterrupt): print() sys.exit(0) if not choice: choice = str(default_idx) if choice.isdigit() and 1 <= int(choice) <= len(series_list): return series_list[int(choice) - 1].capitalize() print(f"Enter a number 1-{len(series_list)}") def pick_model(runs_dir: Path) -> tuple[dict, str]: models = discover_models(runs_dir) if not models: print(f"No models found in {runs_dir}") print("Expected layout: runs//model.pt (or pretrain.pt) + tokenizer.json") sys.exit(1) if len(models) == 1: m = models[0] print(f"Loading {m['name']}/{m['checkpoint']} ({m['series']})...") bundle = load_local_model(m["model_path"], m["tokenizer_path"], m["series"]) return bundle, m["checkpoint"] print("Available models:") for i, m in enumerate(models): print(f" [{i + 1}] {m['name']}/{m['checkpoint']} ({m['series']})") while True: try: choice = input("Select model [1]: ").strip() except (EOFError, KeyboardInterrupt): print() sys.exit(0) if not choice: choice = "1" if choice.isdigit() and 1 <= int(choice) <= len(models): m = models[int(choice) - 1] print(f"Loading {m['name']}/{m['checkpoint']} ({m['series']})...") bundle = load_local_model(m["model_path"], m["tokenizer_path"], m["series"]) return bundle, m["checkpoint"] print(f"Enter a number 1-{len(models)}") # --------------------------------------------------------------------------- # Generation mode configs # --------------------------------------------------------------------------- MODES = { "chat-coherent": { "label": "Chat — Coherent", "desc": "structured, consistent, strong repetition control", "sft_mode": "chat", "temperature": 0.35, "top_k": 20, "top_p": 0.88, "min_p": 0.10, "no_repeat_ngram_size": 4, "repetition_penalty": 1.22, "logit_soft_cap": 20.0, "loop_penalty": 20.0, "max_new_tokens": 4096, "context_window": 2048, }, "chat-variants": { "label": "Chat — Variants", "desc": "creative, diverse, more surprising outputs", "sft_mode": "chat", "temperature": 0.65, "top_k": 60, "top_p": 0.92, "min_p": 0.05, "no_repeat_ngram_size": 4, "repetition_penalty": 1.12, "logit_soft_cap": 20.0, "loop_penalty": 14.0, "max_new_tokens": 4096, "context_window": 2048, }, "pretrain-coherent": { "label": "Pretrain — Coherent", "desc": "grounded continuation, low temperature, tight sampling", "sft_mode": False, "temperature": 0.3, "top_k": 20, "top_p": 0.85, "min_p": 0.10, "no_repeat_ngram_size": 4, "repetition_penalty": 1.2, "logit_soft_cap": 20.0, "loop_penalty": 20.0, "max_new_tokens": 4096, "context_window": 2048, }, "pretrain-variants": { "label": "Pretrain — Variants", "desc": "free-form continuation, higher temperature, more exploration", "sft_mode": False, "temperature": 0.7, "top_k": 60, "top_p": 0.93, "min_p": 0.04, "no_repeat_ngram_size": 4, "repetition_penalty": 1.12, "logit_soft_cap": 20.0, "loop_penalty": 12.0, "max_new_tokens": 4096, "context_window": 2048, }, } _MODE_LIST = list(MODES.keys()) def pick_mode(is_pretrain: bool) -> dict: """Prompt the user to choose a generation mode. Returns a config dict.""" # Filter to relevant modes based on checkpoint type candidates = [k for k in _MODE_LIST if ("pretrain" in k) == is_pretrain] print("\nGeneration mode:") for i, key in enumerate(candidates): cfg = MODES[key] print(f" [{i + 1}] {cfg['label']} — {cfg['desc']}") while True: try: choice = input("Select mode [1]: ").strip() except (EOFError, KeyboardInterrupt): print() sys.exit(0) if not choice: choice = "1" if choice.isdigit() and 1 <= int(choice) <= len(candidates): key = candidates[int(choice) - 1] cfg = MODES[key] print(f"Mode: {cfg['label']}") return cfg print(f"Enter a number 1-{len(candidates)}") def _run_loop(bundle: dict, cfg: dict) -> None: model = bundle["model"] tokenizer = bundle["tokenizer"] device = bundle["device"] sft = cfg["sft_mode"] prompt_label = "You" if sft else "Prompt" print(f"\nModel ready on {device}. Type your message, or /quit to exit.") print(f" temp={cfg['temperature']} top_k={cfg['top_k']} top_p={cfg['top_p']}") print(f" min_p={cfg['min_p']} ng={cfg['no_repeat_ngram_size']} rp={cfg['repetition_penalty']}") print(f" cap={cfg['logit_soft_cap']} loop_penalty={cfg['loop_penalty']}\n") while True: try: prompt = input(f"{prompt_label}: ").strip() except (EOFError, KeyboardInterrupt): print() break if not prompt: continue if prompt in ("/quit", "/exit", "/q"): break if prompt == "/help": print("Commands: /quit /exit /q /help /mode") if sft: print("Anything else is sent as a chat prompt.") else: print("Anything else is sent as a raw continuation prompt.") continue if prompt == "/mode": print(f"Current: {cfg['label']} — {cfg['desc']}") continue print("AI: ", end="", flush=True) generate( model=model, tokenizer=tokenizer, prompt=prompt, max_new_tokens=cfg["max_new_tokens"], temperature=cfg["temperature"], top_k=cfg["top_k"], top_p=cfg["top_p"], min_p=cfg["min_p"], no_repeat_ngram_size=cfg["no_repeat_ngram_size"], repetition_penalty=cfg["repetition_penalty"], logit_soft_cap=cfg["logit_soft_cap"], loop_penalty=cfg["loop_penalty"], device=str(device), sft_mode=cfg["sft_mode"], stream=True, context_window=cfg["context_window"], ) # --------------------------------------------------------------------------- # Dynamic collection discovery # --------------------------------------------------------------------------- _COLLECTION_SLUG = "CompactAI-O/tmlm-haiku-series" _AUTHOR = "CompactAI-O" _SEARCH = "TMLM-Haiku" _FALLBACK_COLLECTION = [ {"version": "TMLM-Haiku-2.3", "hf_id": "CompactAI-O/TMLM-Haiku-2.3"}, {"version": "TMLM-Haiku-2", "hf_id": "CompactAI-O/TMLM-Haiku-2"}, {"version": "TMLM-Haiku-1.3", "hf_id": "CompactAI-O/TMLM-Haiku-1.3"}, {"version": "TMLM-Haiku-1", "hf_id": "CompactAI-O/TMLM-Haiku-1"}, {"version": "Glint-1", "hf_id": "CompactAI-O/Glint-1"}, ] _EXTRA_REPOS = ["CompactAI-O/Glint-1"] def _probe_repo(hf_id: str) -> dict | None: """Return entry dict for one repo, or None if no usable checkpoints found.""" from huggingface_hub import list_repo_files try: files = set(list_repo_files(hf_id)) except Exception: return None # Detect which subdirectory holds the checkpoints subdir: str | None = None for candidate in ("models", "model"): if any(f.startswith(f"{candidate}/") for f in files): subdir = candidate break prefix = f"{subdir}/" if subdir else "" # Collect all .pt files in the checkpoint directory pt_files = sorted( f[len(prefix):] for f in files if f.startswith(prefix) and f.endswith(".pt") ) _LABELS = { "model.pt": ("Chat (SFT)", False), "model_rep.pt": ("Chat (anti-repetition)", False), "pretrain.pt": ("Pretrain (base)", True), } checkpoints = [] for fname in pt_files: label, is_pretrain = _LABELS.get(fname, (fname.removesuffix(".pt"), "pretrain" in fname)) checkpoints.append((label, fname, is_pretrain)) if not checkpoints: return None return { "version": hf_id.split("/")[-1], "hf_id": hf_id, "subdir": subdir, "checkpoints": checkpoints, "desc": "", } def fetch_collection() -> list[dict]: """Query HF for all CompactAI-O TMLM-Haiku models, newest first.""" from huggingface_hub import HfApi print("Checking HuggingFace collection for available models...") try: api = HfApi() infos = list( api.list_models( author=_AUTHOR, search=_SEARCH, sort="lastModified", ) ) infos.sort(key=lambda m: getattr(m, "lastModified", ""), reverse=True) except Exception as exc: print(f" Could not reach HuggingFace ({exc}); using fallback list.") infos = [type("M", (), {"id": e["hf_id"]})() for e in _FALLBACK_COLLECTION] entries = [] seen_ids: set = set() for info in infos: repo_id = info.id if _SEARCH.lower() not in repo_id.lower(): continue entry = _probe_repo(repo_id) if entry: entries.append(entry) seen_ids.add(repo_id) # Always include extra repos (e.g. Glint-1) not caught by TMLM-Haiku search for repo_id in _EXTRA_REPOS: if repo_id not in seen_ids: entry = _probe_repo(repo_id) if entry: entries.append(entry) seen_ids.add(repo_id) if not entries: print(" No models found; using fallback list.") for fb in _FALLBACK_COLLECTION: e = _probe_repo(fb["hf_id"]) if e: entries.append(e) return entries # --------------------------------------------------------------------------- # Download helper # --------------------------------------------------------------------------- def _download_version(entry: dict, cache_dir: Path) -> Path: """Download full repo snapshot; return the directory containing model files.""" try: from huggingface_hub import snapshot_download except ImportError: print("huggingface_hub not installed. Run: pip install huggingface_hub") sys.exit(1) hf_id = entry["hf_id"] print(f"Fetching {hf_id} ...") try: local_dir = Path(snapshot_download(repo_id=hf_id, cache_dir=str(cache_dir))) except Exception as exc: print(f"Download failed: {exc}") sys.exit(1) subdir = entry.get("subdir") model_dir = (local_dir / subdir) if subdir else local_dir if not model_dir.exists(): # Fallback to root model_dir = local_dir return model_dir # --------------------------------------------------------------------------- # Selection prompts # --------------------------------------------------------------------------- def _prompt_int(prompt: str, lo: int, hi: int, default: int = 1) -> int: while True: try: raw = input(f"{prompt} [{default}]: ").strip() except (EOFError, KeyboardInterrupt): print() sys.exit(0) if not raw: return default if raw.isdigit() and lo <= int(raw) <= hi: return int(raw) print(f" Enter a number {lo}–{hi}.") def pick_version(collection: list[dict]) -> dict: print("\nTMLM-Haiku series (CompactAI-O)\n") for i, entry in enumerate(collection): desc = f" — {entry['desc']}" if entry["desc"] else "" print(f" [{i + 1}] {entry['version']}{desc}") idx = _prompt_int("Select version", 1, len(collection)) return collection[idx - 1] def pick_checkpoint(entry: dict) -> tuple[str, bool]: """Return (filename, is_pretrain).""" ckpts = entry["checkpoints"] if len(ckpts) == 1: label, fname, is_pretrain = ckpts[0] print(f" Using: {label} ({fname})") return fname, is_pretrain print(f"\nCheckpoints for {entry['version']}:") for i, (label, fname, _) in enumerate(ckpts): print(f" [{i + 1}] {label} ({fname})") idx = _prompt_int("Select checkpoint", 1, len(ckpts)) label, fname, is_pretrain = ckpts[idx - 1] return fname, is_pretrain # --------------------------------------------------------------------------- # Main # --------------------------------------------------------------------------- def main() -> None: import argparse parser = argparse.ArgumentParser() parser.add_argument("--compare", "-c", action="store_true") parser.add_argument("--prompt", "-p", type=str, default="Hello") mode_group = parser.add_mutually_exclusive_group() mode_group.add_argument("--pretrain", action="store_true") mode_group.add_argument("--sft", action="store_true") args, _ = parser.parse_known_args() print("=" * 56) print(" CompactAI-O Interactive Chat") print(" Models: huggingface.co/CompactAI-O") print("=" * 56) if args.compare: prefetch_huggingface_models() cfg = pick_mode(is_pretrain=args.pretrain) prompt_label = "You" if cfg["sft_mode"] else "Prompt" while True: print(f"{prompt_label}:", end=" ", flush=True) prompt = sys.stdin.readline().strip() if not prompt or prompt in ("/quit", "/exit", "/q"): break compare_all_models(prompt, cfg) return collection = fetch_collection() if not collection: print("No models found. Check your internet connection.") sys.exit(1) entry = pick_version(collection) fname, is_pretrain = pick_checkpoint(entry) if args.pretrain: is_pretrain = True elif args.sft: is_pretrain = False root = Path(__file__).resolve().parent cache_dir = root / "cache" / "huggingface" cache_dir.mkdir(parents=True, exist_ok=True) model_dir = _download_version(entry, cache_dir) model_path = model_dir / fname tokenizer_path = model_dir / "tokenizer.json" if not model_path.exists(): print(f"File not found: {model_path}") sys.exit(1) if not tokenizer_path.exists(): print(f"Tokenizer not found: {tokenizer_path}") sys.exit(1) print(f"Loading {entry['version']} / {fname} ...") bundle = load_local_model(model_path, tokenizer_path, "Haiku") # Use checkpoint-embedded sft_mode/phase if available sft_mode_flag = bundle.get("sft_mode") phase_flag = bundle.get("phase") if sft_mode_flag is not None and not args.pretrain and not args.sft: is_pretrain = not sft_mode_flag elif phase_flag is not None and not args.pretrain and not args.sft: is_pretrain = phase_flag == "pretrain" print("\nChoose action:") print(" [1] Chat with this model") print(" [2] Compare ALL models (local + HuggingFace)") print(" [3] Run Benchmark (BLiMP / WikiText-2 / ARC-Easy)") print("Select [1]:", end=" ", flush=True) choice = sys.stdin.readline().strip() or "1" if choice == "1": cfg = pick_mode(is_pretrain) _run_loop(bundle, cfg) elif choice == "2": print("\nDownloading/preparing HuggingFace models...") prefetch_huggingface_models() cfg = pick_mode(is_pretrain) prompt_label = "You" if cfg["sft_mode"] else "Prompt" while True: print(f"{prompt_label}:", end=" ", flush=True) prompt = sys.stdin.readline().strip() if not prompt or prompt in ("/quit", "/exit", "/q"): break compare_all_models(prompt, cfg) elif choice == "3": run_benchmark_mode() else: print("Enter 1, 2, or 3") if __name__ == "__main__": main()