Spaces:
Running
Running
| #!/usr/bin/env python3 | |
| """Public-facing TMLM-Haiku interactive CLI. | |
| Pulls models from the CompactAI-O HuggingFace collection: | |
| https://huggingface.co/collections/CompactAI-O/tmlm-haiku-series | |
| """ | |
| from __future__ import annotations | |
| #!/usr/bin/env python3 | |
| from __future__ import annotations | |
| import hashlib | |
| import json | |
| import math | |
| import os | |
| import string | |
| import sys | |
| from dataclasses import dataclass | |
| from pathlib import Path | |
| from typing import Dict, Iterator, List, Optional, Sequence, Tuple | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torch.utils.checkpoint import checkpoint | |
| HUGGINGFACE_MODELS = { | |
| "TMLM-Haiku-1": "CompactAI-O/TMLM-Haiku-1", | |
| "TMLM-Haiku-1.3": "CompactAI-O/TMLM-Haiku-1.3", | |
| "TMLM-Haiku-2": "CompactAI-O/TMLM-Haiku-2", | |
| "Glint-1": "CompactAI-O/Glint-1", | |
| } | |
| # --------------------------------------------------------------------------- | |
| # Config | |
| # --------------------------------------------------------------------------- | |
| class ModelConfig: | |
| dim: int = 128 | |
| n_unique_layers: int = 8 | |
| n_logical_layers: int = 16 | |
| n_heads: int = 4 | |
| n_kv_heads: int = 2 | |
| ffn_dim: int = 224 | |
| dropout: float = 0.0 | |
| seq_len: int = 2048 | |
| sliding_window_size: int = 512 | |
| mtp_horizons: Tuple[int, ...] = (2, 3, 4) | |
| rope_fraction: float = 0.5 | |
| embed_scale: bool = True | |
| logit_soft_cap: float = -1.0 | |
| quantization: str = "nvfp4" | |
| def head_dim(self) -> int: | |
| return self.dim // self.n_heads | |
| model_config = ModelConfig() | |
| MODEL_SERIES = { | |
| "haiku": { | |
| "dim": 64, | |
| "n_unique_layers": 12, | |
| "n_logical_layers": 24, | |
| "n_heads": 4, | |
| "n_kv_heads": 2, | |
| "ffn_dim": 384, | |
| "dropout": 0.0, | |
| "seq_len": 2048, | |
| "sliding_window_size": 2048, | |
| "mtp_horizons": (), | |
| "rope_fraction": 0.5, | |
| "engram_dim": 8, | |
| "engram_heads": 2, | |
| "engram_table_size": 64, | |
| "engram_max_ngram": 2, | |
| "mhc_expansion": 2, | |
| "sleep_gate_cap": 0, | |
| "sleep_gate_heads": 4, | |
| "latent_think_layers": 0, | |
| "prelude_layers": 0, | |
| "coda_layers": 0, | |
| "recurrent_loops": 0, | |
| "recurrent_act_threshold": 0.9, | |
| "recurrent_lora_rank": 0, | |
| "recurrent_loop_embed_dim": 0, | |
| }, | |
| "sonnet": { | |
| "dim": 1024, | |
| "n_unique_layers": 20, | |
| "n_logical_layers": 40, | |
| "n_heads": 16, | |
| "n_kv_heads": 4, | |
| "ffn_dim": 4096, | |
| "dropout": 0.0, | |
| "seq_len": 2048, | |
| "mtp_horizons": (2,), | |
| "engram_dim": 32, | |
| "engram_heads": 8, | |
| "engram_table_size": 4096, | |
| "engram_max_ngram": 2, | |
| "mhc_expansion": 2, | |
| "sleep_gate_cap": 0, | |
| "sleep_gate_heads": 8, | |
| "latent_think_layers": 0, | |
| "prelude_layers": 0, | |
| "coda_layers": 0, | |
| "recurrent_loops": 0, | |
| "recurrent_act_threshold": 0.99, | |
| "recurrent_lora_rank": 0, | |
| "recurrent_loop_embed_dim": 0, | |
| }, | |
| "opus": { | |
| "dim": 1536, | |
| "n_unique_layers": 18, | |
| "n_logical_layers": 36, | |
| "n_heads": 16, | |
| "n_kv_heads": 4, | |
| "ffn_dim": 5888, | |
| "dropout": 0.0, | |
| "seq_len": 2048, | |
| "mtp_horizons": (2,), | |
| "engram_dim": 64, | |
| "engram_heads": 8, | |
| "engram_table_size": 8192, | |
| "engram_max_ngram": 2, | |
| "mhc_expansion": 4, | |
| "sleep_gate_cap": 0, | |
| "sleep_gate_heads": 8, | |
| "latent_think_layers": 0, | |
| "prelude_layers": 0, | |
| "coda_layers": 0, | |
| "recurrent_loops": 0, | |
| "recurrent_act_threshold": 0.99, | |
| "recurrent_lora_rank": 0, | |
| "recurrent_loop_embed_dim": 0, | |
| }, | |
| } | |
| # --------------------------------------------------------------------------- | |
| # Tokenizer | |
| # --------------------------------------------------------------------------- | |
| FORMAT_TOKENS = [ | |
| "<|user|>", | |
| "<|assistant|>", | |
| "<|system|>", | |
| "<|start_header_id|>", | |
| "<|end_header_id|>", | |
| "<|begin_of_thought|>", | |
| "<|end_of_thought|>", | |
| "<|begin_of_solution|>", | |
| "<|end_of_solution|>", | |
| ] | |
| class WordTokenizer: | |
| def __init__( | |
| self, extra_chars: str = "", format_tokens: Optional[List[str]] = None | |
| ) -> None: | |
| base = string.ascii_letters + string.digits + string.punctuation + " \n\t\r" | |
| fallback_chars = sorted(set(base + extra_chars)) | |
| self.core_special = ["<PAD>", "<BOS>", "<EOS>", "<UNK>"] | |
| self.format_tokens = ( | |
| list(format_tokens) if format_tokens else list(FORMAT_TOKENS) | |
| ) | |
| self.special = list(self.core_special) + list(self.format_tokens) | |
| self.id_to_token: List[str] = ( | |
| list(self.core_special) + self.format_tokens + fallback_chars | |
| ) | |
| self.token_to_id: Dict[str, int] = { | |
| t: i for i, t in enumerate(self.id_to_token) | |
| } | |
| self.special_multi_tokens = sorted( | |
| [t for t in self.special if len(t) > 1], key=len, reverse=True | |
| ) | |
| self.multi_char_tokens = self.special_multi_tokens | |
| self.dynamic_additions = 0 | |
| def pad_id(self) -> int: | |
| return self.token_to_id["<PAD>"] | |
| def bos_id(self) -> int: | |
| return self.token_to_id["<BOS>"] | |
| def eos_id(self) -> int: | |
| return self.token_to_id["<EOS>"] | |
| def unk_id(self) -> int: | |
| return self.token_to_id["<UNK>"] | |
| def vocab_size(self) -> int: | |
| return len(self.id_to_token) | |
| def maybe_add_char(self, ch: str) -> bool: | |
| if ch in self.token_to_id: | |
| return False | |
| self.token_to_id[ch] = len(self.id_to_token) | |
| self.id_to_token.append(ch) | |
| self.dynamic_additions += 1 | |
| return True | |
| def iter_lexical_tokens(self, text: str) -> Iterator[str]: | |
| i = 0 | |
| n = len(text) | |
| while i < n: | |
| matched_special = False | |
| for token in self.special_multi_tokens: | |
| if text.startswith(token, i): | |
| yield token | |
| i += len(token) | |
| matched_special = True | |
| break | |
| if matched_special: | |
| continue | |
| yield text[i] | |
| i += 1 | |
| def encode( | |
| self, text: str, add_bos: bool = False, add_eos: bool = False | |
| ) -> List[int]: | |
| out: List[int] = [] | |
| if add_bos: | |
| out.append(self.bos_id) | |
| unk = self.unk_id | |
| t2i = self.token_to_id | |
| for tok in self.iter_lexical_tokens(text): | |
| out.append(t2i.get(tok, unk)) | |
| if add_eos: | |
| out.append(self.eos_id) | |
| return out | |
| def decode(self, ids: Sequence[int], skip_special: bool = True) -> str: | |
| pieces: List[str] = [] | |
| for idx in ids: | |
| if int(idx) < 0 or int(idx) >= len(self.id_to_token): | |
| continue | |
| tok = self.id_to_token[int(idx)] | |
| if skip_special and tok in self.special: | |
| continue | |
| pieces.append(tok) | |
| return "".join(pieces) | |
| def load(cls, path: Path) -> WordTokenizer: | |
| with path.open("r", encoding="utf-8") as f: | |
| data = json.load(f) | |
| format_tokens = data.get("format_tokens", FORMAT_TOKENS) | |
| tokenizer = cls(extra_chars="", format_tokens=format_tokens) | |
| tokenizer.id_to_token = data["id_to_token"] | |
| tokenizer.token_to_id = {t: i for i, t in enumerate(tokenizer.id_to_token)} | |
| tokenizer.special = list(tokenizer.core_special) + list(tokenizer.format_tokens) | |
| tokenizer.special_multi_tokens = sorted( | |
| [t for t in tokenizer.special if len(t) > 1], key=len, reverse=True | |
| ) | |
| tokenizer.multi_char_tokens = tokenizer.special_multi_tokens | |
| return tokenizer | |
| LetterTokenizer = WordTokenizer | |
| # --------------------------------------------------------------------------- | |
| # Model | |
| # --------------------------------------------------------------------------- | |
| class RMSNorm(nn.Module): | |
| def __init__(self, dim: int, eps: float = 1e-6) -> None: | |
| super().__init__() | |
| self.weight = nn.Parameter(torch.ones(dim)) | |
| self.eps = eps | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| if hasattr(torch.nn.functional, "rms_norm"): | |
| return torch.nn.functional.rms_norm( | |
| x, self.weight.shape, self.weight, self.eps | |
| ) | |
| x_fp = x.float() | |
| rms = torch.rsqrt(x_fp.pow(2).mean(dim=-1, keepdim=True) + self.eps) | |
| return (x_fp * rms).to(dtype=x.dtype) * self.weight | |
| class RotaryEmbedding(nn.Module): | |
| def __init__(self, dim: int, base: float = 10000.0) -> None: | |
| super().__init__() | |
| inv = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)) | |
| self.register_buffer("inv_freq", inv, persistent=False) | |
| def cos_sin( | |
| self, seq_len: int, device: torch.device, dtype: torch.dtype | |
| ) -> Tuple[torch.Tensor, torch.Tensor]: | |
| t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) | |
| freqs = torch.outer(t, self.inv_freq) | |
| emb = torch.cat([freqs, freqs], dim=-1) | |
| cos = emb.cos()[None, None, :, :].to(dtype=dtype) | |
| sin = emb.sin()[None, None, :, :].to(dtype=dtype) | |
| return cos, sin | |
| def _rotate_half(x: torch.Tensor) -> torch.Tensor: | |
| x1 = x[..., : x.shape[-1] // 2] | |
| x2 = x[..., x.shape[-1] // 2 :] | |
| return torch.cat((-x2, x1), dim=-1) | |
| class CausalSelfAttention(nn.Module): | |
| def __init__( | |
| self, | |
| dim: int, | |
| n_heads: int, | |
| n_kv_heads: int, | |
| head_dim: int, | |
| dropout: float, | |
| sliding_window: int, | |
| rope_fraction: float, | |
| ) -> None: | |
| super().__init__() | |
| self.dim = dim | |
| self.n_heads = n_heads | |
| self.n_kv_heads = n_kv_heads | |
| self.head_dim = head_dim | |
| self.n_rep = n_heads // n_kv_heads | |
| self.dropout = dropout | |
| self.sliding_window = sliding_window | |
| self.wq = nn.Linear(dim, n_heads * head_dim, bias=False) | |
| self.wk = nn.Linear(dim, n_kv_heads * head_dim, bias=False) | |
| self.wv = nn.Linear(dim, n_kv_heads * head_dim, bias=False) | |
| self.wo = nn.Linear(n_heads * head_dim, dim, bias=False) | |
| self.rope_dim = max(2, int(head_dim * rope_fraction) // 2 * 2) | |
| self.rope = RotaryEmbedding(self.rope_dim) | |
| self.q_norm = RMSNorm(head_dim) | |
| self.k_norm = RMSNorm(head_dim) | |
| self.output_gate = nn.Parameter(torch.ones(n_heads)) | |
| def forward( | |
| self, | |
| x: torch.Tensor, | |
| is_global: bool, | |
| past_kv: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, | |
| use_cache: bool = False, | |
| ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]: | |
| B, T, _ = x.shape | |
| q = self.wq(x).view(B, T, self.n_heads, self.head_dim) | |
| k = self.wk(x).view(B, T, self.n_kv_heads, self.head_dim) | |
| v = self.wv(x).view(B, T, self.n_kv_heads, self.head_dim) | |
| q = self.q_norm(q) | |
| k = self.k_norm(k) | |
| q = q.transpose(1, 2) | |
| k = k.transpose(1, 2) | |
| v = v.transpose(1, 2) | |
| past_len = past_kv[0].shape[2] if past_kv is not None else 0 | |
| cos, sin = self.rope.cos_sin(T + past_len, x.device, q.dtype) | |
| cos_slice = cos[:, :, past_len : past_len + T, :] | |
| sin_slice = sin[:, :, past_len : past_len + T, :] | |
| q_rope = q[..., : self.rope_dim] | |
| q_pass = q[..., self.rope_dim :] | |
| k_rope = k[..., : self.rope_dim] | |
| k_pass = k[..., self.rope_dim :] | |
| q_rope = (q_rope * cos_slice) + (_rotate_half(q_rope) * sin_slice) | |
| k_rope = (k_rope * cos_slice) + (_rotate_half(k_rope) * sin_slice) | |
| q = torch.cat([q_rope, q_pass], dim=-1) | |
| k = torch.cat([k_rope, k_pass], dim=-1) | |
| if past_kv is not None: | |
| k = torch.cat([past_kv[0], k], dim=2) | |
| v = torch.cat([past_kv[1], v], dim=2) | |
| new_kv = (k, v) if use_cache else None | |
| S = k.shape[2] | |
| if self.n_rep > 1: | |
| k = ( | |
| k[:, :, None, :, :] | |
| .expand(B, self.n_kv_heads, self.n_rep, S, self.head_dim) | |
| .reshape(B, self.n_heads, S, self.head_dim) | |
| ) | |
| v = ( | |
| v[:, :, None, :, :] | |
| .expand(B, self.n_kv_heads, self.n_rep, S, self.head_dim) | |
| .reshape(B, self.n_heads, S, self.head_dim) | |
| ) | |
| drop_p = self.dropout if (self.training and torch.is_grad_enabled()) else 0.0 | |
| if is_global: | |
| if past_kv is None and T > 1: | |
| out = F.scaled_dot_product_attention( | |
| q, k, v, is_causal=True, dropout_p=drop_p | |
| ) | |
| else: | |
| out = F.scaled_dot_product_attention(q, k, v, dropout_p=drop_p) | |
| else: | |
| T_q = q.shape[2] | |
| q_pos = torch.arange(past_len, past_len + T_q, device=q.device).unsqueeze(1) | |
| k_pos = torch.arange(S, device=q.device).unsqueeze(0) | |
| mask = (q_pos >= k_pos) & ((q_pos - k_pos) < self.sliding_window) | |
| out = F.scaled_dot_product_attention( | |
| q, k, v, attn_mask=mask.unsqueeze(0).unsqueeze(0), dropout_p=drop_p | |
| ) | |
| gate = torch.sigmoid(self.output_gate).view(1, self.n_heads, 1, 1) | |
| out = out * gate | |
| out = out.transpose(1, 2).contiguous().view(B, T, self.n_heads * self.head_dim) | |
| out = self.wo(out) | |
| return out, new_kv | |
| class SwiGLU(nn.Module): | |
| def __init__(self, dim: int, hidden_dim: int, dropout: float) -> None: | |
| super().__init__() | |
| self.gate = nn.Linear(dim, hidden_dim, bias=False) | |
| self.up = nn.Linear(dim, hidden_dim, bias=False) | |
| self.down = nn.Linear(hidden_dim, dim, bias=False) | |
| self.drop = nn.Dropout(dropout) | |
| nn.init.normal_(self.gate.weight, std=dim**-0.5) | |
| nn.init.normal_(self.up.weight, std=dim**-0.5) | |
| nn.init.normal_(self.down.weight, std=hidden_dim**-0.5) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| h = F.silu(self.gate(x)) * self.up(x) | |
| out = self.down(h) | |
| if self.training and torch.is_grad_enabled(): | |
| out = self.drop(out) | |
| return out | |
| def loop_index_embedding(h: torch.Tensor, loop_t: int, loop_dim: int, theta: float = 10000.0) -> torch.Tensor: | |
| if loop_dim <= 0: | |
| return h | |
| loop_dim = min(loop_dim, h.shape[-1]) | |
| if loop_dim % 2 == 1: | |
| loop_dim -= 1 | |
| if loop_dim <= 0: | |
| return h | |
| inv_freq = 1.0 / (theta ** (torch.arange(0, loop_dim, 2, device=h.device, dtype=h.dtype) / loop_dim)) | |
| phase = torch.tensor(float(loop_t), device=h.device, dtype=h.dtype) * inv_freq | |
| loop_embed = torch.cat([phase.sin(), phase.cos()], dim=0).view(1, 1, loop_dim) | |
| out = h.clone() | |
| out[..., :loop_dim] = out[..., :loop_dim] + loop_embed | |
| return out | |
| class DepthLoRAAdapter(nn.Module): | |
| def __init__(self, dim: int, rank: int, max_loops: int) -> None: | |
| super().__init__() | |
| self.rank = max(0, rank) | |
| if self.rank <= 0: | |
| self.down = None | |
| self.B = None | |
| self.scale = None | |
| return | |
| self.down = nn.Linear(dim, self.rank, bias=False) | |
| self.B = nn.Parameter(torch.randn(self.rank, dim) * 0.02) | |
| self.scale = nn.Embedding(max(1, max_loops), self.rank) | |
| nn.init.zeros_(self.scale.weight) | |
| def forward(self, x: torch.Tensor, loop_t: int) -> torch.Tensor: | |
| if self.rank <= 0 or self.down is None or self.B is None or self.scale is None: | |
| return torch.zeros_like(x) | |
| t_idx = min(loop_t, self.scale.num_embeddings - 1) | |
| scale = self.scale(torch.tensor(t_idx, device=x.device)) | |
| return (self.down(x) * scale) @ self.B | |
| class StableRecurrentInjection(nn.Module): | |
| def __init__(self, dim: int) -> None: | |
| super().__init__() | |
| self.log_A = nn.Parameter(torch.full((dim,), -2.0)) | |
| self.log_dt = nn.Parameter(torch.full((dim,), -2.0)) | |
| self.input_gate = nn.Parameter(torch.zeros(dim)) | |
| def forward(self, h: torch.Tensor, e: torch.Tensor, transformer_out: torch.Tensor) -> torch.Tensor: | |
| A = torch.exp(-torch.exp((self.log_dt + self.log_A).clamp(-20, 20))).view(1, 1, -1) | |
| B = torch.sigmoid(self.input_gate).view(1, 1, -1) | |
| return A * h + B * e + transformer_out | |
| class AdaptiveHalting(nn.Module): | |
| def __init__(self, dim: int) -> None: | |
| super().__init__() | |
| self.halt = nn.Linear(dim, 1, bias=True) | |
| nn.init.zeros_(self.halt.weight) | |
| nn.init.constant_(self.halt.bias, -2.0) | |
| def forward(self, h: torch.Tensor) -> torch.Tensor: | |
| return torch.sigmoid(self.halt(h)).squeeze(-1) | |
| class EngramBlock(nn.Module): | |
| """DeepSeek Engram: conditional memory via O(1) hashed N-gram lookup. | |
| Stores common token-pair/triplet patterns in an embedding table and | |
| retrieves them with multi-head hashing. A context-aware gate (using the | |
| current hidden state as query) decides how much of the retrieved memory | |
| to inject into the residual stream. | |
| Reference: DeepSeek-AI, "Conditional Memory via Scalable Lookup" (2025). | |
| """ | |
| def __init__( | |
| self, | |
| dim: int, | |
| engram_dim: int, | |
| n_heads: int = 4, | |
| table_size: int = 8192, | |
| max_ngram: int = 3, | |
| ) -> None: | |
| super().__init__() | |
| self.dim = dim | |
| self.engram_dim = engram_dim | |
| self.n_heads = n_heads | |
| self.table_size = table_size | |
| self.max_ngram = max_ngram | |
| # One embedding table per (ngram_order, hash_head) | |
| self.embeddings = nn.ParameterDict() | |
| for n in range(2, max_ngram + 1): | |
| for k in range(n_heads): | |
| self.embeddings[f"{n}_{k}"] = nn.Parameter( | |
| torch.randn(table_size, engram_dim) * (engram_dim**-0.5) | |
| ) | |
| # Fixed hash parameters (non-learnable, deterministic) | |
| for n in range(2, max_ngram + 1): | |
| for k in range(n_heads): | |
| seed = int(hashlib.md5(f"engram_{n}_{k}".encode()).hexdigest()[:8], 16) | |
| rng = torch.Generator().manual_seed(seed) | |
| a = torch.randint(1, 2**31, (1,), generator=rng).item() | |
| b = torch.randint(0, 2**31, (1,), generator=rng).item() | |
| self.register_buffer( | |
| f"hash_a_{n}_{k}", torch.tensor(a), persistent=False | |
| ) | |
| self.register_buffer( | |
| f"hash_b_{n}_{k}", torch.tensor(b), persistent=False | |
| ) | |
| # Causal convolution over N-gram branch outputs (kernel=4, dilation=max_ngram) | |
| total_branch_dim = engram_dim * n_heads * (max_ngram - 1) | |
| self.branch_conv = nn.Conv1d( | |
| total_branch_dim, | |
| total_branch_dim, | |
| kernel_size=4, | |
| dilation=max_ngram, | |
| padding=0, | |
| groups=total_branch_dim, | |
| bias=True, | |
| ) | |
| nn.init.zeros_(self.branch_conv.weight) | |
| nn.init.zeros_(self.branch_conv.bias) | |
| # Context-aware gating: hidden state as query, memory as key/value | |
| self.gate_query = nn.Linear(dim, engram_dim, bias=False) | |
| self.gate_key = nn.Linear(total_branch_dim, engram_dim, bias=False) | |
| self.gate_value = nn.Linear(total_branch_dim, dim, bias=False) | |
| self.gate_scale = engram_dim**-0.5 | |
| def _hash_ngram(self, token_ids: torch.Tensor, n: int, k: int) -> torch.Tensor: | |
| """Hash n-gram token sequences into table indices. | |
| Args: | |
| token_ids: (B, T) token IDs | |
| n: n-gram order (2 = bigram, 3 = trigram) | |
| k: hash head index | |
| Returns: | |
| indices: (B, T) integer indices into embedding table | |
| """ | |
| a = getattr(self, f"hash_a_{n}_{k}") | |
| b = getattr(self, f"hash_b_{n}_{k}") | |
| B, T = token_ids.shape | |
| # Pad left with zeros so every position has a valid n-gram | |
| padded = F.pad(token_ids, (n - 1, 0), value=0) # (B, T+n-1) | |
| # Polynomial rolling hash | |
| combined = torch.zeros(B, T, dtype=torch.long, device=token_ids.device) | |
| for i in range(n): | |
| combined = combined * 31 + padded[:, i : i + T].long() | |
| indices = ((a * combined) ^ b) % self.table_size | |
| return indices | |
| def forward( | |
| self, hidden: torch.Tensor, token_ids: Optional[torch.Tensor] = None | |
| ) -> torch.Tensor: | |
| """Forward pass. | |
| Args: | |
| hidden: (B, T, dim) current hidden state | |
| token_ids: (B, T) input token IDs for n-gram hashing. | |
| If None, uses argmax of hidden projections as proxy. | |
| Returns: | |
| output: (B, T, dim) memory injection for residual stream | |
| """ | |
| B, T, _ = hidden.shape | |
| if token_ids is None: | |
| # Fallback: derive pseudo-token-ids from hidden state | |
| token_ids = hidden.mean(dim=-1).long() % self.table_size | |
| # Retrieve and concatenate across n-gram orders and hash heads | |
| branch_outputs = [] | |
| for n in range(2, self.max_ngram + 1): | |
| for k in range(self.n_heads): | |
| indices = self._hash_ngram(token_ids, n, k) # (B, T) | |
| table = self.embeddings[f"{n}_{k}"] # (table_size, engram_dim) | |
| retrieved = table[indices] # (B, T, engram_dim) | |
| branch_outputs.append(retrieved) | |
| # (B, T, engram_dim * n_heads * (max_ngram - 1)) | |
| memory = torch.cat(branch_outputs, dim=-1) | |
| # Causal convolution over sequence dimension | |
| # Pad left for causality (kernel_size - 1 = 3) | |
| conv_in = memory.transpose(1, 2) # (B, C, T) | |
| conv_in = F.pad( | |
| conv_in, | |
| ((self.branch_conv.kernel_size[0] - 1) * self.branch_conv.dilation[0], 0), | |
| ) | |
| conv_out = self.branch_conv(conv_in) # (B, C, T) | |
| memory = conv_out.transpose(1, 2) # (B, T, C) | |
| # Context-aware gating | |
| query = self.gate_query(hidden) # (B, T, engram_dim) | |
| key = self.gate_key(memory) # (B, T, engram_dim) | |
| gate = torch.sigmoid( | |
| (query * key).sum(dim=-1, keepdim=True) * self.gate_scale | |
| ) # (B, T, 1) | |
| value = self.gate_value(memory) # (B, T, dim) | |
| return gate * value | |
| class SleepGate(nn.Module): | |
| """Persistent memory + periodic consolidation gate.""" | |
| def __init__( | |
| self, | |
| dim: int, | |
| cap: int = 128, | |
| n_heads: int = 4, | |
| retention_enabled: bool = True, | |
| retention_hidden: int = 0, | |
| ) -> None: | |
| super().__init__() | |
| self.dim = dim | |
| self.cap = cap | |
| self.n_heads = n_heads | |
| self.head_dim = dim // n_heads | |
| self.scale = self.head_dim ** -0.5 | |
| self.retention_enabled = retention_enabled | |
| self.register_buffer("mem_emb", torch.zeros(cap, dim, dtype=torch.bfloat16)) | |
| self.register_buffer("mem_age", torch.zeros(cap, dtype=torch.long)) | |
| self.register_buffer("mem_beta", torch.ones(cap, dtype=torch.float32)) | |
| self.register_buffer("mem_count", torch.zeros((), dtype=torch.long)) | |
| self.register_buffer("mem_head", torch.zeros((), dtype=torch.long)) | |
| self.register_buffer("global_step", torch.zeros((), dtype=torch.long)) | |
| self.q_proj = nn.Linear(dim, dim, bias=False) | |
| self.k_proj = nn.Linear(dim, dim, bias=False) | |
| self.v_proj = nn.Linear(dim, dim, bias=False) | |
| self.o_proj = nn.Linear(dim, dim, bias=False) | |
| nn.init.zeros_(self.o_proj.weight) | |
| self.gate_scale = nn.Parameter(torch.zeros(())) | |
| if retention_enabled: | |
| if retention_hidden > 0: | |
| self.retention_gate: Optional[nn.Module] = nn.Sequential( | |
| nn.Linear(dim, retention_hidden, bias=False), | |
| nn.GELU(), | |
| nn.Linear(retention_hidden, 1, bias=True), | |
| ) | |
| nn.init.constant_(self.retention_gate[-1].bias, 2.2) | |
| else: | |
| self.retention_gate = nn.Linear(dim, 1, bias=True) | |
| nn.init.constant_(self.retention_gate.bias, 2.2) | |
| else: | |
| self.retention_gate = None | |
| self._last_beta: Optional[torch.Tensor] = None | |
| def write(self, hidden: torch.Tensor) -> None: | |
| B, T, _ = hidden.shape | |
| tail_full = hidden[:, max(0, T - 16):, :].float().mean(dim=1) | |
| if self.retention_gate is not None: | |
| beta_live = torch.sigmoid(self.retention_gate(tail_full).squeeze(-1)) | |
| self._last_beta = beta_live if self.training else None | |
| beta_store = beta_live.detach().float() | |
| else: | |
| self._last_beta = None | |
| beta_store = torch.ones(B, device=hidden.device, dtype=torch.float32) | |
| tail = tail_full.to(self.mem_emb.dtype).detach() | |
| with torch.no_grad(): | |
| head = int(self.mem_head.item()) | |
| count = int(self.mem_count.item()) | |
| step = int(self.global_step.item()) | |
| for b in range(B): | |
| self.mem_emb[head] = tail[b] | |
| self.mem_age[head] = step | |
| self.mem_beta[head] = beta_store[b] | |
| head = (head + 1) % self.cap | |
| if count < self.cap: | |
| count += 1 | |
| self.mem_head.fill_(head) | |
| self.mem_count.fill_(count) | |
| def read(self, x: torch.Tensor) -> torch.Tensor: | |
| count = int(self.mem_count.item()) | |
| if count == 0: | |
| return torch.zeros_like(x) | |
| B, T, D = x.shape | |
| mem = self.mem_emb[:count].clone().to(x.dtype) | |
| q = self.q_proj(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2) | |
| k = self.k_proj(mem).view(count, self.n_heads, self.head_dim).transpose(0, 1) | |
| v = self.v_proj(mem).view(count, self.n_heads, self.head_dim).transpose(0, 1) | |
| attn = torch.einsum("bhtd,hmd->bhtm", q, k) * self.scale | |
| attn = F.softmax(attn, dim=-1) | |
| if self.retention_enabled: | |
| step = int(self.global_step.item()) | |
| ages = self.mem_age[:count].to(x.device) | |
| delta = (step - ages).clamp(min=0).to(x.dtype) | |
| betas = self.mem_beta[:count].to(x.dtype).clamp(min=1e-6, max=1.0) | |
| weights = betas.pow(delta) | |
| attn = attn * weights.view(1, 1, 1, count) | |
| attn = attn / attn.sum(dim=-1, keepdim=True).clamp_min(1e-9) | |
| out = torch.einsum("bhtm,hmd->bhtd", attn, v) | |
| out = out.transpose(1, 2).contiguous().view(B, T, D) | |
| out = self.o_proj(out) | |
| return torch.sigmoid(self.gate_scale) * out | |
| def reset(self) -> None: | |
| self.mem_emb.zero_() | |
| self.mem_age.zero_() | |
| self.mem_beta.fill_(1.0) | |
| self.mem_count.zero_() | |
| self.mem_head.zero_() | |
| self.global_step.zero_() | |
| self._last_beta = None | |
| def _sinkhorn_knopp(logits: torch.Tensor, n_iters: int = 7) -> torch.Tensor: | |
| M = torch.exp(logits.clamp(-10, 10)) | |
| for _ in range(n_iters): | |
| M = M / M.sum(dim=-1, keepdim=True).clamp(min=1e-10) | |
| M = M / M.sum(dim=-2, keepdim=True).clamp(min=1e-10) | |
| return M | |
| class ManifoldHyperConnection(nn.Module): | |
| def __init__(self, dim: int, expansion: int = 2) -> None: | |
| super().__init__() | |
| self.dim = dim | |
| self.expansion = expansion | |
| n = expansion | |
| self.expand_fn = "duplicate" | |
| self.collapse_fn = "mean" | |
| self.bias_pre = nn.Parameter(torch.zeros(1, n)) | |
| self.bias_post = nn.Parameter(torch.zeros(1, n)) | |
| self.bias_res = nn.Parameter(torch.zeros(n, n)) | |
| self.theta_pre = nn.Linear(n * dim, n, bias=False) | |
| self.theta_post = nn.Linear(n * dim, n, bias=False) | |
| self.theta_res = nn.Linear(n * dim, n * n, bias=False) | |
| self.alpha_pre = nn.Parameter(torch.tensor(0.0)) | |
| self.alpha_post = nn.Parameter(torch.tensor(0.0)) | |
| self.alpha_res = nn.Parameter(torch.tensor(0.0)) | |
| def _compute_mappings( | |
| self, x_expanded: torch.Tensor | |
| ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | |
| B, T, _ = x_expanded.shape | |
| n = self.expansion | |
| x_norm = F.rms_norm(x_expanded, [x_expanded.shape[-1]]) | |
| d_pre = torch.tanh(self.theta_pre(x_norm)) | |
| d_post = torch.tanh(self.theta_post(x_norm)) | |
| d_res = self.theta_res(x_norm) | |
| H_pre_raw = torch.sigmoid(self.alpha_pre * d_pre + self.bias_pre) | |
| H_post_raw = 2.0 * torch.sigmoid(self.alpha_post * d_post + self.bias_post) | |
| H_res_raw = (self.alpha_res * d_res + self.bias_res.reshape(1, 1, -1)).reshape( | |
| B, T, n, n | |
| ) | |
| H_res = _sinkhorn_knopp(H_res_raw) | |
| return H_pre_raw.unsqueeze(-2), H_post_raw.unsqueeze(-2), H_res | |
| def expand_stream(self, x: torch.Tensor) -> torch.Tensor: | |
| return x.repeat(1, 1, self.expansion) | |
| def collapse_stream(self, x_expanded: torch.Tensor) -> torch.Tensor: | |
| B, T, _ = x_expanded.shape | |
| n = self.expansion | |
| C = self.dim | |
| return x_expanded.view(B, T, n, C).mean(dim=-2) | |
| def pre_mix(self, x_expanded: torch.Tensor, H_pre: torch.Tensor) -> torch.Tensor: | |
| B, T, _ = x_expanded.shape | |
| n = self.expansion | |
| x_streams = x_expanded.view(B, T, n, self.dim) | |
| return (H_pre @ x_streams).squeeze(-2) | |
| def post_res_mix( | |
| self, | |
| layer_output: torch.Tensor, | |
| x_expanded: torch.Tensor, | |
| H_post: torch.Tensor, | |
| H_res: torch.Tensor, | |
| ) -> torch.Tensor: | |
| B, T, _ = x_expanded.shape | |
| n = self.expansion | |
| C = self.dim | |
| x_streams = x_expanded.view(B, T, n, C) | |
| mixed = torch.matmul(H_res, x_streams) | |
| post_out = torch.matmul(H_post.transpose(-2, -1), layer_output.unsqueeze(-2)) | |
| result = mixed + post_out | |
| return result.reshape(B, T, n * C) | |
| class TransformerBlock(nn.Module): | |
| def __init__( | |
| self, | |
| dim: int, | |
| n_heads: int, | |
| n_kv_heads: int, | |
| head_dim: int, | |
| ffn_dim: int, | |
| dropout: float, | |
| sliding_window: int, | |
| rope_fraction: float, | |
| engram_dim: int = 0, | |
| engram_heads: int = 4, | |
| engram_table_size: int = 8192, | |
| engram_max_ngram: int = 3, | |
| mhc_expansion: int = 1, | |
| ) -> None: | |
| super().__init__() | |
| self.norm1 = RMSNorm(dim) | |
| self.attn = CausalSelfAttention( | |
| dim=dim, | |
| n_heads=n_heads, | |
| n_kv_heads=n_kv_heads, | |
| head_dim=head_dim, | |
| dropout=dropout, | |
| sliding_window=sliding_window, | |
| rope_fraction=rope_fraction, | |
| ) | |
| self.norm2 = RMSNorm(dim) | |
| self.ffn = SwiGLU(dim, ffn_dim, dropout) | |
| self.use_engram = engram_dim > 0 | |
| if self.use_engram: | |
| self.engram = EngramBlock( | |
| dim=dim, | |
| engram_dim=engram_dim, | |
| n_heads=engram_heads, | |
| table_size=engram_table_size, | |
| max_ngram=engram_max_ngram, | |
| ) | |
| self.engram_norm = RMSNorm(dim) | |
| self.use_mhc = mhc_expansion > 1 | |
| if self.use_mhc: | |
| self.mhc_attn = ManifoldHyperConnection(dim, expansion=mhc_expansion) | |
| self.mhc_ffn = ManifoldHyperConnection(dim, expansion=mhc_expansion) | |
| def forward( | |
| self, | |
| x: torch.Tensor, | |
| is_global: bool, | |
| past_kv: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, | |
| use_cache: bool = False, | |
| token_ids: Optional[torch.Tensor] = None, | |
| ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]: | |
| if self.use_mhc: | |
| x_exp = self.mhc_attn.expand_stream(x) | |
| H_pre, H_post, H_res = self.mhc_attn._compute_mappings(x_exp) | |
| attn_in = self.mhc_attn.pre_mix(x_exp, H_pre) | |
| attn_out, new_kv = self.attn( | |
| self.norm1(attn_in), is_global, past_kv, use_cache | |
| ) | |
| x_exp = self.mhc_attn.post_res_mix(attn_out, x_exp, H_post, H_res) | |
| if self.use_engram: | |
| collapsed = self.mhc_attn.collapse_stream(x_exp) | |
| collapsed = collapsed + self.engram( | |
| self.engram_norm(collapsed), token_ids=token_ids | |
| ) | |
| x_exp = self.mhc_attn.expand_stream(collapsed) | |
| H_pre2, H_post2, H_res2 = self.mhc_ffn._compute_mappings(x_exp) | |
| ffn_in = self.mhc_ffn.pre_mix(x_exp, H_pre2) | |
| ffn_out = self.ffn(self.norm2(ffn_in)) | |
| x_exp = self.mhc_ffn.post_res_mix(ffn_out, x_exp, H_post2, H_res2) | |
| x = self.mhc_attn.collapse_stream(x_exp) | |
| else: | |
| attn_out, new_kv = self.attn(self.norm1(x), is_global, past_kv, use_cache) | |
| x = x + attn_out | |
| if self.use_engram: | |
| x = x + self.engram(self.engram_norm(x), token_ids=token_ids) | |
| x = x + self.ffn(self.norm2(x)) | |
| return x, new_kv | |
| class RecurrentDepthBlock(nn.Module): | |
| def __init__( | |
| self, | |
| dim: int, | |
| n_heads: int, | |
| n_kv_heads: int, | |
| head_dim: int, | |
| ffn_dim: int, | |
| dropout: float, | |
| sliding_window: int, | |
| rope_fraction: float, | |
| n_loops: int, | |
| act_threshold: float, | |
| lora_rank: int, | |
| loop_embed_dim: int, | |
| ) -> None: | |
| super().__init__() | |
| self.n_loops = max(1, n_loops) | |
| self.act_threshold = act_threshold | |
| self.loop_embed_dim = max(0, loop_embed_dim) | |
| self.norm = RMSNorm(dim) | |
| self.block = TransformerBlock( | |
| dim=dim, n_heads=n_heads, n_kv_heads=n_kv_heads, head_dim=head_dim, | |
| ffn_dim=ffn_dim, dropout=dropout, sliding_window=sliding_window, | |
| rope_fraction=rope_fraction, engram_dim=0, mhc_expansion=1, | |
| ) | |
| self.injection = StableRecurrentInjection(dim) | |
| self.act = AdaptiveHalting(dim) | |
| self.lora = DepthLoRAAdapter(dim, lora_rank, self.n_loops) | |
| def forward( | |
| self, | |
| h: torch.Tensor, | |
| e: torch.Tensor, | |
| token_ids: Optional[torch.Tensor] = None, | |
| past_key_values: Optional[List[Optional[Tuple[torch.Tensor, torch.Tensor]]]] = None, | |
| use_cache: bool = False, | |
| n_loops: Optional[int] = None, | |
| ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], Optional[List[Tuple[torch.Tensor, torch.Tensor]]]]: | |
| loops = max(1, n_loops or self.n_loops) | |
| B, T, _ = h.shape | |
| halted = torch.zeros(B, T, device=h.device, dtype=torch.bool) | |
| cumulative_p = torch.zeros(B, T, device=h.device, dtype=h.dtype) | |
| output = torch.zeros_like(h) | |
| new_past: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = [] if use_cache else None | |
| current = h | |
| final_halt = None | |
| for t in range(loops): | |
| h_loop = loop_index_embedding(current, t, self.loop_embed_dim) | |
| combined = self.norm(h_loop + e) | |
| past_kv = None | |
| if past_key_values is not None and t < len(past_key_values): | |
| past_kv = past_key_values[t] | |
| trans_out, layer_kv = self.block(combined, is_global=True, past_kv=past_kv, use_cache=use_cache, token_ids=token_ids) | |
| trans_out = trans_out + self.lora(trans_out, t) | |
| next_h = self.injection(current, e, trans_out) | |
| p = self.act(next_h) | |
| p = p * (~halted).to(p.dtype) | |
| final_halt = p | |
| should_halt = (~halted) & ((cumulative_p + p) >= self.act_threshold) | |
| update_weight = torch.where(should_halt, (1.0 - cumulative_p).clamp(min=0.0), p) | |
| output = output + next_h * update_weight.unsqueeze(-1) | |
| cumulative_p = cumulative_p + update_weight | |
| current = torch.where(halted.unsqueeze(-1), current, next_h) | |
| halted = halted | should_halt | |
| if new_past is not None: | |
| new_past.append(layer_kv) | |
| if not use_cache and bool(halted.all()): | |
| break | |
| remainder = (1.0 - cumulative_p).clamp(min=0.0) | |
| output = output + current * remainder.unsqueeze(-1) | |
| aux: Dict[str, torch.Tensor] = {} | |
| if final_halt is not None: | |
| aux["recurrent_halt_mean"] = final_halt.mean() | |
| return output, aux, new_past | |
| class TinyMemoryLM(nn.Module): | |
| def __init__( | |
| self, | |
| vocab_size: int, | |
| dim: int, | |
| n_unique_layers: int, | |
| n_logical_layers: int, | |
| n_heads: int, | |
| n_kv_heads: int, | |
| ffn_dim: int, | |
| dropout: float, | |
| mtp_horizons: Sequence[int], | |
| grad_checkpoint: bool, | |
| sliding_window: int = 512, | |
| rope_fraction: float = 0.5, | |
| embed_scale: bool = True, | |
| engram_dim: int = 0, | |
| engram_heads: int = 4, | |
| engram_table_size: int = 8192, | |
| engram_max_ngram: int = 3, | |
| mhc_expansion: int = 1, | |
| sleep_gate_cap: int = 0, | |
| sleep_gate_heads: int = 4, | |
| sleep_retention_enabled: bool = True, | |
| sleep_retention_hidden: int = 0, | |
| latent_think_layers: int = 0, | |
| prelude_layers: int = 0, | |
| coda_layers: int = 0, | |
| recurrent_loops: int = 0, | |
| recurrent_act_threshold: float = 0.99, | |
| recurrent_lora_rank: int = 0, | |
| recurrent_loop_embed_dim: int = 0, | |
| ) -> None: | |
| super().__init__() | |
| self.dim = dim | |
| self.n_unique_layers = n_unique_layers | |
| self.n_logical_layers = n_logical_layers | |
| self.grad_checkpoint = grad_checkpoint | |
| self.embed_scale_factor = math.sqrt(dim) if embed_scale else 1.0 | |
| head_dim = dim // n_heads | |
| self.embed_tokens = nn.Embedding(vocab_size, dim) | |
| self.head = nn.Linear(dim, vocab_size, bias=False) | |
| self.head.weight = self.embed_tokens.weight | |
| self.output_bias = nn.Parameter(torch.zeros(vocab_size)) | |
| self.use_recurrent_depth = recurrent_loops > 0 | |
| self.prelude_layers = max(0, prelude_layers) | |
| self.coda_layers = max(0, coda_layers) | |
| self.recurrent_loops = max(0, recurrent_loops) | |
| self.blocks: Optional[nn.ModuleList] = None | |
| self.prelude: Optional[nn.ModuleList] = None | |
| self.recurrent: Optional[RecurrentDepthBlock] = None | |
| self.coda: Optional[nn.ModuleList] = None | |
| def _make_blocks(n: int) -> nn.ModuleList: | |
| return nn.ModuleList([ | |
| TransformerBlock( | |
| dim=dim, n_heads=n_heads, n_kv_heads=n_kv_heads, head_dim=head_dim, | |
| ffn_dim=ffn_dim, dropout=dropout, sliding_window=sliding_window, | |
| rope_fraction=rope_fraction, engram_dim=engram_dim, | |
| engram_heads=engram_heads, engram_table_size=engram_table_size, | |
| engram_max_ngram=engram_max_ngram, mhc_expansion=mhc_expansion, | |
| ) | |
| for _ in range(n) | |
| ]) | |
| if self.use_recurrent_depth: | |
| if self.prelude_layers > 0: | |
| self.prelude = _make_blocks(self.prelude_layers) | |
| self.recurrent = RecurrentDepthBlock( | |
| dim=dim, n_heads=n_heads, n_kv_heads=n_kv_heads, head_dim=head_dim, | |
| ffn_dim=ffn_dim, dropout=dropout, sliding_window=sliding_window, | |
| rope_fraction=rope_fraction, n_loops=self.recurrent_loops, | |
| act_threshold=recurrent_act_threshold, lora_rank=recurrent_lora_rank, | |
| loop_embed_dim=recurrent_loop_embed_dim or max(2, dim // 8), | |
| ) | |
| if self.coda_layers > 0: | |
| self.coda = _make_blocks(self.coda_layers) | |
| else: | |
| self.blocks = _make_blocks(max(1, n_unique_layers)) | |
| self.norm = RMSNorm(dim) | |
| self.mtp_horizons = sorted({int(h) for h in mtp_horizons if int(h) > 1}) | |
| self.mtp_adapters = nn.ModuleDict( | |
| {str(h): nn.Linear(dim, dim, bias=False) for h in self.mtp_horizons} | |
| ) | |
| self.mtp_norms = nn.ModuleDict( | |
| {str(h): RMSNorm(dim) for h in self.mtp_horizons} | |
| ) | |
| res_scale = (2 * max(1, n_logical_layers)) ** -0.5 | |
| for group in (self.blocks, self.prelude, self.coda): | |
| if group is None: | |
| continue | |
| for block in group: | |
| block.attn.wo.weight.data.mul_(res_scale) | |
| block.ffn.down.weight.data.mul_(res_scale) | |
| if self.recurrent is not None: | |
| self.recurrent.block.attn.wo.weight.data.mul_(res_scale) | |
| self.recurrent.block.ffn.down.weight.data.mul_(res_scale) | |
| self.sleep_gate: Optional[SleepGate] = None | |
| if sleep_gate_cap > 0: | |
| self.sleep_gate = SleepGate( | |
| dim=dim, cap=sleep_gate_cap, n_heads=sleep_gate_heads, | |
| retention_enabled=sleep_retention_enabled, | |
| retention_hidden=sleep_retention_hidden, | |
| ) | |
| self.think_blocks: Optional[nn.ModuleList] = None | |
| self.think_norm: Optional[RMSNorm] = None | |
| if latent_think_layers > 0: | |
| self.think_blocks = nn.ModuleList([ | |
| TransformerBlock( | |
| dim=dim, n_heads=n_heads, n_kv_heads=n_kv_heads, head_dim=head_dim, | |
| ffn_dim=ffn_dim, dropout=0.0, sliding_window=2048, | |
| rope_fraction=rope_fraction, engram_dim=0, mhc_expansion=1, | |
| ) | |
| for _ in range(latent_think_layers) | |
| ]) | |
| self.think_norm = RMSNorm(dim) | |
| def resize_token_embeddings(self, new_vocab_size: int) -> None: | |
| old_vocab_size = self.embed_tokens.num_embeddings | |
| if new_vocab_size == old_vocab_size: | |
| return | |
| device = self.embed_tokens.weight.device | |
| old_embed_weight = self.embed_tokens.weight.data.clone() | |
| self.embed_tokens = nn.Embedding(new_vocab_size, self.embed_tokens.embedding_dim).to(device) | |
| self.head = nn.Linear(self.embed_tokens.embedding_dim, new_vocab_size, bias=False).to(device) | |
| self.head.weight = self.embed_tokens.weight | |
| old_bias = self.output_bias.data.clone() | |
| self.output_bias = nn.Parameter(torch.zeros(new_vocab_size, device=device)) | |
| copy_size = min(old_vocab_size, new_vocab_size) | |
| self.output_bias.data[:copy_size] = old_bias[:copy_size] | |
| self.embed_tokens.weight.data[:copy_size] = old_embed_weight[:copy_size] | |
| def _build_logical_layers(self) -> List[Tuple[nn.Module, int]]: | |
| if self.blocks is None: | |
| return [] | |
| blocks_list = list(self.blocks) | |
| full_sequence = blocks_list + blocks_list | |
| return [(block, i) for i, block in enumerate(full_sequence[: self.n_logical_layers])] | |
| def forward( | |
| self, | |
| ids: torch.Tensor, | |
| use_cache: bool = False, | |
| past_key_values: Optional[List[Optional[Tuple[torch.Tensor, torch.Tensor]]]] = None, | |
| return_hidden: bool = False, | |
| ) -> Tuple[torch.Tensor, Dict[int, torch.Tensor], Dict[str, torch.Tensor], Optional[torch.Tensor], Optional[List[Tuple[torch.Tensor, torch.Tensor]]]]: | |
| B, T = ids.shape | |
| x = self.embed_tokens(ids) * self.embed_scale_factor | |
| new_past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = [] if use_cache else None | |
| aux: Dict[str, torch.Tensor] = {} | |
| if self.use_recurrent_depth: | |
| offset = 0 | |
| if self.prelude is not None: | |
| for block in self.prelude: | |
| past_kv = past_key_values[offset] if past_key_values is not None and offset < len(past_key_values) else None | |
| x, layer_kv = block(x, is_global=True, past_kv=past_kv, use_cache=use_cache, token_ids=ids) | |
| if new_past_key_values is not None: | |
| new_past_key_values.append(layer_kv) | |
| offset += 1 | |
| encoded = x | |
| recurrent_past = past_key_values[offset: offset + self.recurrent_loops] if past_key_values is not None else None | |
| x, recurrent_aux, recurrent_kv = self.recurrent( | |
| x, encoded, token_ids=ids, past_key_values=recurrent_past, use_cache=use_cache, | |
| ) | |
| aux.update(recurrent_aux) | |
| if new_past_key_values is not None and recurrent_kv is not None: | |
| new_past_key_values.extend(recurrent_kv) | |
| offset += self.recurrent_loops | |
| if self.coda is not None: | |
| for block in self.coda: | |
| past_kv = past_key_values[offset] if past_key_values is not None and offset < len(past_key_values) else None | |
| x, layer_kv = block(x, is_global=True, past_kv=past_kv, use_cache=use_cache, token_ids=ids) | |
| if new_past_key_values is not None: | |
| new_past_key_values.append(layer_kv) | |
| offset += 1 | |
| else: | |
| logical_layers = self._build_logical_layers() | |
| last_logical_idx = len(logical_layers) - 1 | |
| for layer_idx, (block, logical_idx) in enumerate(logical_layers): | |
| is_global = logical_idx % 2 == 0 or layer_idx == last_logical_idx | |
| past_kv = past_key_values[layer_idx] if past_key_values is not None and layer_idx < len(past_key_values) else None | |
| if self.grad_checkpoint and self.training and not use_cache: | |
| x, layer_kv = checkpoint(block, x, is_global, past_kv, use_cache, ids, use_reentrant=True) | |
| else: | |
| x, layer_kv = block(x, is_global, past_kv, use_cache, ids) | |
| if new_past_key_values is not None: | |
| new_past_key_values.append(layer_kv) | |
| x = self.norm(x) | |
| if self.sleep_gate is not None: | |
| x = x + self.sleep_gate.read(x) | |
| if self.training: | |
| self.sleep_gate.write(x) | |
| if self.think_blocks is not None: | |
| for think_block in self.think_blocks: | |
| x, _ = think_block(x, is_global=True) | |
| x = self.think_norm(x) | |
| h_out = x if return_hidden else None | |
| logits = self.head(x) | |
| if self.embed_scale_factor != 1.0: | |
| logits = logits / self.embed_scale_factor | |
| logits = logits + self.output_bias | |
| mtp: Dict[int, torch.Tensor] = {} | |
| if self.mtp_horizons and self.training: | |
| for horizon in self.mtp_horizons: | |
| if horizon > 1 and horizon <= T - 1: | |
| shifted_h = x[:, :-horizon, :] | |
| adapted_h = self.mtp_adapters[str(horizon)](shifted_h) | |
| adapted_h = self.mtp_norms[str(horizon)](adapted_h) | |
| mtp_logits = self.head(adapted_h) | |
| if self.embed_scale_factor != 1.0: | |
| mtp_logits = mtp_logits / self.embed_scale_factor | |
| mtp_logits = mtp_logits + self.output_bias | |
| mtp[horizon] = mtp_logits | |
| return logits, mtp, aux, h_out, new_past_key_values | |
| # --------------------------------------------------------------------------- | |
| # Generation | |
| # --------------------------------------------------------------------------- | |
| def build_stop_token_ids(tokenizer: WordTokenizer) -> set: | |
| stop_tokens = {tokenizer.eos_id} | |
| for tok in ("<|user|>", "<|system|>", "<|assistant|>"): | |
| tid = tokenizer.token_to_id.get(tok) | |
| if tid is not None: | |
| stop_tokens.add(int(tid)) | |
| return stop_tokens | |
| def apply_no_repeat_ngram( | |
| logits: torch.Tensor, | |
| token_history: Sequence[int], | |
| ngram_size: int, | |
| ) -> torch.Tensor: | |
| if ngram_size <= 1 or len(token_history) < max(0, ngram_size - 1): | |
| return logits | |
| prefix = tuple(token_history[-(ngram_size - 1) :]) if ngram_size > 1 else tuple() | |
| banned: set = set() | |
| for i in range(len(token_history) - ngram_size + 1): | |
| if tuple(token_history[i : i + ngram_size - 1]) == prefix: | |
| banned.add(int(token_history[i + ngram_size - 1])) | |
| if not banned: | |
| return logits | |
| out = logits.clone() | |
| banned_ids = torch.tensor(sorted(banned), device=logits.device, dtype=torch.long) | |
| out[banned_ids] = float("-inf") | |
| return out | |
| def apply_loop_penalty( | |
| logits: torch.Tensor, | |
| tokenizer: WordTokenizer, | |
| generated_text: str, | |
| penalty: float = 5.0, | |
| ) -> torch.Tensor: | |
| """Detect repeated substring loops and penalise continuation tokens.""" | |
| if len(generated_text) < 16: | |
| return logits | |
| out = logits.clone() | |
| for span_len in [24, 16, 12, 8]: | |
| if len(generated_text) < span_len * 2: | |
| continue | |
| suffix = generated_text[-span_len:] | |
| prev = generated_text[:-span_len].rfind(suffix) | |
| if prev == -1: | |
| continue | |
| next_pos = prev + span_len | |
| if next_pos < len(generated_text): | |
| next_char = generated_text[next_pos] | |
| tid = tokenizer.token_to_id.get(next_char) | |
| if tid is not None: | |
| out[tid] -= penalty | |
| break | |
| return out | |
| def apply_min_p(logits: torch.Tensor, min_p: float) -> torch.Tensor: | |
| """Filter tokens below min_p fraction of the top token probability.""" | |
| if min_p <= 0.0: | |
| return logits | |
| probs = torch.softmax(logits, dim=-1) | |
| threshold = probs.max() * min_p | |
| out = logits.clone() | |
| out[probs < threshold] = float("-inf") | |
| return out | |
| def generate( | |
| model: TinyMemoryLM, | |
| tokenizer: WordTokenizer, | |
| prompt: str, | |
| max_new_tokens: int = 256, | |
| temperature: float = 0.8, | |
| top_k: int = 16, | |
| top_p: float = 0.95, | |
| repetition_penalty: float = 1.0, | |
| device: str = "cuda", | |
| sft_mode: bool = True, | |
| stream: bool = True, | |
| no_repeat_ngram_size: int = 0, | |
| context_window: int = 2048, | |
| logit_soft_cap: float = 15.0, | |
| min_p: float = 0.05, | |
| loop_penalty: float = 5.0, | |
| ) -> str: | |
| if sft_mode: | |
| full_prompt = f"<|user|>\n{prompt}\n<|assistant|>\n" | |
| else: | |
| full_prompt = prompt | |
| input_ids = tokenizer.encode(full_prompt, add_bos=True, add_eos=False) | |
| input_ids_t = torch.tensor([input_ids], dtype=torch.long, device=device) | |
| visible_tokens: List[str] = [] | |
| stop_token_ids = build_stop_token_ids(tokenizer) | |
| generated_text = "" | |
| generated_ids: List[int] = [] | |
| # Full history (prompt + generated) for ngram blocking — prevents echoing prompt | |
| full_ids_history: List[int] = list(input_ids) | |
| with torch.no_grad(): | |
| for _ in range(max_new_tokens): | |
| ctx_ids = ( | |
| input_ids_t[:, -context_window:] if context_window > 0 else input_ids_t | |
| ) | |
| logits, *_ = model(ctx_ids) | |
| next_logits = logits[0, -1, :].clone() | |
| # Logit soft-capping (Gemma-style) — prevents overconfident collapse | |
| if logit_soft_cap > 0: | |
| next_logits = logit_soft_cap * torch.tanh(next_logits / logit_soft_cap) | |
| raw_next_logits = next_logits.clone() | |
| # Repetition penalty on previously generated tokens | |
| if repetition_penalty != 1.0 and generated_ids: | |
| for tok_id in set(generated_ids): | |
| if next_logits[tok_id] > 0: | |
| next_logits[tok_id] /= repetition_penalty | |
| else: | |
| next_logits[tok_id] *= repetition_penalty | |
| # No-repeat n-gram blocking on generated tokens only | |
| if no_repeat_ngram_size > 0 and generated_ids: | |
| next_logits = apply_no_repeat_ngram(next_logits, generated_ids, no_repeat_ngram_size) | |
| # Substring loop detection | |
| next_logits = apply_loop_penalty(next_logits, tokenizer, generated_text, penalty=loop_penalty) | |
| # Temperature scaling | |
| if temperature != 1.0: | |
| next_logits = next_logits / max(temperature, 1e-6) | |
| # Min-p filtering — remove tokens below min_p * max_prob | |
| if min_p > 0: | |
| next_logits = apply_min_p(next_logits, min_p) | |
| # Top-k filtering | |
| if top_k > 0: | |
| v, _ = torch.topk(next_logits, min(top_k, next_logits.size(0))) | |
| next_logits[next_logits < v[-1]] = float("-inf") | |
| # Top-p (nucleus) filtering | |
| if 0.0 < top_p < 1.0: | |
| sorted_logits, sorted_indices = torch.sort(next_logits, descending=True) | |
| sorted_probs = torch.softmax(sorted_logits, dim=-1) | |
| cumulative_probs = torch.cumsum(sorted_probs, dim=-1) | |
| remove_mask = cumulative_probs > top_p | |
| remove_mask[0] = False | |
| indices_to_remove = sorted_indices[remove_mask] | |
| next_logits[indices_to_remove] = float("-inf") | |
| # Fallback if all tokens masked | |
| if not torch.isfinite(next_logits).any(): | |
| next_logits = raw_next_logits | |
| if temperature != 1.0: | |
| next_logits = next_logits / max(temperature, 1e-6) | |
| if temperature == 0: | |
| next_id = torch.argmax(next_logits).item() | |
| else: | |
| probs = torch.softmax(next_logits, dim=-1) | |
| next_id = torch.multinomial(probs, num_samples=1).item() | |
| if next_id in stop_token_ids: | |
| break | |
| token_str = ( | |
| tokenizer.id_to_token[next_id] | |
| if next_id < len(tokenizer.id_to_token) | |
| else "" | |
| ) | |
| generated_ids.append(next_id) | |
| full_ids_history.append(next_id) | |
| if token_str not in tokenizer.special: | |
| visible_tokens.append(token_str) | |
| generated_text += token_str | |
| if stream: | |
| print(token_str, end="", flush=True) | |
| input_ids_t = torch.cat( | |
| [input_ids_t, torch.tensor([[next_id]], device=device)], dim=1 | |
| ) | |
| if stream: | |
| print() | |
| return "".join(visible_tokens) | |
| # --------------------------------------------------------------------------- | |
| # Local model loading | |
| # --------------------------------------------------------------------------- | |
| def series_from_name(name: str) -> str | None: | |
| lower = (name or "").lower() | |
| if "haiku" in lower: | |
| return "Haiku" | |
| if "sonnet" in lower: | |
| return "Sonnet" | |
| if "opus" in lower: | |
| return "Opus" | |
| return None | |
| def series_config(series: str) -> dict[str, object]: | |
| return MODEL_SERIES.get(series.lower(), MODEL_SERIES["sonnet"]) | |
| def discover_models(runs_dir: Path) -> List[dict]: | |
| models = [] | |
| if not runs_dir.is_dir(): | |
| return models | |
| for child in sorted(runs_dir.iterdir()): | |
| if not child.is_dir(): | |
| continue | |
| tokenizer_path = child / "tokenizer.json" | |
| if not tokenizer_path.exists(): | |
| continue | |
| name = child.name | |
| series = None | |
| for ckpt_name in ("model.pt", "pretrain.pt"): | |
| ckpt_path = child / ckpt_name | |
| if ckpt_path.exists(): | |
| series = _fast_series_from_checkpoint(ckpt_path) | |
| break | |
| if series is None: | |
| series = series_from_name(name) or "Sonnet" | |
| found = False | |
| for ckpt_name in ("model.pt", "model_rep.pt", "pretrain.pt"): | |
| ckpt_path = child / ckpt_name | |
| if ckpt_path.exists(): | |
| models.append( | |
| { | |
| "name": name, | |
| "checkpoint": ckpt_name, | |
| "series": series, | |
| "model_path": ckpt_path, | |
| "tokenizer_path": tokenizer_path, | |
| } | |
| ) | |
| found = True | |
| if not found: | |
| step_ckpts = sorted( | |
| child.glob("checkpoint_step_*.pt"), | |
| key=lambda p: int(p.stem.rsplit("_", 1)[-1]), | |
| ) | |
| if step_ckpts: | |
| ckpt_path = step_ckpts[-1] | |
| models.append( | |
| { | |
| "name": name, | |
| "checkpoint": ckpt_path.name, | |
| "series": series, | |
| "model_path": ckpt_path, | |
| "tokenizer_path": tokenizer_path, | |
| } | |
| ) | |
| return models | |
| def _detect_engram(state_dict): | |
| for key in state_dict: | |
| if ".engram." in key: | |
| if ".embeddings." in key: | |
| return state_dict[key].shape[-1] | |
| return 0 | |
| def _detect_mhc(state_dict): | |
| for key, val in state_dict.items(): | |
| if ".mhc_attn.bias_pre" in key and val.dim() == 2: | |
| return val.shape[-1] # (1, expansion) | |
| return 1 | |
| def _detect_sleep_gate(state_dict) -> Tuple[int, int]: | |
| for key, val in state_dict.items(): | |
| if key == "sleep_gate.mem_emb" and val.dim() == 2: | |
| cap = val.shape[0] | |
| return cap, 4 | |
| return 0, 4 | |
| def _detect_latent_think(state_dict) -> int: | |
| indices = { | |
| int(k.split(".")[1]) | |
| for k in state_dict | |
| if k.startswith("think_blocks.") and k.split(".")[1].isdigit() | |
| } | |
| return max(indices) + 1 if indices else 0 | |
| def _detect_prelude_layers(state_dict) -> int: | |
| indices = { | |
| int(k.split(".")[1]) | |
| for k in state_dict | |
| if k.startswith("prelude.") and k.split(".")[1].isdigit() | |
| } | |
| return max(indices) + 1 if indices else 0 | |
| def _detect_coda_layers(state_dict) -> int: | |
| indices = { | |
| int(k.split(".")[1]) | |
| for k in state_dict | |
| if k.startswith("coda.") and k.split(".")[1].isdigit() | |
| } | |
| return max(indices) + 1 if indices else 0 | |
| def _detect_recurrent_loops(state_dict) -> int: | |
| if "recurrent.norm.weight" in state_dict or "recurrent.block.attn.wq.weight" in state_dict: | |
| if "recurrent.lora.scale.weight" in state_dict: | |
| return state_dict["recurrent.lora.scale.weight"].shape[0] | |
| return 1 | |
| return 0 | |
| def _detect_recurrent_lora_rank(state_dict) -> int: | |
| for key in ("recurrent.lora.B", "recurrent.lora.down.weight"): | |
| if key in state_dict: | |
| shape = state_dict[key].shape | |
| if len(shape) == 2: | |
| return int(shape[0]) | |
| return 0 | |
| def _infer_series_from_lora_rank(rank: int) -> str | None: | |
| if rank == 0: | |
| return None | |
| if rank <= 8: | |
| return "haiku" | |
| if rank <= 16: | |
| return "sonnet" | |
| return "opus" | |
| def _fast_series_from_checkpoint(ckpt_path: Path) -> str | None: | |
| try: | |
| cp = torch.load(ckpt_path, map_location="cpu", weights_only=False) | |
| sd = cp.get("model_state", cp.get("state_dict", {})) | |
| rank = 0 | |
| for key in ("recurrent.lora.B", "recurrent.lora.down.weight"): | |
| if key in sd: | |
| rank = int(sd[key].shape[0]) | |
| break | |
| if rank == 0: | |
| return None | |
| if rank <= 8: | |
| return "Haiku" | |
| if rank <= 16: | |
| return "Sonnet" | |
| return "Opus" | |
| except Exception: | |
| pass | |
| return None | |
| def _infer_arch_from_state_dict(state_dict, cfg): | |
| """Infer architecture hyper-parameters directly from checkpoint weights, | |
| falling back to *cfg* (series config) when a key is not found.""" | |
| overrides = {} | |
| has_prelude = any(k.startswith("prelude.") for k in state_dict) | |
| has_blocks = any(k.startswith("blocks.") for k in state_dict) | |
| has_recurrent = any(k.startswith("recurrent.") for k in state_dict) | |
| uses_recurrent_arch = has_prelude and has_recurrent and not has_blocks | |
| # dim from embed_tokens.weight [vocab, dim] | |
| if "embed_tokens.weight" in state_dict: | |
| overrides["dim"] = state_dict["embed_tokens.weight"].shape[1] | |
| if uses_recurrent_arch: | |
| if "prelude.0.ffn.gate.weight" in state_dict: | |
| overrides["ffn_dim"] = state_dict["prelude.0.ffn.gate.weight"].shape[0] | |
| overrides["n_unique_layers"] = 0 | |
| src = "prelude.0" | |
| else: | |
| if "blocks.0.ffn.gate.weight" in state_dict: | |
| overrides["ffn_dim"] = state_dict["blocks.0.ffn.gate.weight"].shape[0] | |
| block_ids = { | |
| int(k.split(".")[1]) | |
| for k in state_dict | |
| if k.startswith("blocks.") and k.split(".")[1].isdigit() | |
| } | |
| if block_ids: | |
| overrides["n_unique_layers"] = max(block_ids) + 1 | |
| src = "blocks.0" | |
| dim = overrides.get("dim", int(cfg.get("dim", model_config.dim))) | |
| if f"{src}.attn.wq.weight" in state_dict: | |
| wq_rows = state_dict[f"{src}.attn.wq.weight"].shape[0] | |
| if f"{src}.attn.q_norm.weight" in state_dict: | |
| head_dim = state_dict[f"{src}.attn.q_norm.weight"].shape[0] | |
| overrides["n_heads"] = wq_rows // head_dim | |
| if f"{src}.attn.wk.weight" in state_dict: | |
| wk_rows = state_dict[f"{src}.attn.wk.weight"].shape[0] | |
| if f"{src}.attn.k_norm.weight" in state_dict: | |
| head_dim = state_dict[f"{src}.attn.k_norm.weight"].shape[0] | |
| overrides["n_kv_heads"] = wk_rows // head_dim | |
| # engram params | |
| for key, val in state_dict.items(): | |
| if ".engram.embeddings." in key and key.endswith("_0") and val.dim() == 2: | |
| overrides["engram_table_size"] = val.shape[0] | |
| overrides["engram_dim"] = val.shape[1] | |
| break | |
| engram_dim = overrides.get("engram_dim", int(cfg.get("engram_dim", 0))) | |
| engram_max_ngram = int(cfg.get("engram_max_ngram", 2)) | |
| if engram_dim > 0: | |
| for key, val in state_dict.items(): | |
| if ".engram.branch_conv.weight" in key and val.dim() == 3: | |
| total_branch_dim = val.shape[0] | |
| denom = engram_dim * (engram_max_ngram - 1) | |
| if denom > 0 and total_branch_dim % denom == 0: | |
| overrides["engram_heads"] = total_branch_dim // denom | |
| break | |
| merged = dict(cfg) | |
| merged.update(overrides) | |
| return merged | |
| def load_local_model(model_path: Path, tokenizer_path: Path, series: str) -> dict: | |
| tokenizer = WordTokenizer.load(tokenizer_path) | |
| ckpt = torch.load(str(model_path), map_location="cpu", weights_only=False) | |
| cfg = series_config(series) | |
| vocab_size = int(ckpt.get("vocab_size", tokenizer.vocab_size)) | |
| state_dict = ckpt.get("model_state") or ckpt.get("state_dict") or ckpt | |
| cfg = _infer_arch_from_state_dict(state_dict, cfg) | |
| engram_dim = int(cfg.get("engram_dim", 0)) | |
| if _detect_engram(state_dict) == 0: | |
| engram_dim = 0 | |
| mhc_expansion = _detect_mhc(state_dict) | |
| if mhc_expansion == 1: | |
| mhc_expansion = int(cfg.get("mhc_expansion", 1)) | |
| ckpt_sleep_cap, ckpt_sleep_heads = _detect_sleep_gate(state_dict) | |
| sleep_gate_cap = ckpt_sleep_cap if ckpt_sleep_cap > 0 else int(cfg.get("sleep_gate_cap", 0)) | |
| sleep_gate_heads = ckpt_sleep_heads if ckpt_sleep_cap > 0 else int(cfg.get("sleep_gate_heads", 4)) | |
| sleep_retention_enabled = bool(cfg.get("sleep_retention_enabled", True)) | |
| sleep_retention_hidden = int(cfg.get("sleep_retention_hidden", 0)) | |
| latent_think_layers = _detect_latent_think(state_dict) | |
| if latent_think_layers == 0: | |
| latent_think_layers = int(cfg.get("latent_think_layers", 0)) | |
| prelude_layers = _detect_prelude_layers(state_dict) | |
| coda_layers = _detect_coda_layers(state_dict) | |
| recurrent_loops = _detect_recurrent_loops(state_dict) | |
| ckpt_lora_rank = _detect_recurrent_lora_rank(state_dict) | |
| if ckpt_lora_rank > 0: | |
| inferred_series = _infer_series_from_lora_rank(ckpt_lora_rank) | |
| if inferred_series and inferred_series != series.lower(): | |
| series = inferred_series.capitalize() | |
| cfg = series_config(series) | |
| recurrent_lora_rank = ckpt_lora_rank | |
| else: | |
| recurrent_lora_rank = int(cfg.get("recurrent_lora_rank", 0)) | |
| recurrent_act_threshold = float(cfg.get("recurrent_act_threshold", 0.99)) | |
| recurrent_loop_embed_dim = int(cfg.get("recurrent_loop_embed_dim", 0)) | |
| n_unique = int(cfg.get("n_unique_layers", model_config.n_unique_layers)) | |
| model = TinyMemoryLM( | |
| vocab_size=vocab_size, | |
| dim=int(cfg.get("dim", model_config.dim)), | |
| n_unique_layers=n_unique, | |
| n_logical_layers=int(cfg.get("n_logical_layers", model_config.n_logical_layers)), | |
| n_heads=int(cfg.get("n_heads", model_config.n_heads)), | |
| n_kv_heads=int(cfg.get("n_kv_heads", model_config.n_kv_heads)), | |
| ffn_dim=int(cfg.get("ffn_dim", model_config.ffn_dim)), | |
| dropout=float(cfg.get("dropout", model_config.dropout)), | |
| mtp_horizons=tuple(int(v) for v in cfg.get("mtp_horizons", model_config.mtp_horizons)), | |
| grad_checkpoint=False, | |
| sliding_window=int(cfg.get("sliding_window_size", getattr(model_config, "sliding_window_size", 512))), | |
| rope_fraction=float(cfg.get("rope_fraction", getattr(model_config, "rope_fraction", 0.25))), | |
| embed_scale=bool(cfg.get("embed_scale", getattr(model_config, "embed_scale", True))), | |
| engram_dim=engram_dim, | |
| engram_heads=int(cfg.get("engram_heads", 4)), | |
| engram_table_size=int(cfg.get("engram_table_size", 8192)), | |
| engram_max_ngram=int(cfg.get("engram_max_ngram", 3)), | |
| mhc_expansion=mhc_expansion, | |
| sleep_gate_cap=sleep_gate_cap, | |
| sleep_gate_heads=sleep_gate_heads, | |
| sleep_retention_enabled=sleep_retention_enabled, | |
| sleep_retention_hidden=sleep_retention_hidden, | |
| latent_think_layers=latent_think_layers, | |
| prelude_layers=prelude_layers, | |
| coda_layers=coda_layers, | |
| recurrent_loops=recurrent_loops, | |
| recurrent_act_threshold=recurrent_act_threshold, | |
| recurrent_lora_rank=recurrent_lora_rank, | |
| recurrent_loop_embed_dim=recurrent_loop_embed_dim, | |
| ) | |
| model.load_state_dict(state_dict, strict=False) | |
| model.eval() | |
| if tokenizer.vocab_size > vocab_size: | |
| model.resize_token_embeddings(tokenizer.vocab_size) | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| model = model.to(device) | |
| return { | |
| "model": model, | |
| "tokenizer": tokenizer, | |
| "device": device, | |
| "series": series, | |
| "sft_mode": ckpt.get("sft_mode", None), | |
| "phase": ckpt.get("phase", None), | |
| } | |
| # --------------------------------------------------------------------------- | |
| # HuggingFace Model Download & Loading | |
| # --------------------------------------------------------------------------- | |
| def download_huggingface_model(hf_id: str, cache_dir: Path) -> dict: | |
| try: | |
| from huggingface_hub import snapshot_download | |
| except ImportError: | |
| print("huggingface_hub not installed. Install with: pip install huggingface_hub") | |
| sys.exit(1) | |
| print(f"Downloading {hf_id}...") | |
| try: | |
| local_dir = Path(snapshot_download(repo_id=hf_id, cache_dir=str(cache_dir))) | |
| except Exception as e: | |
| print(f"Failed to download {hf_id}: {e}") | |
| return None | |
| print(f"Using cached {hf_id} from {local_dir}") | |
| # Check common subdirectory names: "models/", "model/" | |
| if (local_dir / "models").exists(): | |
| model_dir = local_dir / "models" | |
| elif (local_dir / "model").exists(): | |
| model_dir = local_dir / "model" | |
| else: | |
| model_dir = local_dir | |
| model_path = model_dir / "model.pt" | |
| pretrain_path = model_dir / "pretrain.pt" | |
| tokenizer_path = model_dir / "tokenizer.json" | |
| ckpt_path = None | |
| for p in [model_path, pretrain_path]: | |
| if p.exists(): | |
| ckpt_path = p | |
| break | |
| if ckpt_path is None or not tokenizer_path.exists(): | |
| print(f"Missing model files in {model_dir}") | |
| print(f" model.pt exists: {model_path.exists()}") | |
| print(f" pretrain.pt exists: {pretrain_path.exists()}") | |
| print(f" tokenizer.json exists: {tokenizer_path.exists()}") | |
| return None | |
| return { | |
| "model_path": ckpt_path, | |
| "tokenizer_path": tokenizer_path, | |
| "model_name": ckpt_path.stem, | |
| } | |
| def load_huggingface_model(hf_id: str, cache_dir: Path) -> dict: | |
| files = download_huggingface_model(hf_id, cache_dir) | |
| if files is None: | |
| return None | |
| return load_local_model(files["model_path"], files["tokenizer_path"], "Haiku") | |
| # --------------------------------------------------------------------------- | |
| # Compare All Models | |
| # --------------------------------------------------------------------------- | |
| _hf_model_cache: Dict[str, dict] = {} | |
| def prefetch_huggingface_models() -> None: | |
| root = Path(__file__).resolve().parent | |
| cache_dir = root / "cache" / "huggingface" | |
| cache_dir.mkdir(parents=True, exist_ok=True) | |
| print("Downloading/preparing HuggingFace models...") | |
| for name, hf_id in HUGGINGFACE_MODELS.items(): | |
| print(f" {name}...") | |
| bundle = load_huggingface_model(hf_id, cache_dir) | |
| if bundle: | |
| _hf_model_cache[name] = bundle | |
| print(f"Prepared {len(_hf_model_cache)} HuggingFace models") | |
| def compare_all_models(prompt: str, cfg: dict) -> None: | |
| root = Path(__file__).resolve().parent | |
| runs_dir = root / "runs" | |
| all_models = discover_models(runs_dir) | |
| is_pretrain = not cfg.get("sft_mode", True) | |
| local_models = [ | |
| m for m in all_models | |
| if ("pretrain" in m["checkpoint"]) == is_pretrain | |
| ] | |
| if not local_models: | |
| print("No models found matching mode.") | |
| return | |
| results: List[dict] = [] | |
| for m in local_models: | |
| print(f"\n{'='*60}") | |
| print(f"Loading local {m['name']}/{m['checkpoint']}...") | |
| try: | |
| bundle = load_local_model(m["model_path"], m["tokenizer_path"], m["series"]) | |
| except Exception as e: | |
| print(f"Failed to load {m['name']}: {e}") | |
| continue | |
| model = bundle["model"] | |
| tokenizer = bundle["tokenizer"] | |
| device = bundle["device"] | |
| print(f"Generating on '{prompt}'...") | |
| output = generate( | |
| model=model, | |
| tokenizer=tokenizer, | |
| prompt=prompt, | |
| max_new_tokens=cfg["max_new_tokens"], | |
| temperature=cfg["temperature"], | |
| top_k=cfg["top_k"], | |
| top_p=cfg["top_p"], | |
| min_p=cfg["min_p"], | |
| no_repeat_ngram_size=cfg["no_repeat_ngram_size"], | |
| repetition_penalty=cfg["repetition_penalty"], | |
| logit_soft_cap=cfg["logit_soft_cap"], | |
| loop_penalty=cfg["loop_penalty"], | |
| device=str(device), | |
| sft_mode=cfg["sft_mode"], | |
| stream=True, | |
| context_window=cfg["context_window"], | |
| ) | |
| results.append({ | |
| "name": f"[LOCAL] {m['name']}/{m['checkpoint']}", | |
| "output": output, | |
| "device": device, | |
| }) | |
| for name, bundle in _hf_model_cache.items(): | |
| print(f"\n{'='*60}") | |
| print(f"Loading {name} (cached)...") | |
| model = bundle["model"] | |
| tokenizer = bundle["tokenizer"] | |
| device = bundle["device"] | |
| print(f"Generating on '{prompt}'...") | |
| output = generate( | |
| model=model, | |
| tokenizer=tokenizer, | |
| prompt=prompt, | |
| max_new_tokens=cfg["max_new_tokens"], | |
| temperature=cfg["temperature"], | |
| top_k=cfg["top_k"], | |
| top_p=cfg["top_p"], | |
| min_p=cfg["min_p"], | |
| no_repeat_ngram_size=cfg["no_repeat_ngram_size"], | |
| repetition_penalty=cfg["repetition_penalty"], | |
| logit_soft_cap=cfg["logit_soft_cap"], | |
| loop_penalty=cfg["loop_penalty"], | |
| device=str(device), | |
| sft_mode=cfg["sft_mode"], | |
| stream=True, | |
| context_window=cfg["context_window"], | |
| ) | |
| results.append({ | |
| "name": name, | |
| "output": output, | |
| "device": device, | |
| }) | |
| print(f"\n{'='*60}") | |
| print("=" * 60) | |
| print("SIDE-BY-SIDE COMPARISON") | |
| print("=" * 60) | |
| for r in results: | |
| print(f"\n--- {r['name']} ---") | |
| print(r["output"]) | |
| print(f"\n{'='*60}") | |
| # --------------------------------------------------------------------------- | |
| # Benchmark | |
| # --------------------------------------------------------------------------- | |
| BENCHMARKS = { | |
| "blimp": { | |
| "label": "BLiMP", | |
| "desc": "Grammaticality minimal pairs (67 paradigms). Accuracy = % grammatical < ungrammatical perplexity.", | |
| "hf_dataset": ("nyu-mll/blimp", None), | |
| "metric": "accuracy", | |
| }, | |
| "wikitext2": { | |
| "label": "WikiText-2", | |
| "desc": "LM perplexity on Wikipedia test split. Lower is better.", | |
| "hf_dataset": ("Salesforce/wikitext", "wikitext-2-raw-v1"), | |
| "metric": "perplexity", | |
| }, | |
| "arc_easy": { | |
| "label": "ARC-Easy", | |
| "desc": "Multiple-choice science QA (~2.4K). Perplexity-based answer selection.", | |
| "hf_dataset": ("allenai/ai2_arc", "ARC-Easy"), | |
| "metric": "accuracy", | |
| }, | |
| } | |
| def _score_text(model: TinyMemoryLM, tokenizer: WordTokenizer, text: str, device: str) -> float: | |
| ids = tokenizer.encode(text, add_bos=True, add_eos=False) | |
| if len(ids) < 2: | |
| return float("nan") | |
| ids_t = torch.tensor([ids], dtype=torch.long, device=device) | |
| with torch.no_grad(): | |
| logits, *_ = model(ids_t) | |
| log_probs = F.log_softmax(logits[0], dim=-1) | |
| targets = ids_t[0, 1:] | |
| nll = -log_probs[range(len(targets)), targets].mean().item() | |
| return nll | |
| def _score_completion(model: TinyMemoryLM, tokenizer: WordTokenizer, context: str, completion: str, device: str) -> float: | |
| full_ids = tokenizer.encode(context + completion, add_bos=True, add_eos=False) | |
| ctx_ids = tokenizer.encode(context, add_bos=True, add_eos=False) | |
| n_ctx = len(ctx_ids) | |
| n_ref = len(full_ids) - n_ctx | |
| if n_ref <= 0: | |
| return float("nan") | |
| ids_t = torch.tensor([full_ids], dtype=torch.long, device=device) | |
| with torch.no_grad(): | |
| logits, *_ = model(ids_t) | |
| log_probs = F.log_softmax(logits[0], dim=-1) | |
| targets = ids_t[0, 1:] | |
| ref_start = n_ctx - 1 | |
| ref_end = min(ref_start + n_ref, log_probs.shape[0]) | |
| if ref_start >= ref_end: | |
| return float("nan") | |
| nll = -log_probs[ref_start:ref_end][range(ref_end - ref_start), targets[ref_start:ref_end]].mean().item() | |
| return nll | |
| BLIMP_PARADIGMS = [ | |
| "adjunct_island", "anaphor_gender_agreement", "anaphor_number_agreement", | |
| "animate_subject_passive", "animate_subject_trans", "causative", | |
| "complex_NP_island", "coordinate_structure_constraint_complex_left_branch", | |
| "coordinate_structure_constraint_object_extraction", | |
| "determiner_noun_agreement_1", "determiner_noun_agreement_2", | |
| "determiner_noun_agreement_irregular_1", "determiner_noun_agreement_irregular_2", | |
| "determiner_noun_agreement_with_adj_2", "determiner_noun_agreement_with_adj_irregular_1", | |
| "determiner_noun_agreement_with_adj_irregular_2", "determiner_noun_agreement_with_adjective_1", | |
| "distractor_agreement_relational_noun", "distractor_agreement_relative_clause", | |
| "drop_argument", "ellipsis_n_bar_1", "ellipsis_n_bar_2", | |
| "existential_there_object_raising", "existential_there_quantifiers_1", | |
| "existential_there_quantifiers_2", "existential_there_subject_raising", | |
| "expletive_it_object_raising", "inchoative", "intransitive", | |
| "irregular_past_participle_adjectives", "irregular_past_participle_verbs", | |
| "irregular_plural_subject_verb_agreement_1", "irregular_plural_subject_verb_agreement_2", | |
| "left_branch_island_echo_question", "left_branch_island_simple_question", | |
| "matrix_question_npi_licensor_present", "npi_present_1", "npi_present_2", | |
| "only_npi_licensor_present", "only_npi_scope", "passive_1", "passive_2", | |
| "principle_A_c_command", "principle_A_case_1", "principle_A_case_2", | |
| "principle_A_domain_1", "principle_A_domain_2", "principle_A_domain_3", | |
| "principle_A_reconstruction", "regular_plural_subject_verb_agreement_1", | |
| "regular_plural_subject_verb_agreement_2", "sentential_negation_npi_licensor_present", | |
| "sentential_negation_npi_scope", "sentential_subject_island", | |
| "superlative_quantifiers_1", "superlative_quantifiers_2", | |
| "tough_vs_raising_1", "tough_vs_raising_2", "transitive", "wh_island", | |
| "wh_questions_object_gap", "wh_questions_subject_gap", | |
| "wh_questions_subject_gap_long_distance", "wh_vs_that_no_gap", | |
| "wh_vs_that_no_gap_long_distance", "wh_vs_that_with_gap", | |
| "wh_vs_that_with_gap_long_distance", | |
| ] | |
| def _run_blimp(model, tokenizer, device, n_samples: int = 200) -> Tuple[List[str], List[float]]: | |
| from datasets import load_dataset # type: ignore | |
| accuracies: List[float] = [] | |
| for paradigm in BLIMP_PARADIGMS: | |
| try: | |
| ds = load_dataset("nyu-mll/blimp", paradigm, split="train") | |
| except Exception as e: | |
| print(f" {paradigm}: skip ({e})") | |
| accuracies.append(float("nan")) | |
| continue | |
| items = list(ds)[:n_samples] | |
| correct = 0 | |
| for ex in items: | |
| good_nll = _score_text(model, tokenizer, ex["sentence_good"], device) | |
| bad_nll = _score_text(model, tokenizer, ex["sentence_bad"], device) | |
| if math.isnan(good_nll) or math.isnan(bad_nll): | |
| continue | |
| if good_nll < bad_nll: | |
| correct += 1 | |
| acc = correct / len(items) if items else float("nan") | |
| accuracies.append(acc) | |
| print(f" {paradigm:50s} acc={acc:.3f}") | |
| return BLIMP_PARADIGMS, accuracies | |
| def _run_wikitext2(model, tokenizer, device, chunk_chars: int = 512, max_chunks: int = 100) -> Tuple[List[str], List[float]]: | |
| from datasets import load_dataset # type: ignore | |
| ds = load_dataset("Salesforce/wikitext", "wikitext-2-raw-v1", split="test") | |
| full_text = "\n".join(ex["text"] for ex in ds if ex["text"].strip()) | |
| chunks = [full_text[i:i + chunk_chars] for i in range(0, len(full_text), chunk_chars)] | |
| chunks = [c for c in chunks if len(c) > 20][:max_chunks] | |
| labels: List[str] = [] | |
| ppls: List[float] = [] | |
| for i, chunk in enumerate(chunks): | |
| nll = _score_text(model, tokenizer, chunk, device) | |
| ppl = math.exp(nll) if not math.isnan(nll) else float("nan") | |
| labels.append(f"chunk {i + 1}") | |
| ppls.append(ppl) | |
| if (i + 1) % 10 == 0: | |
| valid = [v for v in ppls if not math.isnan(v)] | |
| mean = sum(valid) / len(valid) if valid else float("nan") | |
| print(f" chunk {i + 1}/{len(chunks)} running mean ppl={mean:.2f}") | |
| return labels, ppls | |
| def _run_arc_easy(model, tokenizer, device, max_samples: int = 200) -> Tuple[List[str], List[float]]: | |
| from datasets import load_dataset # type: ignore | |
| ds = load_dataset("allenai/ai2_arc", "ARC-Easy", split="test") | |
| items = list(ds)[:max_samples] | |
| labels: List[str] = [] | |
| scores: List[float] = [] | |
| for i, ex in enumerate(items): | |
| question = ex["question"] | |
| choices = ex["choices"]["text"] | |
| choice_labels = ex["choices"]["label"] | |
| answer_key = ex["answerKey"] | |
| context = f"Question: {question}\nAnswer:" | |
| nlls = [_score_completion(model, tokenizer, context, f" {c}", device) for c in choices] | |
| if all(math.isnan(v) for v in nlls): | |
| scores.append(float("nan")) | |
| else: | |
| best_idx = min(range(len(nlls)), key=lambda j: nlls[j] if not math.isnan(nlls[j]) else float("inf")) | |
| predicted = choice_labels[best_idx] | |
| scores.append(1.0 if predicted == answer_key else 0.0) | |
| labels.append(f"Q{i + 1}") | |
| n_valid = sum(1 for s in scores if not math.isnan(s)) | |
| acc = sum(s for s in scores if not math.isnan(s)) / n_valid if n_valid else float("nan") | |
| print(f" {n_valid} questions evaluated, accuracy={acc:.3f}") | |
| return labels, scores | |
| def run_benchmark_mode() -> None: | |
| try: | |
| import matplotlib | |
| matplotlib.use("Agg") | |
| import matplotlib.pyplot as plt | |
| except ImportError: | |
| print("matplotlib not installed. pip install matplotlib") | |
| return | |
| bench_keys = list(BENCHMARKS.keys()) | |
| print("\nBenchmarks:") | |
| for i, k in enumerate(bench_keys): | |
| b = BENCHMARKS[k] | |
| print(f" [{i + 1}] {b['label']} — {b['desc']}") | |
| print("Select benchmark [1]:", end=" ", flush=True) | |
| try: | |
| b_choice = input().strip() or "1" | |
| except (EOFError, KeyboardInterrupt): | |
| print() | |
| return | |
| if not (b_choice.isdigit() and 1 <= int(b_choice) <= len(bench_keys)): | |
| print("Invalid selection.") | |
| return | |
| bench_key = bench_keys[int(b_choice) - 1] | |
| bench = BENCHMARKS[bench_key] | |
| print(f"Benchmark: {bench['label']}") | |
| root = Path(__file__).resolve().parent | |
| runs_dir = root / "runs" | |
| all_models = discover_models(runs_dir) | |
| model_entries: List[dict] = [] | |
| for m in all_models: | |
| model_entries.append({"label": f"[LOCAL] {m['name']}/{m['checkpoint']}", "type": "local", "meta": m}) | |
| for hf_name, hf_id in HUGGINGFACE_MODELS.items(): | |
| model_entries.append({"label": f"[HF] {hf_name}", "type": "hf", "hf_id": hf_id, "hf_name": hf_name}) | |
| if not model_entries: | |
| print("No models found.") | |
| return | |
| print("\nAvailable models:") | |
| for i, e in enumerate(model_entries): | |
| print(f" [{i + 1}] {e['label']}") | |
| print(" [a] All models") | |
| print("Select models (comma-separated or 'a'):", end=" ", flush=True) | |
| try: | |
| raw = input().strip() | |
| except (EOFError, KeyboardInterrupt): | |
| print() | |
| return | |
| if raw.lower() == "a": | |
| selected = list(range(len(model_entries))) | |
| else: | |
| selected = [] | |
| for tok in raw.split(","): | |
| tok = tok.strip() | |
| if tok.isdigit() and 1 <= int(tok) <= len(model_entries): | |
| selected.append(int(tok) - 1) | |
| if not selected: | |
| print("No valid selection.") | |
| return | |
| all_results: List[dict] = [] | |
| shared_x_labels: Optional[List[str]] = None | |
| for idx in selected: | |
| entry = model_entries[idx] | |
| print(f"\n{'='*60}\nLoading {entry['label']}...") | |
| try: | |
| if entry["type"] == "local": | |
| m = entry["meta"] | |
| bundle = load_local_model(m["model_path"], m["tokenizer_path"], m["series"]) | |
| else: | |
| bundle = load_huggingface_model(entry["hf_id"], root / ".hf_cache") | |
| except Exception as e: | |
| print(f" Failed: {e}") | |
| continue | |
| model = bundle["model"] | |
| tokenizer = bundle["tokenizer"] | |
| device = str(bundle["device"]) | |
| model.eval() | |
| if bench_key == "blimp": | |
| x_labels, y_vals = _run_blimp(model, tokenizer, device) | |
| elif bench_key == "wikitext2": | |
| x_labels, y_vals = _run_wikitext2(model, tokenizer, device) | |
| else: | |
| x_labels, y_vals = _run_arc_easy(model, tokenizer, device) | |
| if shared_x_labels is None: | |
| shared_x_labels = x_labels | |
| valid = [v for v in y_vals if not math.isnan(v)] | |
| summary = sum(valid) / len(valid) if valid else float("nan") | |
| all_results.append({"label": entry["label"], "y": y_vals, "summary": summary}) | |
| if not all_results or shared_x_labels is None: | |
| print("No results to plot.") | |
| return | |
| metric = bench["metric"] | |
| paired = sorted(zip([r["summary"] for r in all_results], [r["label"] for r in all_results]), | |
| reverse=(metric != "perplexity")) | |
| summaries, model_labels = zip(*paired) if paired else ([], []) | |
| n = len(summaries) | |
| colors = [plt.cm.RdYlGn(i / max(n - 1, 1)) for i in range(n)] | |
| fig, ax = plt.subplots(figsize=(max(6, n * 1.4), 6)) | |
| bars = ax.bar(range(n), summaries, color=colors, edgecolor="black") | |
| for bar, val in zip(bars, summaries): | |
| ax.text(bar.get_x() + bar.get_width() / 2, bar.get_height() + 0.005, | |
| f"{val:.3f}", ha="center", va="bottom", fontsize=9, fontweight="bold") | |
| ylabel = "Mean Perplexity (↓ better)" if metric == "perplexity" else "Mean Accuracy (↑ better)" | |
| ax.set_ylabel(ylabel) | |
| ax.set_title(f"{bench['label']} Benchmark — Model Comparison") | |
| ax.set_xticks(range(n)) | |
| ax.set_xticklabels(model_labels, rotation=20, ha="right", fontsize=9) | |
| if metric == "accuracy": | |
| ax.set_ylim(0, 1.05) | |
| ax.grid(True, axis="y", alpha=0.3) | |
| plt.tight_layout() | |
| out_path = root / f"benchmark_{bench_key}.png" | |
| plt.savefig(str(out_path), dpi=150) | |
| print(f"\nChart saved to {out_path}") | |
| try: | |
| import subprocess | |
| subprocess.Popen(["xdg-open", str(out_path)]) | |
| except Exception: | |
| pass | |
| # --------------------------------------------------------------------------- | |
| # Interactive CLI | |
| # --------------------------------------------------------------------------- | |
| def _pick_series(detected: str) -> str: | |
| series_list = list(MODEL_SERIES.keys()) | |
| detected_lower = detected.lower() | |
| default_idx = next( | |
| (i + 1 for i, s in enumerate(series_list) if s == detected_lower), 1 | |
| ) | |
| # Skip selection if only one series available | |
| if len(series_list) == 1: | |
| return series_list[0].capitalize() | |
| print("Series:") | |
| for i, s in enumerate(series_list): | |
| marker = " (detected)" if s == detected_lower else "" | |
| print(f" [{i + 1}] {s.capitalize()}{marker}") | |
| while True: | |
| try: | |
| choice = input(f"Select series [{default_idx}]: ").strip() | |
| except (EOFError, KeyboardInterrupt): | |
| print() | |
| sys.exit(0) | |
| if not choice: | |
| choice = str(default_idx) | |
| if choice.isdigit() and 1 <= int(choice) <= len(series_list): | |
| return series_list[int(choice) - 1].capitalize() | |
| print(f"Enter a number 1-{len(series_list)}") | |
| def pick_model(runs_dir: Path) -> tuple[dict, str]: | |
| models = discover_models(runs_dir) | |
| if not models: | |
| print(f"No models found in {runs_dir}") | |
| print("Expected layout: runs/<name>/model.pt (or pretrain.pt) + tokenizer.json") | |
| sys.exit(1) | |
| if len(models) == 1: | |
| m = models[0] | |
| print(f"Loading {m['name']}/{m['checkpoint']} ({m['series']})...") | |
| bundle = load_local_model(m["model_path"], m["tokenizer_path"], m["series"]) | |
| return bundle, m["checkpoint"] | |
| print("Available models:") | |
| for i, m in enumerate(models): | |
| print(f" [{i + 1}] {m['name']}/{m['checkpoint']} ({m['series']})") | |
| while True: | |
| try: | |
| choice = input("Select model [1]: ").strip() | |
| except (EOFError, KeyboardInterrupt): | |
| print() | |
| sys.exit(0) | |
| if not choice: | |
| choice = "1" | |
| if choice.isdigit() and 1 <= int(choice) <= len(models): | |
| m = models[int(choice) - 1] | |
| print(f"Loading {m['name']}/{m['checkpoint']} ({m['series']})...") | |
| bundle = load_local_model(m["model_path"], m["tokenizer_path"], m["series"]) | |
| return bundle, m["checkpoint"] | |
| print(f"Enter a number 1-{len(models)}") | |
| # --------------------------------------------------------------------------- | |
| # Generation mode configs | |
| # --------------------------------------------------------------------------- | |
| MODES = { | |
| "chat-coherent": { | |
| "label": "Chat — Coherent", | |
| "desc": "structured, consistent, strong repetition control", | |
| "sft_mode": "chat", | |
| "temperature": 0.35, | |
| "top_k": 20, | |
| "top_p": 0.88, | |
| "min_p": 0.10, | |
| "no_repeat_ngram_size": 4, | |
| "repetition_penalty": 1.22, | |
| "logit_soft_cap": 20.0, | |
| "loop_penalty": 20.0, | |
| "max_new_tokens": 4096, | |
| "context_window": 2048, | |
| }, | |
| "chat-variants": { | |
| "label": "Chat — Variants", | |
| "desc": "creative, diverse, more surprising outputs", | |
| "sft_mode": "chat", | |
| "temperature": 0.65, | |
| "top_k": 60, | |
| "top_p": 0.92, | |
| "min_p": 0.05, | |
| "no_repeat_ngram_size": 4, | |
| "repetition_penalty": 1.12, | |
| "logit_soft_cap": 20.0, | |
| "loop_penalty": 14.0, | |
| "max_new_tokens": 4096, | |
| "context_window": 2048, | |
| }, | |
| "pretrain-coherent": { | |
| "label": "Pretrain — Coherent", | |
| "desc": "grounded continuation, low temperature, tight sampling", | |
| "sft_mode": False, | |
| "temperature": 0.3, | |
| "top_k": 20, | |
| "top_p": 0.85, | |
| "min_p": 0.10, | |
| "no_repeat_ngram_size": 4, | |
| "repetition_penalty": 1.2, | |
| "logit_soft_cap": 20.0, | |
| "loop_penalty": 20.0, | |
| "max_new_tokens": 4096, | |
| "context_window": 2048, | |
| }, | |
| "pretrain-variants": { | |
| "label": "Pretrain — Variants", | |
| "desc": "free-form continuation, higher temperature, more exploration", | |
| "sft_mode": False, | |
| "temperature": 0.7, | |
| "top_k": 60, | |
| "top_p": 0.93, | |
| "min_p": 0.04, | |
| "no_repeat_ngram_size": 4, | |
| "repetition_penalty": 1.12, | |
| "logit_soft_cap": 20.0, | |
| "loop_penalty": 12.0, | |
| "max_new_tokens": 4096, | |
| "context_window": 2048, | |
| }, | |
| } | |
| _MODE_LIST = list(MODES.keys()) | |
| def pick_mode(is_pretrain: bool) -> dict: | |
| """Prompt the user to choose a generation mode. Returns a config dict.""" | |
| # Filter to relevant modes based on checkpoint type | |
| candidates = [k for k in _MODE_LIST if ("pretrain" in k) == is_pretrain] | |
| print("\nGeneration mode:") | |
| for i, key in enumerate(candidates): | |
| cfg = MODES[key] | |
| print(f" [{i + 1}] {cfg['label']} — {cfg['desc']}") | |
| while True: | |
| try: | |
| choice = input("Select mode [1]: ").strip() | |
| except (EOFError, KeyboardInterrupt): | |
| print() | |
| sys.exit(0) | |
| if not choice: | |
| choice = "1" | |
| if choice.isdigit() and 1 <= int(choice) <= len(candidates): | |
| key = candidates[int(choice) - 1] | |
| cfg = MODES[key] | |
| print(f"Mode: {cfg['label']}") | |
| return cfg | |
| print(f"Enter a number 1-{len(candidates)}") | |
| def _run_loop(bundle: dict, cfg: dict) -> None: | |
| model = bundle["model"] | |
| tokenizer = bundle["tokenizer"] | |
| device = bundle["device"] | |
| sft = cfg["sft_mode"] | |
| prompt_label = "You" if sft else "Prompt" | |
| print(f"\nModel ready on {device}. Type your message, or /quit to exit.") | |
| print(f" temp={cfg['temperature']} top_k={cfg['top_k']} top_p={cfg['top_p']}") | |
| print(f" min_p={cfg['min_p']} ng={cfg['no_repeat_ngram_size']} rp={cfg['repetition_penalty']}") | |
| print(f" cap={cfg['logit_soft_cap']} loop_penalty={cfg['loop_penalty']}\n") | |
| while True: | |
| try: | |
| prompt = input(f"{prompt_label}: ").strip() | |
| except (EOFError, KeyboardInterrupt): | |
| print() | |
| break | |
| if not prompt: | |
| continue | |
| if prompt in ("/quit", "/exit", "/q"): | |
| break | |
| if prompt == "/help": | |
| print("Commands: /quit /exit /q /help /mode") | |
| if sft: | |
| print("Anything else is sent as a chat prompt.") | |
| else: | |
| print("Anything else is sent as a raw continuation prompt.") | |
| continue | |
| if prompt == "/mode": | |
| print(f"Current: {cfg['label']} — {cfg['desc']}") | |
| continue | |
| print("AI: ", end="", flush=True) | |
| generate( | |
| model=model, | |
| tokenizer=tokenizer, | |
| prompt=prompt, | |
| max_new_tokens=cfg["max_new_tokens"], | |
| temperature=cfg["temperature"], | |
| top_k=cfg["top_k"], | |
| top_p=cfg["top_p"], | |
| min_p=cfg["min_p"], | |
| no_repeat_ngram_size=cfg["no_repeat_ngram_size"], | |
| repetition_penalty=cfg["repetition_penalty"], | |
| logit_soft_cap=cfg["logit_soft_cap"], | |
| loop_penalty=cfg["loop_penalty"], | |
| device=str(device), | |
| sft_mode=cfg["sft_mode"], | |
| stream=True, | |
| context_window=cfg["context_window"], | |
| ) | |
| # --------------------------------------------------------------------------- | |
| # Dynamic collection discovery | |
| # --------------------------------------------------------------------------- | |
| _COLLECTION_SLUG = "CompactAI-O/tmlm-haiku-series" | |
| _AUTHOR = "CompactAI-O" | |
| _SEARCH = "TMLM-Haiku" | |
| _FALLBACK_COLLECTION = [ | |
| {"version": "TMLM-Haiku-2.3", "hf_id": "CompactAI-O/TMLM-Haiku-2.3"}, | |
| {"version": "TMLM-Haiku-2", "hf_id": "CompactAI-O/TMLM-Haiku-2"}, | |
| {"version": "TMLM-Haiku-1.3", "hf_id": "CompactAI-O/TMLM-Haiku-1.3"}, | |
| {"version": "TMLM-Haiku-1", "hf_id": "CompactAI-O/TMLM-Haiku-1"}, | |
| {"version": "Glint-1", "hf_id": "CompactAI-O/Glint-1"}, | |
| ] | |
| _EXTRA_REPOS = ["CompactAI-O/Glint-1"] | |
| def _probe_repo(hf_id: str) -> dict | None: | |
| """Return entry dict for one repo, or None if no usable checkpoints found.""" | |
| from huggingface_hub import list_repo_files | |
| try: | |
| files = set(list_repo_files(hf_id)) | |
| except Exception: | |
| return None | |
| # Detect which subdirectory holds the checkpoints | |
| subdir: str | None = None | |
| for candidate in ("models", "model"): | |
| if any(f.startswith(f"{candidate}/") for f in files): | |
| subdir = candidate | |
| break | |
| prefix = f"{subdir}/" if subdir else "" | |
| # Collect all .pt files in the checkpoint directory | |
| pt_files = sorted( | |
| f[len(prefix):] for f in files | |
| if f.startswith(prefix) and f.endswith(".pt") | |
| ) | |
| _LABELS = { | |
| "model.pt": ("Chat (SFT)", False), | |
| "model_rep.pt": ("Chat (anti-repetition)", False), | |
| "pretrain.pt": ("Pretrain (base)", True), | |
| } | |
| checkpoints = [] | |
| for fname in pt_files: | |
| label, is_pretrain = _LABELS.get(fname, (fname.removesuffix(".pt"), "pretrain" in fname)) | |
| checkpoints.append((label, fname, is_pretrain)) | |
| if not checkpoints: | |
| return None | |
| return { | |
| "version": hf_id.split("/")[-1], | |
| "hf_id": hf_id, | |
| "subdir": subdir, | |
| "checkpoints": checkpoints, | |
| "desc": "", | |
| } | |
| def fetch_collection() -> list[dict]: | |
| """Query HF for all CompactAI-O TMLM-Haiku models, newest first.""" | |
| from huggingface_hub import HfApi | |
| print("Checking HuggingFace collection for available models...") | |
| try: | |
| api = HfApi() | |
| infos = list( | |
| api.list_models( | |
| author=_AUTHOR, | |
| search=_SEARCH, | |
| sort="lastModified", | |
| ) | |
| ) | |
| infos.sort(key=lambda m: getattr(m, "lastModified", ""), reverse=True) | |
| except Exception as exc: | |
| print(f" Could not reach HuggingFace ({exc}); using fallback list.") | |
| infos = [type("M", (), {"id": e["hf_id"]})() for e in _FALLBACK_COLLECTION] | |
| entries = [] | |
| seen_ids: set = set() | |
| for info in infos: | |
| repo_id = info.id | |
| if _SEARCH.lower() not in repo_id.lower(): | |
| continue | |
| entry = _probe_repo(repo_id) | |
| if entry: | |
| entries.append(entry) | |
| seen_ids.add(repo_id) | |
| # Always include extra repos (e.g. Glint-1) not caught by TMLM-Haiku search | |
| for repo_id in _EXTRA_REPOS: | |
| if repo_id not in seen_ids: | |
| entry = _probe_repo(repo_id) | |
| if entry: | |
| entries.append(entry) | |
| seen_ids.add(repo_id) | |
| if not entries: | |
| print(" No models found; using fallback list.") | |
| for fb in _FALLBACK_COLLECTION: | |
| e = _probe_repo(fb["hf_id"]) | |
| if e: | |
| entries.append(e) | |
| return entries | |
| # --------------------------------------------------------------------------- | |
| # Download helper | |
| # --------------------------------------------------------------------------- | |
| def _download_version(entry: dict, cache_dir: Path) -> Path: | |
| """Download full repo snapshot; return the directory containing model files.""" | |
| try: | |
| from huggingface_hub import snapshot_download | |
| except ImportError: | |
| print("huggingface_hub not installed. Run: pip install huggingface_hub") | |
| sys.exit(1) | |
| hf_id = entry["hf_id"] | |
| print(f"Fetching {hf_id} ...") | |
| try: | |
| local_dir = Path(snapshot_download(repo_id=hf_id, cache_dir=str(cache_dir))) | |
| except Exception as exc: | |
| print(f"Download failed: {exc}") | |
| sys.exit(1) | |
| subdir = entry.get("subdir") | |
| model_dir = (local_dir / subdir) if subdir else local_dir | |
| if not model_dir.exists(): | |
| # Fallback to root | |
| model_dir = local_dir | |
| return model_dir | |
| # --------------------------------------------------------------------------- | |
| # Selection prompts | |
| # --------------------------------------------------------------------------- | |
| def _prompt_int(prompt: str, lo: int, hi: int, default: int = 1) -> int: | |
| while True: | |
| try: | |
| raw = input(f"{prompt} [{default}]: ").strip() | |
| except (EOFError, KeyboardInterrupt): | |
| print() | |
| sys.exit(0) | |
| if not raw: | |
| return default | |
| if raw.isdigit() and lo <= int(raw) <= hi: | |
| return int(raw) | |
| print(f" Enter a number {lo}–{hi}.") | |
| def pick_version(collection: list[dict]) -> dict: | |
| print("\nTMLM-Haiku series (CompactAI-O)\n") | |
| for i, entry in enumerate(collection): | |
| desc = f" — {entry['desc']}" if entry["desc"] else "" | |
| print(f" [{i + 1}] {entry['version']}{desc}") | |
| idx = _prompt_int("Select version", 1, len(collection)) | |
| return collection[idx - 1] | |
| def pick_checkpoint(entry: dict) -> tuple[str, bool]: | |
| """Return (filename, is_pretrain).""" | |
| ckpts = entry["checkpoints"] | |
| if len(ckpts) == 1: | |
| label, fname, is_pretrain = ckpts[0] | |
| print(f" Using: {label} ({fname})") | |
| return fname, is_pretrain | |
| print(f"\nCheckpoints for {entry['version']}:") | |
| for i, (label, fname, _) in enumerate(ckpts): | |
| print(f" [{i + 1}] {label} ({fname})") | |
| idx = _prompt_int("Select checkpoint", 1, len(ckpts)) | |
| label, fname, is_pretrain = ckpts[idx - 1] | |
| return fname, is_pretrain | |
| # --------------------------------------------------------------------------- | |
| # Main | |
| # --------------------------------------------------------------------------- | |
| def main() -> None: | |
| import argparse | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--compare", "-c", action="store_true") | |
| parser.add_argument("--prompt", "-p", type=str, default="Hello") | |
| mode_group = parser.add_mutually_exclusive_group() | |
| mode_group.add_argument("--pretrain", action="store_true") | |
| mode_group.add_argument("--sft", action="store_true") | |
| args, _ = parser.parse_known_args() | |
| print("=" * 56) | |
| print(" CompactAI-O Interactive Chat") | |
| print(" Models: huggingface.co/CompactAI-O") | |
| print("=" * 56) | |
| if args.compare: | |
| prefetch_huggingface_models() | |
| cfg = pick_mode(is_pretrain=args.pretrain) | |
| prompt_label = "You" if cfg["sft_mode"] else "Prompt" | |
| while True: | |
| print(f"{prompt_label}:", end=" ", flush=True) | |
| prompt = sys.stdin.readline().strip() | |
| if not prompt or prompt in ("/quit", "/exit", "/q"): | |
| break | |
| compare_all_models(prompt, cfg) | |
| return | |
| collection = fetch_collection() | |
| if not collection: | |
| print("No models found. Check your internet connection.") | |
| sys.exit(1) | |
| entry = pick_version(collection) | |
| fname, is_pretrain = pick_checkpoint(entry) | |
| if args.pretrain: | |
| is_pretrain = True | |
| elif args.sft: | |
| is_pretrain = False | |
| root = Path(__file__).resolve().parent | |
| cache_dir = root / "cache" / "huggingface" | |
| cache_dir.mkdir(parents=True, exist_ok=True) | |
| model_dir = _download_version(entry, cache_dir) | |
| model_path = model_dir / fname | |
| tokenizer_path = model_dir / "tokenizer.json" | |
| if not model_path.exists(): | |
| print(f"File not found: {model_path}") | |
| sys.exit(1) | |
| if not tokenizer_path.exists(): | |
| print(f"Tokenizer not found: {tokenizer_path}") | |
| sys.exit(1) | |
| print(f"Loading {entry['version']} / {fname} ...") | |
| bundle = load_local_model(model_path, tokenizer_path, "Haiku") | |
| # Use checkpoint-embedded sft_mode/phase if available | |
| sft_mode_flag = bundle.get("sft_mode") | |
| phase_flag = bundle.get("phase") | |
| if sft_mode_flag is not None and not args.pretrain and not args.sft: | |
| is_pretrain = not sft_mode_flag | |
| elif phase_flag is not None and not args.pretrain and not args.sft: | |
| is_pretrain = phase_flag == "pretrain" | |
| print("\nChoose action:") | |
| print(" [1] Chat with this model") | |
| print(" [2] Compare ALL models (local + HuggingFace)") | |
| print(" [3] Run Benchmark (BLiMP / WikiText-2 / ARC-Easy)") | |
| print("Select [1]:", end=" ", flush=True) | |
| choice = sys.stdin.readline().strip() or "1" | |
| if choice == "1": | |
| cfg = pick_mode(is_pretrain) | |
| _run_loop(bundle, cfg) | |
| elif choice == "2": | |
| print("\nDownloading/preparing HuggingFace models...") | |
| prefetch_huggingface_models() | |
| cfg = pick_mode(is_pretrain) | |
| prompt_label = "You" if cfg["sft_mode"] else "Prompt" | |
| while True: | |
| print(f"{prompt_label}:", end=" ", flush=True) | |
| prompt = sys.stdin.readline().strip() | |
| if not prompt or prompt in ("/quit", "/exit", "/q"): | |
| break | |
| compare_all_models(prompt, cfg) | |
| elif choice == "3": | |
| run_benchmark_mode() | |
| else: | |
| print("Enter 1, 2, or 3") | |
| if __name__ == "__main__": | |
| main() | |