Homepage / downloads /interactive.py
CompactAI's picture
Update downloads/interactive.py
16b04a8 verified
#!/usr/bin/env python3
"""Public-facing TMLM-Haiku interactive CLI.
Pulls models from the CompactAI-O HuggingFace collection:
https://huggingface.co/collections/CompactAI-O/tmlm-haiku-series
"""
from __future__ import annotations
#!/usr/bin/env python3
from __future__ import annotations
import hashlib
import json
import math
import os
import string
import sys
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, Iterator, List, Optional, Sequence, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.checkpoint import checkpoint
HUGGINGFACE_MODELS = {
"TMLM-Haiku-1": "CompactAI-O/TMLM-Haiku-1",
"TMLM-Haiku-1.3": "CompactAI-O/TMLM-Haiku-1.3",
"TMLM-Haiku-2": "CompactAI-O/TMLM-Haiku-2",
"Glint-1": "CompactAI-O/Glint-1",
}
# ---------------------------------------------------------------------------
# Config
# ---------------------------------------------------------------------------
@dataclass
class ModelConfig:
dim: int = 128
n_unique_layers: int = 8
n_logical_layers: int = 16
n_heads: int = 4
n_kv_heads: int = 2
ffn_dim: int = 224
dropout: float = 0.0
seq_len: int = 2048
sliding_window_size: int = 512
mtp_horizons: Tuple[int, ...] = (2, 3, 4)
rope_fraction: float = 0.5
embed_scale: bool = True
logit_soft_cap: float = -1.0
quantization: str = "nvfp4"
@property
def head_dim(self) -> int:
return self.dim // self.n_heads
model_config = ModelConfig()
MODEL_SERIES = {
"haiku": {
"dim": 64,
"n_unique_layers": 12,
"n_logical_layers": 24,
"n_heads": 4,
"n_kv_heads": 2,
"ffn_dim": 384,
"dropout": 0.0,
"seq_len": 2048,
"sliding_window_size": 2048,
"mtp_horizons": (),
"rope_fraction": 0.5,
"engram_dim": 8,
"engram_heads": 2,
"engram_table_size": 64,
"engram_max_ngram": 2,
"mhc_expansion": 2,
"sleep_gate_cap": 0,
"sleep_gate_heads": 4,
"latent_think_layers": 0,
"prelude_layers": 0,
"coda_layers": 0,
"recurrent_loops": 0,
"recurrent_act_threshold": 0.9,
"recurrent_lora_rank": 0,
"recurrent_loop_embed_dim": 0,
},
"sonnet": {
"dim": 1024,
"n_unique_layers": 20,
"n_logical_layers": 40,
"n_heads": 16,
"n_kv_heads": 4,
"ffn_dim": 4096,
"dropout": 0.0,
"seq_len": 2048,
"mtp_horizons": (2,),
"engram_dim": 32,
"engram_heads": 8,
"engram_table_size": 4096,
"engram_max_ngram": 2,
"mhc_expansion": 2,
"sleep_gate_cap": 0,
"sleep_gate_heads": 8,
"latent_think_layers": 0,
"prelude_layers": 0,
"coda_layers": 0,
"recurrent_loops": 0,
"recurrent_act_threshold": 0.99,
"recurrent_lora_rank": 0,
"recurrent_loop_embed_dim": 0,
},
"opus": {
"dim": 1536,
"n_unique_layers": 18,
"n_logical_layers": 36,
"n_heads": 16,
"n_kv_heads": 4,
"ffn_dim": 5888,
"dropout": 0.0,
"seq_len": 2048,
"mtp_horizons": (2,),
"engram_dim": 64,
"engram_heads": 8,
"engram_table_size": 8192,
"engram_max_ngram": 2,
"mhc_expansion": 4,
"sleep_gate_cap": 0,
"sleep_gate_heads": 8,
"latent_think_layers": 0,
"prelude_layers": 0,
"coda_layers": 0,
"recurrent_loops": 0,
"recurrent_act_threshold": 0.99,
"recurrent_lora_rank": 0,
"recurrent_loop_embed_dim": 0,
},
}
# ---------------------------------------------------------------------------
# Tokenizer
# ---------------------------------------------------------------------------
FORMAT_TOKENS = [
"<|user|>",
"<|assistant|>",
"<|system|>",
"<|start_header_id|>",
"<|end_header_id|>",
"<|begin_of_thought|>",
"<|end_of_thought|>",
"<|begin_of_solution|>",
"<|end_of_solution|>",
]
class WordTokenizer:
def __init__(
self, extra_chars: str = "", format_tokens: Optional[List[str]] = None
) -> None:
base = string.ascii_letters + string.digits + string.punctuation + " \n\t\r"
fallback_chars = sorted(set(base + extra_chars))
self.core_special = ["<PAD>", "<BOS>", "<EOS>", "<UNK>"]
self.format_tokens = (
list(format_tokens) if format_tokens else list(FORMAT_TOKENS)
)
self.special = list(self.core_special) + list(self.format_tokens)
self.id_to_token: List[str] = (
list(self.core_special) + self.format_tokens + fallback_chars
)
self.token_to_id: Dict[str, int] = {
t: i for i, t in enumerate(self.id_to_token)
}
self.special_multi_tokens = sorted(
[t for t in self.special if len(t) > 1], key=len, reverse=True
)
self.multi_char_tokens = self.special_multi_tokens
self.dynamic_additions = 0
@property
def pad_id(self) -> int:
return self.token_to_id["<PAD>"]
@property
def bos_id(self) -> int:
return self.token_to_id["<BOS>"]
@property
def eos_id(self) -> int:
return self.token_to_id["<EOS>"]
@property
def unk_id(self) -> int:
return self.token_to_id["<UNK>"]
@property
def vocab_size(self) -> int:
return len(self.id_to_token)
def maybe_add_char(self, ch: str) -> bool:
if ch in self.token_to_id:
return False
self.token_to_id[ch] = len(self.id_to_token)
self.id_to_token.append(ch)
self.dynamic_additions += 1
return True
def iter_lexical_tokens(self, text: str) -> Iterator[str]:
i = 0
n = len(text)
while i < n:
matched_special = False
for token in self.special_multi_tokens:
if text.startswith(token, i):
yield token
i += len(token)
matched_special = True
break
if matched_special:
continue
yield text[i]
i += 1
def encode(
self, text: str, add_bos: bool = False, add_eos: bool = False
) -> List[int]:
out: List[int] = []
if add_bos:
out.append(self.bos_id)
unk = self.unk_id
t2i = self.token_to_id
for tok in self.iter_lexical_tokens(text):
out.append(t2i.get(tok, unk))
if add_eos:
out.append(self.eos_id)
return out
def decode(self, ids: Sequence[int], skip_special: bool = True) -> str:
pieces: List[str] = []
for idx in ids:
if int(idx) < 0 or int(idx) >= len(self.id_to_token):
continue
tok = self.id_to_token[int(idx)]
if skip_special and tok in self.special:
continue
pieces.append(tok)
return "".join(pieces)
@classmethod
def load(cls, path: Path) -> WordTokenizer:
with path.open("r", encoding="utf-8") as f:
data = json.load(f)
format_tokens = data.get("format_tokens", FORMAT_TOKENS)
tokenizer = cls(extra_chars="", format_tokens=format_tokens)
tokenizer.id_to_token = data["id_to_token"]
tokenizer.token_to_id = {t: i for i, t in enumerate(tokenizer.id_to_token)}
tokenizer.special = list(tokenizer.core_special) + list(tokenizer.format_tokens)
tokenizer.special_multi_tokens = sorted(
[t for t in tokenizer.special if len(t) > 1], key=len, reverse=True
)
tokenizer.multi_char_tokens = tokenizer.special_multi_tokens
return tokenizer
LetterTokenizer = WordTokenizer
# ---------------------------------------------------------------------------
# Model
# ---------------------------------------------------------------------------
class RMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-6) -> None:
super().__init__()
self.weight = nn.Parameter(torch.ones(dim))
self.eps = eps
def forward(self, x: torch.Tensor) -> torch.Tensor:
if hasattr(torch.nn.functional, "rms_norm"):
return torch.nn.functional.rms_norm(
x, self.weight.shape, self.weight, self.eps
)
x_fp = x.float()
rms = torch.rsqrt(x_fp.pow(2).mean(dim=-1, keepdim=True) + self.eps)
return (x_fp * rms).to(dtype=x.dtype) * self.weight
class RotaryEmbedding(nn.Module):
def __init__(self, dim: int, base: float = 10000.0) -> None:
super().__init__()
inv = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer("inv_freq", inv, persistent=False)
def cos_sin(
self, seq_len: int, device: torch.device, dtype: torch.dtype
) -> Tuple[torch.Tensor, torch.Tensor]:
t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype)
freqs = torch.outer(t, self.inv_freq)
emb = torch.cat([freqs, freqs], dim=-1)
cos = emb.cos()[None, None, :, :].to(dtype=dtype)
sin = emb.sin()[None, None, :, :].to(dtype=dtype)
return cos, sin
def _rotate_half(x: torch.Tensor) -> torch.Tensor:
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
class CausalSelfAttention(nn.Module):
def __init__(
self,
dim: int,
n_heads: int,
n_kv_heads: int,
head_dim: int,
dropout: float,
sliding_window: int,
rope_fraction: float,
) -> None:
super().__init__()
self.dim = dim
self.n_heads = n_heads
self.n_kv_heads = n_kv_heads
self.head_dim = head_dim
self.n_rep = n_heads // n_kv_heads
self.dropout = dropout
self.sliding_window = sliding_window
self.wq = nn.Linear(dim, n_heads * head_dim, bias=False)
self.wk = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
self.wv = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
self.wo = nn.Linear(n_heads * head_dim, dim, bias=False)
self.rope_dim = max(2, int(head_dim * rope_fraction) // 2 * 2)
self.rope = RotaryEmbedding(self.rope_dim)
self.q_norm = RMSNorm(head_dim)
self.k_norm = RMSNorm(head_dim)
self.output_gate = nn.Parameter(torch.ones(n_heads))
def forward(
self,
x: torch.Tensor,
is_global: bool,
past_kv: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
B, T, _ = x.shape
q = self.wq(x).view(B, T, self.n_heads, self.head_dim)
k = self.wk(x).view(B, T, self.n_kv_heads, self.head_dim)
v = self.wv(x).view(B, T, self.n_kv_heads, self.head_dim)
q = self.q_norm(q)
k = self.k_norm(k)
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
past_len = past_kv[0].shape[2] if past_kv is not None else 0
cos, sin = self.rope.cos_sin(T + past_len, x.device, q.dtype)
cos_slice = cos[:, :, past_len : past_len + T, :]
sin_slice = sin[:, :, past_len : past_len + T, :]
q_rope = q[..., : self.rope_dim]
q_pass = q[..., self.rope_dim :]
k_rope = k[..., : self.rope_dim]
k_pass = k[..., self.rope_dim :]
q_rope = (q_rope * cos_slice) + (_rotate_half(q_rope) * sin_slice)
k_rope = (k_rope * cos_slice) + (_rotate_half(k_rope) * sin_slice)
q = torch.cat([q_rope, q_pass], dim=-1)
k = torch.cat([k_rope, k_pass], dim=-1)
if past_kv is not None:
k = torch.cat([past_kv[0], k], dim=2)
v = torch.cat([past_kv[1], v], dim=2)
new_kv = (k, v) if use_cache else None
S = k.shape[2]
if self.n_rep > 1:
k = (
k[:, :, None, :, :]
.expand(B, self.n_kv_heads, self.n_rep, S, self.head_dim)
.reshape(B, self.n_heads, S, self.head_dim)
)
v = (
v[:, :, None, :, :]
.expand(B, self.n_kv_heads, self.n_rep, S, self.head_dim)
.reshape(B, self.n_heads, S, self.head_dim)
)
drop_p = self.dropout if (self.training and torch.is_grad_enabled()) else 0.0
if is_global:
if past_kv is None and T > 1:
out = F.scaled_dot_product_attention(
q, k, v, is_causal=True, dropout_p=drop_p
)
else:
out = F.scaled_dot_product_attention(q, k, v, dropout_p=drop_p)
else:
T_q = q.shape[2]
q_pos = torch.arange(past_len, past_len + T_q, device=q.device).unsqueeze(1)
k_pos = torch.arange(S, device=q.device).unsqueeze(0)
mask = (q_pos >= k_pos) & ((q_pos - k_pos) < self.sliding_window)
out = F.scaled_dot_product_attention(
q, k, v, attn_mask=mask.unsqueeze(0).unsqueeze(0), dropout_p=drop_p
)
gate = torch.sigmoid(self.output_gate).view(1, self.n_heads, 1, 1)
out = out * gate
out = out.transpose(1, 2).contiguous().view(B, T, self.n_heads * self.head_dim)
out = self.wo(out)
return out, new_kv
class SwiGLU(nn.Module):
def __init__(self, dim: int, hidden_dim: int, dropout: float) -> None:
super().__init__()
self.gate = nn.Linear(dim, hidden_dim, bias=False)
self.up = nn.Linear(dim, hidden_dim, bias=False)
self.down = nn.Linear(hidden_dim, dim, bias=False)
self.drop = nn.Dropout(dropout)
nn.init.normal_(self.gate.weight, std=dim**-0.5)
nn.init.normal_(self.up.weight, std=dim**-0.5)
nn.init.normal_(self.down.weight, std=hidden_dim**-0.5)
def forward(self, x: torch.Tensor) -> torch.Tensor:
h = F.silu(self.gate(x)) * self.up(x)
out = self.down(h)
if self.training and torch.is_grad_enabled():
out = self.drop(out)
return out
def loop_index_embedding(h: torch.Tensor, loop_t: int, loop_dim: int, theta: float = 10000.0) -> torch.Tensor:
if loop_dim <= 0:
return h
loop_dim = min(loop_dim, h.shape[-1])
if loop_dim % 2 == 1:
loop_dim -= 1
if loop_dim <= 0:
return h
inv_freq = 1.0 / (theta ** (torch.arange(0, loop_dim, 2, device=h.device, dtype=h.dtype) / loop_dim))
phase = torch.tensor(float(loop_t), device=h.device, dtype=h.dtype) * inv_freq
loop_embed = torch.cat([phase.sin(), phase.cos()], dim=0).view(1, 1, loop_dim)
out = h.clone()
out[..., :loop_dim] = out[..., :loop_dim] + loop_embed
return out
class DepthLoRAAdapter(nn.Module):
def __init__(self, dim: int, rank: int, max_loops: int) -> None:
super().__init__()
self.rank = max(0, rank)
if self.rank <= 0:
self.down = None
self.B = None
self.scale = None
return
self.down = nn.Linear(dim, self.rank, bias=False)
self.B = nn.Parameter(torch.randn(self.rank, dim) * 0.02)
self.scale = nn.Embedding(max(1, max_loops), self.rank)
nn.init.zeros_(self.scale.weight)
def forward(self, x: torch.Tensor, loop_t: int) -> torch.Tensor:
if self.rank <= 0 or self.down is None or self.B is None or self.scale is None:
return torch.zeros_like(x)
t_idx = min(loop_t, self.scale.num_embeddings - 1)
scale = self.scale(torch.tensor(t_idx, device=x.device))
return (self.down(x) * scale) @ self.B
class StableRecurrentInjection(nn.Module):
def __init__(self, dim: int) -> None:
super().__init__()
self.log_A = nn.Parameter(torch.full((dim,), -2.0))
self.log_dt = nn.Parameter(torch.full((dim,), -2.0))
self.input_gate = nn.Parameter(torch.zeros(dim))
def forward(self, h: torch.Tensor, e: torch.Tensor, transformer_out: torch.Tensor) -> torch.Tensor:
A = torch.exp(-torch.exp((self.log_dt + self.log_A).clamp(-20, 20))).view(1, 1, -1)
B = torch.sigmoid(self.input_gate).view(1, 1, -1)
return A * h + B * e + transformer_out
class AdaptiveHalting(nn.Module):
def __init__(self, dim: int) -> None:
super().__init__()
self.halt = nn.Linear(dim, 1, bias=True)
nn.init.zeros_(self.halt.weight)
nn.init.constant_(self.halt.bias, -2.0)
def forward(self, h: torch.Tensor) -> torch.Tensor:
return torch.sigmoid(self.halt(h)).squeeze(-1)
class EngramBlock(nn.Module):
"""DeepSeek Engram: conditional memory via O(1) hashed N-gram lookup.
Stores common token-pair/triplet patterns in an embedding table and
retrieves them with multi-head hashing. A context-aware gate (using the
current hidden state as query) decides how much of the retrieved memory
to inject into the residual stream.
Reference: DeepSeek-AI, "Conditional Memory via Scalable Lookup" (2025).
"""
def __init__(
self,
dim: int,
engram_dim: int,
n_heads: int = 4,
table_size: int = 8192,
max_ngram: int = 3,
) -> None:
super().__init__()
self.dim = dim
self.engram_dim = engram_dim
self.n_heads = n_heads
self.table_size = table_size
self.max_ngram = max_ngram
# One embedding table per (ngram_order, hash_head)
self.embeddings = nn.ParameterDict()
for n in range(2, max_ngram + 1):
for k in range(n_heads):
self.embeddings[f"{n}_{k}"] = nn.Parameter(
torch.randn(table_size, engram_dim) * (engram_dim**-0.5)
)
# Fixed hash parameters (non-learnable, deterministic)
for n in range(2, max_ngram + 1):
for k in range(n_heads):
seed = int(hashlib.md5(f"engram_{n}_{k}".encode()).hexdigest()[:8], 16)
rng = torch.Generator().manual_seed(seed)
a = torch.randint(1, 2**31, (1,), generator=rng).item()
b = torch.randint(0, 2**31, (1,), generator=rng).item()
self.register_buffer(
f"hash_a_{n}_{k}", torch.tensor(a), persistent=False
)
self.register_buffer(
f"hash_b_{n}_{k}", torch.tensor(b), persistent=False
)
# Causal convolution over N-gram branch outputs (kernel=4, dilation=max_ngram)
total_branch_dim = engram_dim * n_heads * (max_ngram - 1)
self.branch_conv = nn.Conv1d(
total_branch_dim,
total_branch_dim,
kernel_size=4,
dilation=max_ngram,
padding=0,
groups=total_branch_dim,
bias=True,
)
nn.init.zeros_(self.branch_conv.weight)
nn.init.zeros_(self.branch_conv.bias)
# Context-aware gating: hidden state as query, memory as key/value
self.gate_query = nn.Linear(dim, engram_dim, bias=False)
self.gate_key = nn.Linear(total_branch_dim, engram_dim, bias=False)
self.gate_value = nn.Linear(total_branch_dim, dim, bias=False)
self.gate_scale = engram_dim**-0.5
def _hash_ngram(self, token_ids: torch.Tensor, n: int, k: int) -> torch.Tensor:
"""Hash n-gram token sequences into table indices.
Args:
token_ids: (B, T) token IDs
n: n-gram order (2 = bigram, 3 = trigram)
k: hash head index
Returns:
indices: (B, T) integer indices into embedding table
"""
a = getattr(self, f"hash_a_{n}_{k}")
b = getattr(self, f"hash_b_{n}_{k}")
B, T = token_ids.shape
# Pad left with zeros so every position has a valid n-gram
padded = F.pad(token_ids, (n - 1, 0), value=0) # (B, T+n-1)
# Polynomial rolling hash
combined = torch.zeros(B, T, dtype=torch.long, device=token_ids.device)
for i in range(n):
combined = combined * 31 + padded[:, i : i + T].long()
indices = ((a * combined) ^ b) % self.table_size
return indices
def forward(
self, hidden: torch.Tensor, token_ids: Optional[torch.Tensor] = None
) -> torch.Tensor:
"""Forward pass.
Args:
hidden: (B, T, dim) current hidden state
token_ids: (B, T) input token IDs for n-gram hashing.
If None, uses argmax of hidden projections as proxy.
Returns:
output: (B, T, dim) memory injection for residual stream
"""
B, T, _ = hidden.shape
if token_ids is None:
# Fallback: derive pseudo-token-ids from hidden state
token_ids = hidden.mean(dim=-1).long() % self.table_size
# Retrieve and concatenate across n-gram orders and hash heads
branch_outputs = []
for n in range(2, self.max_ngram + 1):
for k in range(self.n_heads):
indices = self._hash_ngram(token_ids, n, k) # (B, T)
table = self.embeddings[f"{n}_{k}"] # (table_size, engram_dim)
retrieved = table[indices] # (B, T, engram_dim)
branch_outputs.append(retrieved)
# (B, T, engram_dim * n_heads * (max_ngram - 1))
memory = torch.cat(branch_outputs, dim=-1)
# Causal convolution over sequence dimension
# Pad left for causality (kernel_size - 1 = 3)
conv_in = memory.transpose(1, 2) # (B, C, T)
conv_in = F.pad(
conv_in,
((self.branch_conv.kernel_size[0] - 1) * self.branch_conv.dilation[0], 0),
)
conv_out = self.branch_conv(conv_in) # (B, C, T)
memory = conv_out.transpose(1, 2) # (B, T, C)
# Context-aware gating
query = self.gate_query(hidden) # (B, T, engram_dim)
key = self.gate_key(memory) # (B, T, engram_dim)
gate = torch.sigmoid(
(query * key).sum(dim=-1, keepdim=True) * self.gate_scale
) # (B, T, 1)
value = self.gate_value(memory) # (B, T, dim)
return gate * value
class SleepGate(nn.Module):
"""Persistent memory + periodic consolidation gate."""
def __init__(
self,
dim: int,
cap: int = 128,
n_heads: int = 4,
retention_enabled: bool = True,
retention_hidden: int = 0,
) -> None:
super().__init__()
self.dim = dim
self.cap = cap
self.n_heads = n_heads
self.head_dim = dim // n_heads
self.scale = self.head_dim ** -0.5
self.retention_enabled = retention_enabled
self.register_buffer("mem_emb", torch.zeros(cap, dim, dtype=torch.bfloat16))
self.register_buffer("mem_age", torch.zeros(cap, dtype=torch.long))
self.register_buffer("mem_beta", torch.ones(cap, dtype=torch.float32))
self.register_buffer("mem_count", torch.zeros((), dtype=torch.long))
self.register_buffer("mem_head", torch.zeros((), dtype=torch.long))
self.register_buffer("global_step", torch.zeros((), dtype=torch.long))
self.q_proj = nn.Linear(dim, dim, bias=False)
self.k_proj = nn.Linear(dim, dim, bias=False)
self.v_proj = nn.Linear(dim, dim, bias=False)
self.o_proj = nn.Linear(dim, dim, bias=False)
nn.init.zeros_(self.o_proj.weight)
self.gate_scale = nn.Parameter(torch.zeros(()))
if retention_enabled:
if retention_hidden > 0:
self.retention_gate: Optional[nn.Module] = nn.Sequential(
nn.Linear(dim, retention_hidden, bias=False),
nn.GELU(),
nn.Linear(retention_hidden, 1, bias=True),
)
nn.init.constant_(self.retention_gate[-1].bias, 2.2)
else:
self.retention_gate = nn.Linear(dim, 1, bias=True)
nn.init.constant_(self.retention_gate.bias, 2.2)
else:
self.retention_gate = None
self._last_beta: Optional[torch.Tensor] = None
def write(self, hidden: torch.Tensor) -> None:
B, T, _ = hidden.shape
tail_full = hidden[:, max(0, T - 16):, :].float().mean(dim=1)
if self.retention_gate is not None:
beta_live = torch.sigmoid(self.retention_gate(tail_full).squeeze(-1))
self._last_beta = beta_live if self.training else None
beta_store = beta_live.detach().float()
else:
self._last_beta = None
beta_store = torch.ones(B, device=hidden.device, dtype=torch.float32)
tail = tail_full.to(self.mem_emb.dtype).detach()
with torch.no_grad():
head = int(self.mem_head.item())
count = int(self.mem_count.item())
step = int(self.global_step.item())
for b in range(B):
self.mem_emb[head] = tail[b]
self.mem_age[head] = step
self.mem_beta[head] = beta_store[b]
head = (head + 1) % self.cap
if count < self.cap:
count += 1
self.mem_head.fill_(head)
self.mem_count.fill_(count)
def read(self, x: torch.Tensor) -> torch.Tensor:
count = int(self.mem_count.item())
if count == 0:
return torch.zeros_like(x)
B, T, D = x.shape
mem = self.mem_emb[:count].clone().to(x.dtype)
q = self.q_proj(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
k = self.k_proj(mem).view(count, self.n_heads, self.head_dim).transpose(0, 1)
v = self.v_proj(mem).view(count, self.n_heads, self.head_dim).transpose(0, 1)
attn = torch.einsum("bhtd,hmd->bhtm", q, k) * self.scale
attn = F.softmax(attn, dim=-1)
if self.retention_enabled:
step = int(self.global_step.item())
ages = self.mem_age[:count].to(x.device)
delta = (step - ages).clamp(min=0).to(x.dtype)
betas = self.mem_beta[:count].to(x.dtype).clamp(min=1e-6, max=1.0)
weights = betas.pow(delta)
attn = attn * weights.view(1, 1, 1, count)
attn = attn / attn.sum(dim=-1, keepdim=True).clamp_min(1e-9)
out = torch.einsum("bhtm,hmd->bhtd", attn, v)
out = out.transpose(1, 2).contiguous().view(B, T, D)
out = self.o_proj(out)
return torch.sigmoid(self.gate_scale) * out
@torch.no_grad()
def reset(self) -> None:
self.mem_emb.zero_()
self.mem_age.zero_()
self.mem_beta.fill_(1.0)
self.mem_count.zero_()
self.mem_head.zero_()
self.global_step.zero_()
self._last_beta = None
def _sinkhorn_knopp(logits: torch.Tensor, n_iters: int = 7) -> torch.Tensor:
M = torch.exp(logits.clamp(-10, 10))
for _ in range(n_iters):
M = M / M.sum(dim=-1, keepdim=True).clamp(min=1e-10)
M = M / M.sum(dim=-2, keepdim=True).clamp(min=1e-10)
return M
class ManifoldHyperConnection(nn.Module):
def __init__(self, dim: int, expansion: int = 2) -> None:
super().__init__()
self.dim = dim
self.expansion = expansion
n = expansion
self.expand_fn = "duplicate"
self.collapse_fn = "mean"
self.bias_pre = nn.Parameter(torch.zeros(1, n))
self.bias_post = nn.Parameter(torch.zeros(1, n))
self.bias_res = nn.Parameter(torch.zeros(n, n))
self.theta_pre = nn.Linear(n * dim, n, bias=False)
self.theta_post = nn.Linear(n * dim, n, bias=False)
self.theta_res = nn.Linear(n * dim, n * n, bias=False)
self.alpha_pre = nn.Parameter(torch.tensor(0.0))
self.alpha_post = nn.Parameter(torch.tensor(0.0))
self.alpha_res = nn.Parameter(torch.tensor(0.0))
def _compute_mappings(
self, x_expanded: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
B, T, _ = x_expanded.shape
n = self.expansion
x_norm = F.rms_norm(x_expanded, [x_expanded.shape[-1]])
d_pre = torch.tanh(self.theta_pre(x_norm))
d_post = torch.tanh(self.theta_post(x_norm))
d_res = self.theta_res(x_norm)
H_pre_raw = torch.sigmoid(self.alpha_pre * d_pre + self.bias_pre)
H_post_raw = 2.0 * torch.sigmoid(self.alpha_post * d_post + self.bias_post)
H_res_raw = (self.alpha_res * d_res + self.bias_res.reshape(1, 1, -1)).reshape(
B, T, n, n
)
H_res = _sinkhorn_knopp(H_res_raw)
return H_pre_raw.unsqueeze(-2), H_post_raw.unsqueeze(-2), H_res
def expand_stream(self, x: torch.Tensor) -> torch.Tensor:
return x.repeat(1, 1, self.expansion)
def collapse_stream(self, x_expanded: torch.Tensor) -> torch.Tensor:
B, T, _ = x_expanded.shape
n = self.expansion
C = self.dim
return x_expanded.view(B, T, n, C).mean(dim=-2)
def pre_mix(self, x_expanded: torch.Tensor, H_pre: torch.Tensor) -> torch.Tensor:
B, T, _ = x_expanded.shape
n = self.expansion
x_streams = x_expanded.view(B, T, n, self.dim)
return (H_pre @ x_streams).squeeze(-2)
def post_res_mix(
self,
layer_output: torch.Tensor,
x_expanded: torch.Tensor,
H_post: torch.Tensor,
H_res: torch.Tensor,
) -> torch.Tensor:
B, T, _ = x_expanded.shape
n = self.expansion
C = self.dim
x_streams = x_expanded.view(B, T, n, C)
mixed = torch.matmul(H_res, x_streams)
post_out = torch.matmul(H_post.transpose(-2, -1), layer_output.unsqueeze(-2))
result = mixed + post_out
return result.reshape(B, T, n * C)
class TransformerBlock(nn.Module):
def __init__(
self,
dim: int,
n_heads: int,
n_kv_heads: int,
head_dim: int,
ffn_dim: int,
dropout: float,
sliding_window: int,
rope_fraction: float,
engram_dim: int = 0,
engram_heads: int = 4,
engram_table_size: int = 8192,
engram_max_ngram: int = 3,
mhc_expansion: int = 1,
) -> None:
super().__init__()
self.norm1 = RMSNorm(dim)
self.attn = CausalSelfAttention(
dim=dim,
n_heads=n_heads,
n_kv_heads=n_kv_heads,
head_dim=head_dim,
dropout=dropout,
sliding_window=sliding_window,
rope_fraction=rope_fraction,
)
self.norm2 = RMSNorm(dim)
self.ffn = SwiGLU(dim, ffn_dim, dropout)
self.use_engram = engram_dim > 0
if self.use_engram:
self.engram = EngramBlock(
dim=dim,
engram_dim=engram_dim,
n_heads=engram_heads,
table_size=engram_table_size,
max_ngram=engram_max_ngram,
)
self.engram_norm = RMSNorm(dim)
self.use_mhc = mhc_expansion > 1
if self.use_mhc:
self.mhc_attn = ManifoldHyperConnection(dim, expansion=mhc_expansion)
self.mhc_ffn = ManifoldHyperConnection(dim, expansion=mhc_expansion)
def forward(
self,
x: torch.Tensor,
is_global: bool,
past_kv: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
use_cache: bool = False,
token_ids: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
if self.use_mhc:
x_exp = self.mhc_attn.expand_stream(x)
H_pre, H_post, H_res = self.mhc_attn._compute_mappings(x_exp)
attn_in = self.mhc_attn.pre_mix(x_exp, H_pre)
attn_out, new_kv = self.attn(
self.norm1(attn_in), is_global, past_kv, use_cache
)
x_exp = self.mhc_attn.post_res_mix(attn_out, x_exp, H_post, H_res)
if self.use_engram:
collapsed = self.mhc_attn.collapse_stream(x_exp)
collapsed = collapsed + self.engram(
self.engram_norm(collapsed), token_ids=token_ids
)
x_exp = self.mhc_attn.expand_stream(collapsed)
H_pre2, H_post2, H_res2 = self.mhc_ffn._compute_mappings(x_exp)
ffn_in = self.mhc_ffn.pre_mix(x_exp, H_pre2)
ffn_out = self.ffn(self.norm2(ffn_in))
x_exp = self.mhc_ffn.post_res_mix(ffn_out, x_exp, H_post2, H_res2)
x = self.mhc_attn.collapse_stream(x_exp)
else:
attn_out, new_kv = self.attn(self.norm1(x), is_global, past_kv, use_cache)
x = x + attn_out
if self.use_engram:
x = x + self.engram(self.engram_norm(x), token_ids=token_ids)
x = x + self.ffn(self.norm2(x))
return x, new_kv
class RecurrentDepthBlock(nn.Module):
def __init__(
self,
dim: int,
n_heads: int,
n_kv_heads: int,
head_dim: int,
ffn_dim: int,
dropout: float,
sliding_window: int,
rope_fraction: float,
n_loops: int,
act_threshold: float,
lora_rank: int,
loop_embed_dim: int,
) -> None:
super().__init__()
self.n_loops = max(1, n_loops)
self.act_threshold = act_threshold
self.loop_embed_dim = max(0, loop_embed_dim)
self.norm = RMSNorm(dim)
self.block = TransformerBlock(
dim=dim, n_heads=n_heads, n_kv_heads=n_kv_heads, head_dim=head_dim,
ffn_dim=ffn_dim, dropout=dropout, sliding_window=sliding_window,
rope_fraction=rope_fraction, engram_dim=0, mhc_expansion=1,
)
self.injection = StableRecurrentInjection(dim)
self.act = AdaptiveHalting(dim)
self.lora = DepthLoRAAdapter(dim, lora_rank, self.n_loops)
def forward(
self,
h: torch.Tensor,
e: torch.Tensor,
token_ids: Optional[torch.Tensor] = None,
past_key_values: Optional[List[Optional[Tuple[torch.Tensor, torch.Tensor]]]] = None,
use_cache: bool = False,
n_loops: Optional[int] = None,
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], Optional[List[Tuple[torch.Tensor, torch.Tensor]]]]:
loops = max(1, n_loops or self.n_loops)
B, T, _ = h.shape
halted = torch.zeros(B, T, device=h.device, dtype=torch.bool)
cumulative_p = torch.zeros(B, T, device=h.device, dtype=h.dtype)
output = torch.zeros_like(h)
new_past: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = [] if use_cache else None
current = h
final_halt = None
for t in range(loops):
h_loop = loop_index_embedding(current, t, self.loop_embed_dim)
combined = self.norm(h_loop + e)
past_kv = None
if past_key_values is not None and t < len(past_key_values):
past_kv = past_key_values[t]
trans_out, layer_kv = self.block(combined, is_global=True, past_kv=past_kv, use_cache=use_cache, token_ids=token_ids)
trans_out = trans_out + self.lora(trans_out, t)
next_h = self.injection(current, e, trans_out)
p = self.act(next_h)
p = p * (~halted).to(p.dtype)
final_halt = p
should_halt = (~halted) & ((cumulative_p + p) >= self.act_threshold)
update_weight = torch.where(should_halt, (1.0 - cumulative_p).clamp(min=0.0), p)
output = output + next_h * update_weight.unsqueeze(-1)
cumulative_p = cumulative_p + update_weight
current = torch.where(halted.unsqueeze(-1), current, next_h)
halted = halted | should_halt
if new_past is not None:
new_past.append(layer_kv)
if not use_cache and bool(halted.all()):
break
remainder = (1.0 - cumulative_p).clamp(min=0.0)
output = output + current * remainder.unsqueeze(-1)
aux: Dict[str, torch.Tensor] = {}
if final_halt is not None:
aux["recurrent_halt_mean"] = final_halt.mean()
return output, aux, new_past
class TinyMemoryLM(nn.Module):
def __init__(
self,
vocab_size: int,
dim: int,
n_unique_layers: int,
n_logical_layers: int,
n_heads: int,
n_kv_heads: int,
ffn_dim: int,
dropout: float,
mtp_horizons: Sequence[int],
grad_checkpoint: bool,
sliding_window: int = 512,
rope_fraction: float = 0.5,
embed_scale: bool = True,
engram_dim: int = 0,
engram_heads: int = 4,
engram_table_size: int = 8192,
engram_max_ngram: int = 3,
mhc_expansion: int = 1,
sleep_gate_cap: int = 0,
sleep_gate_heads: int = 4,
sleep_retention_enabled: bool = True,
sleep_retention_hidden: int = 0,
latent_think_layers: int = 0,
prelude_layers: int = 0,
coda_layers: int = 0,
recurrent_loops: int = 0,
recurrent_act_threshold: float = 0.99,
recurrent_lora_rank: int = 0,
recurrent_loop_embed_dim: int = 0,
) -> None:
super().__init__()
self.dim = dim
self.n_unique_layers = n_unique_layers
self.n_logical_layers = n_logical_layers
self.grad_checkpoint = grad_checkpoint
self.embed_scale_factor = math.sqrt(dim) if embed_scale else 1.0
head_dim = dim // n_heads
self.embed_tokens = nn.Embedding(vocab_size, dim)
self.head = nn.Linear(dim, vocab_size, bias=False)
self.head.weight = self.embed_tokens.weight
self.output_bias = nn.Parameter(torch.zeros(vocab_size))
self.use_recurrent_depth = recurrent_loops > 0
self.prelude_layers = max(0, prelude_layers)
self.coda_layers = max(0, coda_layers)
self.recurrent_loops = max(0, recurrent_loops)
self.blocks: Optional[nn.ModuleList] = None
self.prelude: Optional[nn.ModuleList] = None
self.recurrent: Optional[RecurrentDepthBlock] = None
self.coda: Optional[nn.ModuleList] = None
def _make_blocks(n: int) -> nn.ModuleList:
return nn.ModuleList([
TransformerBlock(
dim=dim, n_heads=n_heads, n_kv_heads=n_kv_heads, head_dim=head_dim,
ffn_dim=ffn_dim, dropout=dropout, sliding_window=sliding_window,
rope_fraction=rope_fraction, engram_dim=engram_dim,
engram_heads=engram_heads, engram_table_size=engram_table_size,
engram_max_ngram=engram_max_ngram, mhc_expansion=mhc_expansion,
)
for _ in range(n)
])
if self.use_recurrent_depth:
if self.prelude_layers > 0:
self.prelude = _make_blocks(self.prelude_layers)
self.recurrent = RecurrentDepthBlock(
dim=dim, n_heads=n_heads, n_kv_heads=n_kv_heads, head_dim=head_dim,
ffn_dim=ffn_dim, dropout=dropout, sliding_window=sliding_window,
rope_fraction=rope_fraction, n_loops=self.recurrent_loops,
act_threshold=recurrent_act_threshold, lora_rank=recurrent_lora_rank,
loop_embed_dim=recurrent_loop_embed_dim or max(2, dim // 8),
)
if self.coda_layers > 0:
self.coda = _make_blocks(self.coda_layers)
else:
self.blocks = _make_blocks(max(1, n_unique_layers))
self.norm = RMSNorm(dim)
self.mtp_horizons = sorted({int(h) for h in mtp_horizons if int(h) > 1})
self.mtp_adapters = nn.ModuleDict(
{str(h): nn.Linear(dim, dim, bias=False) for h in self.mtp_horizons}
)
self.mtp_norms = nn.ModuleDict(
{str(h): RMSNorm(dim) for h in self.mtp_horizons}
)
res_scale = (2 * max(1, n_logical_layers)) ** -0.5
for group in (self.blocks, self.prelude, self.coda):
if group is None:
continue
for block in group:
block.attn.wo.weight.data.mul_(res_scale)
block.ffn.down.weight.data.mul_(res_scale)
if self.recurrent is not None:
self.recurrent.block.attn.wo.weight.data.mul_(res_scale)
self.recurrent.block.ffn.down.weight.data.mul_(res_scale)
self.sleep_gate: Optional[SleepGate] = None
if sleep_gate_cap > 0:
self.sleep_gate = SleepGate(
dim=dim, cap=sleep_gate_cap, n_heads=sleep_gate_heads,
retention_enabled=sleep_retention_enabled,
retention_hidden=sleep_retention_hidden,
)
self.think_blocks: Optional[nn.ModuleList] = None
self.think_norm: Optional[RMSNorm] = None
if latent_think_layers > 0:
self.think_blocks = nn.ModuleList([
TransformerBlock(
dim=dim, n_heads=n_heads, n_kv_heads=n_kv_heads, head_dim=head_dim,
ffn_dim=ffn_dim, dropout=0.0, sliding_window=2048,
rope_fraction=rope_fraction, engram_dim=0, mhc_expansion=1,
)
for _ in range(latent_think_layers)
])
self.think_norm = RMSNorm(dim)
def resize_token_embeddings(self, new_vocab_size: int) -> None:
old_vocab_size = self.embed_tokens.num_embeddings
if new_vocab_size == old_vocab_size:
return
device = self.embed_tokens.weight.device
old_embed_weight = self.embed_tokens.weight.data.clone()
self.embed_tokens = nn.Embedding(new_vocab_size, self.embed_tokens.embedding_dim).to(device)
self.head = nn.Linear(self.embed_tokens.embedding_dim, new_vocab_size, bias=False).to(device)
self.head.weight = self.embed_tokens.weight
old_bias = self.output_bias.data.clone()
self.output_bias = nn.Parameter(torch.zeros(new_vocab_size, device=device))
copy_size = min(old_vocab_size, new_vocab_size)
self.output_bias.data[:copy_size] = old_bias[:copy_size]
self.embed_tokens.weight.data[:copy_size] = old_embed_weight[:copy_size]
def _build_logical_layers(self) -> List[Tuple[nn.Module, int]]:
if self.blocks is None:
return []
blocks_list = list(self.blocks)
full_sequence = blocks_list + blocks_list
return [(block, i) for i, block in enumerate(full_sequence[: self.n_logical_layers])]
def forward(
self,
ids: torch.Tensor,
use_cache: bool = False,
past_key_values: Optional[List[Optional[Tuple[torch.Tensor, torch.Tensor]]]] = None,
return_hidden: bool = False,
) -> Tuple[torch.Tensor, Dict[int, torch.Tensor], Dict[str, torch.Tensor], Optional[torch.Tensor], Optional[List[Tuple[torch.Tensor, torch.Tensor]]]]:
B, T = ids.shape
x = self.embed_tokens(ids) * self.embed_scale_factor
new_past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = [] if use_cache else None
aux: Dict[str, torch.Tensor] = {}
if self.use_recurrent_depth:
offset = 0
if self.prelude is not None:
for block in self.prelude:
past_kv = past_key_values[offset] if past_key_values is not None and offset < len(past_key_values) else None
x, layer_kv = block(x, is_global=True, past_kv=past_kv, use_cache=use_cache, token_ids=ids)
if new_past_key_values is not None:
new_past_key_values.append(layer_kv)
offset += 1
encoded = x
recurrent_past = past_key_values[offset: offset + self.recurrent_loops] if past_key_values is not None else None
x, recurrent_aux, recurrent_kv = self.recurrent(
x, encoded, token_ids=ids, past_key_values=recurrent_past, use_cache=use_cache,
)
aux.update(recurrent_aux)
if new_past_key_values is not None and recurrent_kv is not None:
new_past_key_values.extend(recurrent_kv)
offset += self.recurrent_loops
if self.coda is not None:
for block in self.coda:
past_kv = past_key_values[offset] if past_key_values is not None and offset < len(past_key_values) else None
x, layer_kv = block(x, is_global=True, past_kv=past_kv, use_cache=use_cache, token_ids=ids)
if new_past_key_values is not None:
new_past_key_values.append(layer_kv)
offset += 1
else:
logical_layers = self._build_logical_layers()
last_logical_idx = len(logical_layers) - 1
for layer_idx, (block, logical_idx) in enumerate(logical_layers):
is_global = logical_idx % 2 == 0 or layer_idx == last_logical_idx
past_kv = past_key_values[layer_idx] if past_key_values is not None and layer_idx < len(past_key_values) else None
if self.grad_checkpoint and self.training and not use_cache:
x, layer_kv = checkpoint(block, x, is_global, past_kv, use_cache, ids, use_reentrant=True)
else:
x, layer_kv = block(x, is_global, past_kv, use_cache, ids)
if new_past_key_values is not None:
new_past_key_values.append(layer_kv)
x = self.norm(x)
if self.sleep_gate is not None:
x = x + self.sleep_gate.read(x)
if self.training:
self.sleep_gate.write(x)
if self.think_blocks is not None:
for think_block in self.think_blocks:
x, _ = think_block(x, is_global=True)
x = self.think_norm(x)
h_out = x if return_hidden else None
logits = self.head(x)
if self.embed_scale_factor != 1.0:
logits = logits / self.embed_scale_factor
logits = logits + self.output_bias
mtp: Dict[int, torch.Tensor] = {}
if self.mtp_horizons and self.training:
for horizon in self.mtp_horizons:
if horizon > 1 and horizon <= T - 1:
shifted_h = x[:, :-horizon, :]
adapted_h = self.mtp_adapters[str(horizon)](shifted_h)
adapted_h = self.mtp_norms[str(horizon)](adapted_h)
mtp_logits = self.head(adapted_h)
if self.embed_scale_factor != 1.0:
mtp_logits = mtp_logits / self.embed_scale_factor
mtp_logits = mtp_logits + self.output_bias
mtp[horizon] = mtp_logits
return logits, mtp, aux, h_out, new_past_key_values
# ---------------------------------------------------------------------------
# Generation
# ---------------------------------------------------------------------------
def build_stop_token_ids(tokenizer: WordTokenizer) -> set:
stop_tokens = {tokenizer.eos_id}
for tok in ("<|user|>", "<|system|>", "<|assistant|>"):
tid = tokenizer.token_to_id.get(tok)
if tid is not None:
stop_tokens.add(int(tid))
return stop_tokens
def apply_no_repeat_ngram(
logits: torch.Tensor,
token_history: Sequence[int],
ngram_size: int,
) -> torch.Tensor:
if ngram_size <= 1 or len(token_history) < max(0, ngram_size - 1):
return logits
prefix = tuple(token_history[-(ngram_size - 1) :]) if ngram_size > 1 else tuple()
banned: set = set()
for i in range(len(token_history) - ngram_size + 1):
if tuple(token_history[i : i + ngram_size - 1]) == prefix:
banned.add(int(token_history[i + ngram_size - 1]))
if not banned:
return logits
out = logits.clone()
banned_ids = torch.tensor(sorted(banned), device=logits.device, dtype=torch.long)
out[banned_ids] = float("-inf")
return out
def apply_loop_penalty(
logits: torch.Tensor,
tokenizer: WordTokenizer,
generated_text: str,
penalty: float = 5.0,
) -> torch.Tensor:
"""Detect repeated substring loops and penalise continuation tokens."""
if len(generated_text) < 16:
return logits
out = logits.clone()
for span_len in [24, 16, 12, 8]:
if len(generated_text) < span_len * 2:
continue
suffix = generated_text[-span_len:]
prev = generated_text[:-span_len].rfind(suffix)
if prev == -1:
continue
next_pos = prev + span_len
if next_pos < len(generated_text):
next_char = generated_text[next_pos]
tid = tokenizer.token_to_id.get(next_char)
if tid is not None:
out[tid] -= penalty
break
return out
def apply_min_p(logits: torch.Tensor, min_p: float) -> torch.Tensor:
"""Filter tokens below min_p fraction of the top token probability."""
if min_p <= 0.0:
return logits
probs = torch.softmax(logits, dim=-1)
threshold = probs.max() * min_p
out = logits.clone()
out[probs < threshold] = float("-inf")
return out
def generate(
model: TinyMemoryLM,
tokenizer: WordTokenizer,
prompt: str,
max_new_tokens: int = 256,
temperature: float = 0.8,
top_k: int = 16,
top_p: float = 0.95,
repetition_penalty: float = 1.0,
device: str = "cuda",
sft_mode: bool = True,
stream: bool = True,
no_repeat_ngram_size: int = 0,
context_window: int = 2048,
logit_soft_cap: float = 15.0,
min_p: float = 0.05,
loop_penalty: float = 5.0,
) -> str:
if sft_mode:
full_prompt = f"<|user|>\n{prompt}\n<|assistant|>\n"
else:
full_prompt = prompt
input_ids = tokenizer.encode(full_prompt, add_bos=True, add_eos=False)
input_ids_t = torch.tensor([input_ids], dtype=torch.long, device=device)
visible_tokens: List[str] = []
stop_token_ids = build_stop_token_ids(tokenizer)
generated_text = ""
generated_ids: List[int] = []
# Full history (prompt + generated) for ngram blocking — prevents echoing prompt
full_ids_history: List[int] = list(input_ids)
with torch.no_grad():
for _ in range(max_new_tokens):
ctx_ids = (
input_ids_t[:, -context_window:] if context_window > 0 else input_ids_t
)
logits, *_ = model(ctx_ids)
next_logits = logits[0, -1, :].clone()
# Logit soft-capping (Gemma-style) — prevents overconfident collapse
if logit_soft_cap > 0:
next_logits = logit_soft_cap * torch.tanh(next_logits / logit_soft_cap)
raw_next_logits = next_logits.clone()
# Repetition penalty on previously generated tokens
if repetition_penalty != 1.0 and generated_ids:
for tok_id in set(generated_ids):
if next_logits[tok_id] > 0:
next_logits[tok_id] /= repetition_penalty
else:
next_logits[tok_id] *= repetition_penalty
# No-repeat n-gram blocking on generated tokens only
if no_repeat_ngram_size > 0 and generated_ids:
next_logits = apply_no_repeat_ngram(next_logits, generated_ids, no_repeat_ngram_size)
# Substring loop detection
next_logits = apply_loop_penalty(next_logits, tokenizer, generated_text, penalty=loop_penalty)
# Temperature scaling
if temperature != 1.0:
next_logits = next_logits / max(temperature, 1e-6)
# Min-p filtering — remove tokens below min_p * max_prob
if min_p > 0:
next_logits = apply_min_p(next_logits, min_p)
# Top-k filtering
if top_k > 0:
v, _ = torch.topk(next_logits, min(top_k, next_logits.size(0)))
next_logits[next_logits < v[-1]] = float("-inf")
# Top-p (nucleus) filtering
if 0.0 < top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(next_logits, descending=True)
sorted_probs = torch.softmax(sorted_logits, dim=-1)
cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
remove_mask = cumulative_probs > top_p
remove_mask[0] = False
indices_to_remove = sorted_indices[remove_mask]
next_logits[indices_to_remove] = float("-inf")
# Fallback if all tokens masked
if not torch.isfinite(next_logits).any():
next_logits = raw_next_logits
if temperature != 1.0:
next_logits = next_logits / max(temperature, 1e-6)
if temperature == 0:
next_id = torch.argmax(next_logits).item()
else:
probs = torch.softmax(next_logits, dim=-1)
next_id = torch.multinomial(probs, num_samples=1).item()
if next_id in stop_token_ids:
break
token_str = (
tokenizer.id_to_token[next_id]
if next_id < len(tokenizer.id_to_token)
else ""
)
generated_ids.append(next_id)
full_ids_history.append(next_id)
if token_str not in tokenizer.special:
visible_tokens.append(token_str)
generated_text += token_str
if stream:
print(token_str, end="", flush=True)
input_ids_t = torch.cat(
[input_ids_t, torch.tensor([[next_id]], device=device)], dim=1
)
if stream:
print()
return "".join(visible_tokens)
# ---------------------------------------------------------------------------
# Local model loading
# ---------------------------------------------------------------------------
def series_from_name(name: str) -> str | None:
lower = (name or "").lower()
if "haiku" in lower:
return "Haiku"
if "sonnet" in lower:
return "Sonnet"
if "opus" in lower:
return "Opus"
return None
def series_config(series: str) -> dict[str, object]:
return MODEL_SERIES.get(series.lower(), MODEL_SERIES["sonnet"])
def discover_models(runs_dir: Path) -> List[dict]:
models = []
if not runs_dir.is_dir():
return models
for child in sorted(runs_dir.iterdir()):
if not child.is_dir():
continue
tokenizer_path = child / "tokenizer.json"
if not tokenizer_path.exists():
continue
name = child.name
series = None
for ckpt_name in ("model.pt", "pretrain.pt"):
ckpt_path = child / ckpt_name
if ckpt_path.exists():
series = _fast_series_from_checkpoint(ckpt_path)
break
if series is None:
series = series_from_name(name) or "Sonnet"
found = False
for ckpt_name in ("model.pt", "model_rep.pt", "pretrain.pt"):
ckpt_path = child / ckpt_name
if ckpt_path.exists():
models.append(
{
"name": name,
"checkpoint": ckpt_name,
"series": series,
"model_path": ckpt_path,
"tokenizer_path": tokenizer_path,
}
)
found = True
if not found:
step_ckpts = sorted(
child.glob("checkpoint_step_*.pt"),
key=lambda p: int(p.stem.rsplit("_", 1)[-1]),
)
if step_ckpts:
ckpt_path = step_ckpts[-1]
models.append(
{
"name": name,
"checkpoint": ckpt_path.name,
"series": series,
"model_path": ckpt_path,
"tokenizer_path": tokenizer_path,
}
)
return models
def _detect_engram(state_dict):
for key in state_dict:
if ".engram." in key:
if ".embeddings." in key:
return state_dict[key].shape[-1]
return 0
def _detect_mhc(state_dict):
for key, val in state_dict.items():
if ".mhc_attn.bias_pre" in key and val.dim() == 2:
return val.shape[-1] # (1, expansion)
return 1
def _detect_sleep_gate(state_dict) -> Tuple[int, int]:
for key, val in state_dict.items():
if key == "sleep_gate.mem_emb" and val.dim() == 2:
cap = val.shape[0]
return cap, 4
return 0, 4
def _detect_latent_think(state_dict) -> int:
indices = {
int(k.split(".")[1])
for k in state_dict
if k.startswith("think_blocks.") and k.split(".")[1].isdigit()
}
return max(indices) + 1 if indices else 0
def _detect_prelude_layers(state_dict) -> int:
indices = {
int(k.split(".")[1])
for k in state_dict
if k.startswith("prelude.") and k.split(".")[1].isdigit()
}
return max(indices) + 1 if indices else 0
def _detect_coda_layers(state_dict) -> int:
indices = {
int(k.split(".")[1])
for k in state_dict
if k.startswith("coda.") and k.split(".")[1].isdigit()
}
return max(indices) + 1 if indices else 0
def _detect_recurrent_loops(state_dict) -> int:
if "recurrent.norm.weight" in state_dict or "recurrent.block.attn.wq.weight" in state_dict:
if "recurrent.lora.scale.weight" in state_dict:
return state_dict["recurrent.lora.scale.weight"].shape[0]
return 1
return 0
def _detect_recurrent_lora_rank(state_dict) -> int:
for key in ("recurrent.lora.B", "recurrent.lora.down.weight"):
if key in state_dict:
shape = state_dict[key].shape
if len(shape) == 2:
return int(shape[0])
return 0
def _infer_series_from_lora_rank(rank: int) -> str | None:
if rank == 0:
return None
if rank <= 8:
return "haiku"
if rank <= 16:
return "sonnet"
return "opus"
def _fast_series_from_checkpoint(ckpt_path: Path) -> str | None:
try:
cp = torch.load(ckpt_path, map_location="cpu", weights_only=False)
sd = cp.get("model_state", cp.get("state_dict", {}))
rank = 0
for key in ("recurrent.lora.B", "recurrent.lora.down.weight"):
if key in sd:
rank = int(sd[key].shape[0])
break
if rank == 0:
return None
if rank <= 8:
return "Haiku"
if rank <= 16:
return "Sonnet"
return "Opus"
except Exception:
pass
return None
def _infer_arch_from_state_dict(state_dict, cfg):
"""Infer architecture hyper-parameters directly from checkpoint weights,
falling back to *cfg* (series config) when a key is not found."""
overrides = {}
has_prelude = any(k.startswith("prelude.") for k in state_dict)
has_blocks = any(k.startswith("blocks.") for k in state_dict)
has_recurrent = any(k.startswith("recurrent.") for k in state_dict)
uses_recurrent_arch = has_prelude and has_recurrent and not has_blocks
# dim from embed_tokens.weight [vocab, dim]
if "embed_tokens.weight" in state_dict:
overrides["dim"] = state_dict["embed_tokens.weight"].shape[1]
if uses_recurrent_arch:
if "prelude.0.ffn.gate.weight" in state_dict:
overrides["ffn_dim"] = state_dict["prelude.0.ffn.gate.weight"].shape[0]
overrides["n_unique_layers"] = 0
src = "prelude.0"
else:
if "blocks.0.ffn.gate.weight" in state_dict:
overrides["ffn_dim"] = state_dict["blocks.0.ffn.gate.weight"].shape[0]
block_ids = {
int(k.split(".")[1])
for k in state_dict
if k.startswith("blocks.") and k.split(".")[1].isdigit()
}
if block_ids:
overrides["n_unique_layers"] = max(block_ids) + 1
src = "blocks.0"
dim = overrides.get("dim", int(cfg.get("dim", model_config.dim)))
if f"{src}.attn.wq.weight" in state_dict:
wq_rows = state_dict[f"{src}.attn.wq.weight"].shape[0]
if f"{src}.attn.q_norm.weight" in state_dict:
head_dim = state_dict[f"{src}.attn.q_norm.weight"].shape[0]
overrides["n_heads"] = wq_rows // head_dim
if f"{src}.attn.wk.weight" in state_dict:
wk_rows = state_dict[f"{src}.attn.wk.weight"].shape[0]
if f"{src}.attn.k_norm.weight" in state_dict:
head_dim = state_dict[f"{src}.attn.k_norm.weight"].shape[0]
overrides["n_kv_heads"] = wk_rows // head_dim
# engram params
for key, val in state_dict.items():
if ".engram.embeddings." in key and key.endswith("_0") and val.dim() == 2:
overrides["engram_table_size"] = val.shape[0]
overrides["engram_dim"] = val.shape[1]
break
engram_dim = overrides.get("engram_dim", int(cfg.get("engram_dim", 0)))
engram_max_ngram = int(cfg.get("engram_max_ngram", 2))
if engram_dim > 0:
for key, val in state_dict.items():
if ".engram.branch_conv.weight" in key and val.dim() == 3:
total_branch_dim = val.shape[0]
denom = engram_dim * (engram_max_ngram - 1)
if denom > 0 and total_branch_dim % denom == 0:
overrides["engram_heads"] = total_branch_dim // denom
break
merged = dict(cfg)
merged.update(overrides)
return merged
def load_local_model(model_path: Path, tokenizer_path: Path, series: str) -> dict:
tokenizer = WordTokenizer.load(tokenizer_path)
ckpt = torch.load(str(model_path), map_location="cpu", weights_only=False)
cfg = series_config(series)
vocab_size = int(ckpt.get("vocab_size", tokenizer.vocab_size))
state_dict = ckpt.get("model_state") or ckpt.get("state_dict") or ckpt
cfg = _infer_arch_from_state_dict(state_dict, cfg)
engram_dim = int(cfg.get("engram_dim", 0))
if _detect_engram(state_dict) == 0:
engram_dim = 0
mhc_expansion = _detect_mhc(state_dict)
if mhc_expansion == 1:
mhc_expansion = int(cfg.get("mhc_expansion", 1))
ckpt_sleep_cap, ckpt_sleep_heads = _detect_sleep_gate(state_dict)
sleep_gate_cap = ckpt_sleep_cap if ckpt_sleep_cap > 0 else int(cfg.get("sleep_gate_cap", 0))
sleep_gate_heads = ckpt_sleep_heads if ckpt_sleep_cap > 0 else int(cfg.get("sleep_gate_heads", 4))
sleep_retention_enabled = bool(cfg.get("sleep_retention_enabled", True))
sleep_retention_hidden = int(cfg.get("sleep_retention_hidden", 0))
latent_think_layers = _detect_latent_think(state_dict)
if latent_think_layers == 0:
latent_think_layers = int(cfg.get("latent_think_layers", 0))
prelude_layers = _detect_prelude_layers(state_dict)
coda_layers = _detect_coda_layers(state_dict)
recurrent_loops = _detect_recurrent_loops(state_dict)
ckpt_lora_rank = _detect_recurrent_lora_rank(state_dict)
if ckpt_lora_rank > 0:
inferred_series = _infer_series_from_lora_rank(ckpt_lora_rank)
if inferred_series and inferred_series != series.lower():
series = inferred_series.capitalize()
cfg = series_config(series)
recurrent_lora_rank = ckpt_lora_rank
else:
recurrent_lora_rank = int(cfg.get("recurrent_lora_rank", 0))
recurrent_act_threshold = float(cfg.get("recurrent_act_threshold", 0.99))
recurrent_loop_embed_dim = int(cfg.get("recurrent_loop_embed_dim", 0))
n_unique = int(cfg.get("n_unique_layers", model_config.n_unique_layers))
model = TinyMemoryLM(
vocab_size=vocab_size,
dim=int(cfg.get("dim", model_config.dim)),
n_unique_layers=n_unique,
n_logical_layers=int(cfg.get("n_logical_layers", model_config.n_logical_layers)),
n_heads=int(cfg.get("n_heads", model_config.n_heads)),
n_kv_heads=int(cfg.get("n_kv_heads", model_config.n_kv_heads)),
ffn_dim=int(cfg.get("ffn_dim", model_config.ffn_dim)),
dropout=float(cfg.get("dropout", model_config.dropout)),
mtp_horizons=tuple(int(v) for v in cfg.get("mtp_horizons", model_config.mtp_horizons)),
grad_checkpoint=False,
sliding_window=int(cfg.get("sliding_window_size", getattr(model_config, "sliding_window_size", 512))),
rope_fraction=float(cfg.get("rope_fraction", getattr(model_config, "rope_fraction", 0.25))),
embed_scale=bool(cfg.get("embed_scale", getattr(model_config, "embed_scale", True))),
engram_dim=engram_dim,
engram_heads=int(cfg.get("engram_heads", 4)),
engram_table_size=int(cfg.get("engram_table_size", 8192)),
engram_max_ngram=int(cfg.get("engram_max_ngram", 3)),
mhc_expansion=mhc_expansion,
sleep_gate_cap=sleep_gate_cap,
sleep_gate_heads=sleep_gate_heads,
sleep_retention_enabled=sleep_retention_enabled,
sleep_retention_hidden=sleep_retention_hidden,
latent_think_layers=latent_think_layers,
prelude_layers=prelude_layers,
coda_layers=coda_layers,
recurrent_loops=recurrent_loops,
recurrent_act_threshold=recurrent_act_threshold,
recurrent_lora_rank=recurrent_lora_rank,
recurrent_loop_embed_dim=recurrent_loop_embed_dim,
)
model.load_state_dict(state_dict, strict=False)
model.eval()
if tokenizer.vocab_size > vocab_size:
model.resize_token_embeddings(tokenizer.vocab_size)
device = "cuda" if torch.cuda.is_available() else "cpu"
model = model.to(device)
return {
"model": model,
"tokenizer": tokenizer,
"device": device,
"series": series,
"sft_mode": ckpt.get("sft_mode", None),
"phase": ckpt.get("phase", None),
}
# ---------------------------------------------------------------------------
# HuggingFace Model Download & Loading
# ---------------------------------------------------------------------------
def download_huggingface_model(hf_id: str, cache_dir: Path) -> dict:
try:
from huggingface_hub import snapshot_download
except ImportError:
print("huggingface_hub not installed. Install with: pip install huggingface_hub")
sys.exit(1)
print(f"Downloading {hf_id}...")
try:
local_dir = Path(snapshot_download(repo_id=hf_id, cache_dir=str(cache_dir)))
except Exception as e:
print(f"Failed to download {hf_id}: {e}")
return None
print(f"Using cached {hf_id} from {local_dir}")
# Check common subdirectory names: "models/", "model/"
if (local_dir / "models").exists():
model_dir = local_dir / "models"
elif (local_dir / "model").exists():
model_dir = local_dir / "model"
else:
model_dir = local_dir
model_path = model_dir / "model.pt"
pretrain_path = model_dir / "pretrain.pt"
tokenizer_path = model_dir / "tokenizer.json"
ckpt_path = None
for p in [model_path, pretrain_path]:
if p.exists():
ckpt_path = p
break
if ckpt_path is None or not tokenizer_path.exists():
print(f"Missing model files in {model_dir}")
print(f" model.pt exists: {model_path.exists()}")
print(f" pretrain.pt exists: {pretrain_path.exists()}")
print(f" tokenizer.json exists: {tokenizer_path.exists()}")
return None
return {
"model_path": ckpt_path,
"tokenizer_path": tokenizer_path,
"model_name": ckpt_path.stem,
}
def load_huggingface_model(hf_id: str, cache_dir: Path) -> dict:
files = download_huggingface_model(hf_id, cache_dir)
if files is None:
return None
return load_local_model(files["model_path"], files["tokenizer_path"], "Haiku")
# ---------------------------------------------------------------------------
# Compare All Models
# ---------------------------------------------------------------------------
_hf_model_cache: Dict[str, dict] = {}
def prefetch_huggingface_models() -> None:
root = Path(__file__).resolve().parent
cache_dir = root / "cache" / "huggingface"
cache_dir.mkdir(parents=True, exist_ok=True)
print("Downloading/preparing HuggingFace models...")
for name, hf_id in HUGGINGFACE_MODELS.items():
print(f" {name}...")
bundle = load_huggingface_model(hf_id, cache_dir)
if bundle:
_hf_model_cache[name] = bundle
print(f"Prepared {len(_hf_model_cache)} HuggingFace models")
def compare_all_models(prompt: str, cfg: dict) -> None:
root = Path(__file__).resolve().parent
runs_dir = root / "runs"
all_models = discover_models(runs_dir)
is_pretrain = not cfg.get("sft_mode", True)
local_models = [
m for m in all_models
if ("pretrain" in m["checkpoint"]) == is_pretrain
]
if not local_models:
print("No models found matching mode.")
return
results: List[dict] = []
for m in local_models:
print(f"\n{'='*60}")
print(f"Loading local {m['name']}/{m['checkpoint']}...")
try:
bundle = load_local_model(m["model_path"], m["tokenizer_path"], m["series"])
except Exception as e:
print(f"Failed to load {m['name']}: {e}")
continue
model = bundle["model"]
tokenizer = bundle["tokenizer"]
device = bundle["device"]
print(f"Generating on '{prompt}'...")
output = generate(
model=model,
tokenizer=tokenizer,
prompt=prompt,
max_new_tokens=cfg["max_new_tokens"],
temperature=cfg["temperature"],
top_k=cfg["top_k"],
top_p=cfg["top_p"],
min_p=cfg["min_p"],
no_repeat_ngram_size=cfg["no_repeat_ngram_size"],
repetition_penalty=cfg["repetition_penalty"],
logit_soft_cap=cfg["logit_soft_cap"],
loop_penalty=cfg["loop_penalty"],
device=str(device),
sft_mode=cfg["sft_mode"],
stream=True,
context_window=cfg["context_window"],
)
results.append({
"name": f"[LOCAL] {m['name']}/{m['checkpoint']}",
"output": output,
"device": device,
})
for name, bundle in _hf_model_cache.items():
print(f"\n{'='*60}")
print(f"Loading {name} (cached)...")
model = bundle["model"]
tokenizer = bundle["tokenizer"]
device = bundle["device"]
print(f"Generating on '{prompt}'...")
output = generate(
model=model,
tokenizer=tokenizer,
prompt=prompt,
max_new_tokens=cfg["max_new_tokens"],
temperature=cfg["temperature"],
top_k=cfg["top_k"],
top_p=cfg["top_p"],
min_p=cfg["min_p"],
no_repeat_ngram_size=cfg["no_repeat_ngram_size"],
repetition_penalty=cfg["repetition_penalty"],
logit_soft_cap=cfg["logit_soft_cap"],
loop_penalty=cfg["loop_penalty"],
device=str(device),
sft_mode=cfg["sft_mode"],
stream=True,
context_window=cfg["context_window"],
)
results.append({
"name": name,
"output": output,
"device": device,
})
print(f"\n{'='*60}")
print("=" * 60)
print("SIDE-BY-SIDE COMPARISON")
print("=" * 60)
for r in results:
print(f"\n--- {r['name']} ---")
print(r["output"])
print(f"\n{'='*60}")
# ---------------------------------------------------------------------------
# Benchmark
# ---------------------------------------------------------------------------
BENCHMARKS = {
"blimp": {
"label": "BLiMP",
"desc": "Grammaticality minimal pairs (67 paradigms). Accuracy = % grammatical < ungrammatical perplexity.",
"hf_dataset": ("nyu-mll/blimp", None),
"metric": "accuracy",
},
"wikitext2": {
"label": "WikiText-2",
"desc": "LM perplexity on Wikipedia test split. Lower is better.",
"hf_dataset": ("Salesforce/wikitext", "wikitext-2-raw-v1"),
"metric": "perplexity",
},
"arc_easy": {
"label": "ARC-Easy",
"desc": "Multiple-choice science QA (~2.4K). Perplexity-based answer selection.",
"hf_dataset": ("allenai/ai2_arc", "ARC-Easy"),
"metric": "accuracy",
},
}
def _score_text(model: TinyMemoryLM, tokenizer: WordTokenizer, text: str, device: str) -> float:
ids = tokenizer.encode(text, add_bos=True, add_eos=False)
if len(ids) < 2:
return float("nan")
ids_t = torch.tensor([ids], dtype=torch.long, device=device)
with torch.no_grad():
logits, *_ = model(ids_t)
log_probs = F.log_softmax(logits[0], dim=-1)
targets = ids_t[0, 1:]
nll = -log_probs[range(len(targets)), targets].mean().item()
return nll
def _score_completion(model: TinyMemoryLM, tokenizer: WordTokenizer, context: str, completion: str, device: str) -> float:
full_ids = tokenizer.encode(context + completion, add_bos=True, add_eos=False)
ctx_ids = tokenizer.encode(context, add_bos=True, add_eos=False)
n_ctx = len(ctx_ids)
n_ref = len(full_ids) - n_ctx
if n_ref <= 0:
return float("nan")
ids_t = torch.tensor([full_ids], dtype=torch.long, device=device)
with torch.no_grad():
logits, *_ = model(ids_t)
log_probs = F.log_softmax(logits[0], dim=-1)
targets = ids_t[0, 1:]
ref_start = n_ctx - 1
ref_end = min(ref_start + n_ref, log_probs.shape[0])
if ref_start >= ref_end:
return float("nan")
nll = -log_probs[ref_start:ref_end][range(ref_end - ref_start), targets[ref_start:ref_end]].mean().item()
return nll
BLIMP_PARADIGMS = [
"adjunct_island", "anaphor_gender_agreement", "anaphor_number_agreement",
"animate_subject_passive", "animate_subject_trans", "causative",
"complex_NP_island", "coordinate_structure_constraint_complex_left_branch",
"coordinate_structure_constraint_object_extraction",
"determiner_noun_agreement_1", "determiner_noun_agreement_2",
"determiner_noun_agreement_irregular_1", "determiner_noun_agreement_irregular_2",
"determiner_noun_agreement_with_adj_2", "determiner_noun_agreement_with_adj_irregular_1",
"determiner_noun_agreement_with_adj_irregular_2", "determiner_noun_agreement_with_adjective_1",
"distractor_agreement_relational_noun", "distractor_agreement_relative_clause",
"drop_argument", "ellipsis_n_bar_1", "ellipsis_n_bar_2",
"existential_there_object_raising", "existential_there_quantifiers_1",
"existential_there_quantifiers_2", "existential_there_subject_raising",
"expletive_it_object_raising", "inchoative", "intransitive",
"irregular_past_participle_adjectives", "irregular_past_participle_verbs",
"irregular_plural_subject_verb_agreement_1", "irregular_plural_subject_verb_agreement_2",
"left_branch_island_echo_question", "left_branch_island_simple_question",
"matrix_question_npi_licensor_present", "npi_present_1", "npi_present_2",
"only_npi_licensor_present", "only_npi_scope", "passive_1", "passive_2",
"principle_A_c_command", "principle_A_case_1", "principle_A_case_2",
"principle_A_domain_1", "principle_A_domain_2", "principle_A_domain_3",
"principle_A_reconstruction", "regular_plural_subject_verb_agreement_1",
"regular_plural_subject_verb_agreement_2", "sentential_negation_npi_licensor_present",
"sentential_negation_npi_scope", "sentential_subject_island",
"superlative_quantifiers_1", "superlative_quantifiers_2",
"tough_vs_raising_1", "tough_vs_raising_2", "transitive", "wh_island",
"wh_questions_object_gap", "wh_questions_subject_gap",
"wh_questions_subject_gap_long_distance", "wh_vs_that_no_gap",
"wh_vs_that_no_gap_long_distance", "wh_vs_that_with_gap",
"wh_vs_that_with_gap_long_distance",
]
def _run_blimp(model, tokenizer, device, n_samples: int = 200) -> Tuple[List[str], List[float]]:
from datasets import load_dataset # type: ignore
accuracies: List[float] = []
for paradigm in BLIMP_PARADIGMS:
try:
ds = load_dataset("nyu-mll/blimp", paradigm, split="train")
except Exception as e:
print(f" {paradigm}: skip ({e})")
accuracies.append(float("nan"))
continue
items = list(ds)[:n_samples]
correct = 0
for ex in items:
good_nll = _score_text(model, tokenizer, ex["sentence_good"], device)
bad_nll = _score_text(model, tokenizer, ex["sentence_bad"], device)
if math.isnan(good_nll) or math.isnan(bad_nll):
continue
if good_nll < bad_nll:
correct += 1
acc = correct / len(items) if items else float("nan")
accuracies.append(acc)
print(f" {paradigm:50s} acc={acc:.3f}")
return BLIMP_PARADIGMS, accuracies
def _run_wikitext2(model, tokenizer, device, chunk_chars: int = 512, max_chunks: int = 100) -> Tuple[List[str], List[float]]:
from datasets import load_dataset # type: ignore
ds = load_dataset("Salesforce/wikitext", "wikitext-2-raw-v1", split="test")
full_text = "\n".join(ex["text"] for ex in ds if ex["text"].strip())
chunks = [full_text[i:i + chunk_chars] for i in range(0, len(full_text), chunk_chars)]
chunks = [c for c in chunks if len(c) > 20][:max_chunks]
labels: List[str] = []
ppls: List[float] = []
for i, chunk in enumerate(chunks):
nll = _score_text(model, tokenizer, chunk, device)
ppl = math.exp(nll) if not math.isnan(nll) else float("nan")
labels.append(f"chunk {i + 1}")
ppls.append(ppl)
if (i + 1) % 10 == 0:
valid = [v for v in ppls if not math.isnan(v)]
mean = sum(valid) / len(valid) if valid else float("nan")
print(f" chunk {i + 1}/{len(chunks)} running mean ppl={mean:.2f}")
return labels, ppls
def _run_arc_easy(model, tokenizer, device, max_samples: int = 200) -> Tuple[List[str], List[float]]:
from datasets import load_dataset # type: ignore
ds = load_dataset("allenai/ai2_arc", "ARC-Easy", split="test")
items = list(ds)[:max_samples]
labels: List[str] = []
scores: List[float] = []
for i, ex in enumerate(items):
question = ex["question"]
choices = ex["choices"]["text"]
choice_labels = ex["choices"]["label"]
answer_key = ex["answerKey"]
context = f"Question: {question}\nAnswer:"
nlls = [_score_completion(model, tokenizer, context, f" {c}", device) for c in choices]
if all(math.isnan(v) for v in nlls):
scores.append(float("nan"))
else:
best_idx = min(range(len(nlls)), key=lambda j: nlls[j] if not math.isnan(nlls[j]) else float("inf"))
predicted = choice_labels[best_idx]
scores.append(1.0 if predicted == answer_key else 0.0)
labels.append(f"Q{i + 1}")
n_valid = sum(1 for s in scores if not math.isnan(s))
acc = sum(s for s in scores if not math.isnan(s)) / n_valid if n_valid else float("nan")
print(f" {n_valid} questions evaluated, accuracy={acc:.3f}")
return labels, scores
def run_benchmark_mode() -> None:
try:
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
except ImportError:
print("matplotlib not installed. pip install matplotlib")
return
bench_keys = list(BENCHMARKS.keys())
print("\nBenchmarks:")
for i, k in enumerate(bench_keys):
b = BENCHMARKS[k]
print(f" [{i + 1}] {b['label']}{b['desc']}")
print("Select benchmark [1]:", end=" ", flush=True)
try:
b_choice = input().strip() or "1"
except (EOFError, KeyboardInterrupt):
print()
return
if not (b_choice.isdigit() and 1 <= int(b_choice) <= len(bench_keys)):
print("Invalid selection.")
return
bench_key = bench_keys[int(b_choice) - 1]
bench = BENCHMARKS[bench_key]
print(f"Benchmark: {bench['label']}")
root = Path(__file__).resolve().parent
runs_dir = root / "runs"
all_models = discover_models(runs_dir)
model_entries: List[dict] = []
for m in all_models:
model_entries.append({"label": f"[LOCAL] {m['name']}/{m['checkpoint']}", "type": "local", "meta": m})
for hf_name, hf_id in HUGGINGFACE_MODELS.items():
model_entries.append({"label": f"[HF] {hf_name}", "type": "hf", "hf_id": hf_id, "hf_name": hf_name})
if not model_entries:
print("No models found.")
return
print("\nAvailable models:")
for i, e in enumerate(model_entries):
print(f" [{i + 1}] {e['label']}")
print(" [a] All models")
print("Select models (comma-separated or 'a'):", end=" ", flush=True)
try:
raw = input().strip()
except (EOFError, KeyboardInterrupt):
print()
return
if raw.lower() == "a":
selected = list(range(len(model_entries)))
else:
selected = []
for tok in raw.split(","):
tok = tok.strip()
if tok.isdigit() and 1 <= int(tok) <= len(model_entries):
selected.append(int(tok) - 1)
if not selected:
print("No valid selection.")
return
all_results: List[dict] = []
shared_x_labels: Optional[List[str]] = None
for idx in selected:
entry = model_entries[idx]
print(f"\n{'='*60}\nLoading {entry['label']}...")
try:
if entry["type"] == "local":
m = entry["meta"]
bundle = load_local_model(m["model_path"], m["tokenizer_path"], m["series"])
else:
bundle = load_huggingface_model(entry["hf_id"], root / ".hf_cache")
except Exception as e:
print(f" Failed: {e}")
continue
model = bundle["model"]
tokenizer = bundle["tokenizer"]
device = str(bundle["device"])
model.eval()
if bench_key == "blimp":
x_labels, y_vals = _run_blimp(model, tokenizer, device)
elif bench_key == "wikitext2":
x_labels, y_vals = _run_wikitext2(model, tokenizer, device)
else:
x_labels, y_vals = _run_arc_easy(model, tokenizer, device)
if shared_x_labels is None:
shared_x_labels = x_labels
valid = [v for v in y_vals if not math.isnan(v)]
summary = sum(valid) / len(valid) if valid else float("nan")
all_results.append({"label": entry["label"], "y": y_vals, "summary": summary})
if not all_results or shared_x_labels is None:
print("No results to plot.")
return
metric = bench["metric"]
paired = sorted(zip([r["summary"] for r in all_results], [r["label"] for r in all_results]),
reverse=(metric != "perplexity"))
summaries, model_labels = zip(*paired) if paired else ([], [])
n = len(summaries)
colors = [plt.cm.RdYlGn(i / max(n - 1, 1)) for i in range(n)]
fig, ax = plt.subplots(figsize=(max(6, n * 1.4), 6))
bars = ax.bar(range(n), summaries, color=colors, edgecolor="black")
for bar, val in zip(bars, summaries):
ax.text(bar.get_x() + bar.get_width() / 2, bar.get_height() + 0.005,
f"{val:.3f}", ha="center", va="bottom", fontsize=9, fontweight="bold")
ylabel = "Mean Perplexity (↓ better)" if metric == "perplexity" else "Mean Accuracy (↑ better)"
ax.set_ylabel(ylabel)
ax.set_title(f"{bench['label']} Benchmark — Model Comparison")
ax.set_xticks(range(n))
ax.set_xticklabels(model_labels, rotation=20, ha="right", fontsize=9)
if metric == "accuracy":
ax.set_ylim(0, 1.05)
ax.grid(True, axis="y", alpha=0.3)
plt.tight_layout()
out_path = root / f"benchmark_{bench_key}.png"
plt.savefig(str(out_path), dpi=150)
print(f"\nChart saved to {out_path}")
try:
import subprocess
subprocess.Popen(["xdg-open", str(out_path)])
except Exception:
pass
# ---------------------------------------------------------------------------
# Interactive CLI
# ---------------------------------------------------------------------------
def _pick_series(detected: str) -> str:
series_list = list(MODEL_SERIES.keys())
detected_lower = detected.lower()
default_idx = next(
(i + 1 for i, s in enumerate(series_list) if s == detected_lower), 1
)
# Skip selection if only one series available
if len(series_list) == 1:
return series_list[0].capitalize()
print("Series:")
for i, s in enumerate(series_list):
marker = " (detected)" if s == detected_lower else ""
print(f" [{i + 1}] {s.capitalize()}{marker}")
while True:
try:
choice = input(f"Select series [{default_idx}]: ").strip()
except (EOFError, KeyboardInterrupt):
print()
sys.exit(0)
if not choice:
choice = str(default_idx)
if choice.isdigit() and 1 <= int(choice) <= len(series_list):
return series_list[int(choice) - 1].capitalize()
print(f"Enter a number 1-{len(series_list)}")
def pick_model(runs_dir: Path) -> tuple[dict, str]:
models = discover_models(runs_dir)
if not models:
print(f"No models found in {runs_dir}")
print("Expected layout: runs/<name>/model.pt (or pretrain.pt) + tokenizer.json")
sys.exit(1)
if len(models) == 1:
m = models[0]
print(f"Loading {m['name']}/{m['checkpoint']} ({m['series']})...")
bundle = load_local_model(m["model_path"], m["tokenizer_path"], m["series"])
return bundle, m["checkpoint"]
print("Available models:")
for i, m in enumerate(models):
print(f" [{i + 1}] {m['name']}/{m['checkpoint']} ({m['series']})")
while True:
try:
choice = input("Select model [1]: ").strip()
except (EOFError, KeyboardInterrupt):
print()
sys.exit(0)
if not choice:
choice = "1"
if choice.isdigit() and 1 <= int(choice) <= len(models):
m = models[int(choice) - 1]
print(f"Loading {m['name']}/{m['checkpoint']} ({m['series']})...")
bundle = load_local_model(m["model_path"], m["tokenizer_path"], m["series"])
return bundle, m["checkpoint"]
print(f"Enter a number 1-{len(models)}")
# ---------------------------------------------------------------------------
# Generation mode configs
# ---------------------------------------------------------------------------
MODES = {
"chat-coherent": {
"label": "Chat — Coherent",
"desc": "structured, consistent, strong repetition control",
"sft_mode": "chat",
"temperature": 0.35,
"top_k": 20,
"top_p": 0.88,
"min_p": 0.10,
"no_repeat_ngram_size": 4,
"repetition_penalty": 1.22,
"logit_soft_cap": 20.0,
"loop_penalty": 20.0,
"max_new_tokens": 4096,
"context_window": 2048,
},
"chat-variants": {
"label": "Chat — Variants",
"desc": "creative, diverse, more surprising outputs",
"sft_mode": "chat",
"temperature": 0.65,
"top_k": 60,
"top_p": 0.92,
"min_p": 0.05,
"no_repeat_ngram_size": 4,
"repetition_penalty": 1.12,
"logit_soft_cap": 20.0,
"loop_penalty": 14.0,
"max_new_tokens": 4096,
"context_window": 2048,
},
"pretrain-coherent": {
"label": "Pretrain — Coherent",
"desc": "grounded continuation, low temperature, tight sampling",
"sft_mode": False,
"temperature": 0.3,
"top_k": 20,
"top_p": 0.85,
"min_p": 0.10,
"no_repeat_ngram_size": 4,
"repetition_penalty": 1.2,
"logit_soft_cap": 20.0,
"loop_penalty": 20.0,
"max_new_tokens": 4096,
"context_window": 2048,
},
"pretrain-variants": {
"label": "Pretrain — Variants",
"desc": "free-form continuation, higher temperature, more exploration",
"sft_mode": False,
"temperature": 0.7,
"top_k": 60,
"top_p": 0.93,
"min_p": 0.04,
"no_repeat_ngram_size": 4,
"repetition_penalty": 1.12,
"logit_soft_cap": 20.0,
"loop_penalty": 12.0,
"max_new_tokens": 4096,
"context_window": 2048,
},
}
_MODE_LIST = list(MODES.keys())
def pick_mode(is_pretrain: bool) -> dict:
"""Prompt the user to choose a generation mode. Returns a config dict."""
# Filter to relevant modes based on checkpoint type
candidates = [k for k in _MODE_LIST if ("pretrain" in k) == is_pretrain]
print("\nGeneration mode:")
for i, key in enumerate(candidates):
cfg = MODES[key]
print(f" [{i + 1}] {cfg['label']}{cfg['desc']}")
while True:
try:
choice = input("Select mode [1]: ").strip()
except (EOFError, KeyboardInterrupt):
print()
sys.exit(0)
if not choice:
choice = "1"
if choice.isdigit() and 1 <= int(choice) <= len(candidates):
key = candidates[int(choice) - 1]
cfg = MODES[key]
print(f"Mode: {cfg['label']}")
return cfg
print(f"Enter a number 1-{len(candidates)}")
def _run_loop(bundle: dict, cfg: dict) -> None:
model = bundle["model"]
tokenizer = bundle["tokenizer"]
device = bundle["device"]
sft = cfg["sft_mode"]
prompt_label = "You" if sft else "Prompt"
print(f"\nModel ready on {device}. Type your message, or /quit to exit.")
print(f" temp={cfg['temperature']} top_k={cfg['top_k']} top_p={cfg['top_p']}")
print(f" min_p={cfg['min_p']} ng={cfg['no_repeat_ngram_size']} rp={cfg['repetition_penalty']}")
print(f" cap={cfg['logit_soft_cap']} loop_penalty={cfg['loop_penalty']}\n")
while True:
try:
prompt = input(f"{prompt_label}: ").strip()
except (EOFError, KeyboardInterrupt):
print()
break
if not prompt:
continue
if prompt in ("/quit", "/exit", "/q"):
break
if prompt == "/help":
print("Commands: /quit /exit /q /help /mode")
if sft:
print("Anything else is sent as a chat prompt.")
else:
print("Anything else is sent as a raw continuation prompt.")
continue
if prompt == "/mode":
print(f"Current: {cfg['label']}{cfg['desc']}")
continue
print("AI: ", end="", flush=True)
generate(
model=model,
tokenizer=tokenizer,
prompt=prompt,
max_new_tokens=cfg["max_new_tokens"],
temperature=cfg["temperature"],
top_k=cfg["top_k"],
top_p=cfg["top_p"],
min_p=cfg["min_p"],
no_repeat_ngram_size=cfg["no_repeat_ngram_size"],
repetition_penalty=cfg["repetition_penalty"],
logit_soft_cap=cfg["logit_soft_cap"],
loop_penalty=cfg["loop_penalty"],
device=str(device),
sft_mode=cfg["sft_mode"],
stream=True,
context_window=cfg["context_window"],
)
# ---------------------------------------------------------------------------
# Dynamic collection discovery
# ---------------------------------------------------------------------------
_COLLECTION_SLUG = "CompactAI-O/tmlm-haiku-series"
_AUTHOR = "CompactAI-O"
_SEARCH = "TMLM-Haiku"
_FALLBACK_COLLECTION = [
{"version": "TMLM-Haiku-2.3", "hf_id": "CompactAI-O/TMLM-Haiku-2.3"},
{"version": "TMLM-Haiku-2", "hf_id": "CompactAI-O/TMLM-Haiku-2"},
{"version": "TMLM-Haiku-1.3", "hf_id": "CompactAI-O/TMLM-Haiku-1.3"},
{"version": "TMLM-Haiku-1", "hf_id": "CompactAI-O/TMLM-Haiku-1"},
{"version": "Glint-1", "hf_id": "CompactAI-O/Glint-1"},
]
_EXTRA_REPOS = ["CompactAI-O/Glint-1"]
def _probe_repo(hf_id: str) -> dict | None:
"""Return entry dict for one repo, or None if no usable checkpoints found."""
from huggingface_hub import list_repo_files
try:
files = set(list_repo_files(hf_id))
except Exception:
return None
# Detect which subdirectory holds the checkpoints
subdir: str | None = None
for candidate in ("models", "model"):
if any(f.startswith(f"{candidate}/") for f in files):
subdir = candidate
break
prefix = f"{subdir}/" if subdir else ""
# Collect all .pt files in the checkpoint directory
pt_files = sorted(
f[len(prefix):] for f in files
if f.startswith(prefix) and f.endswith(".pt")
)
_LABELS = {
"model.pt": ("Chat (SFT)", False),
"model_rep.pt": ("Chat (anti-repetition)", False),
"pretrain.pt": ("Pretrain (base)", True),
}
checkpoints = []
for fname in pt_files:
label, is_pretrain = _LABELS.get(fname, (fname.removesuffix(".pt"), "pretrain" in fname))
checkpoints.append((label, fname, is_pretrain))
if not checkpoints:
return None
return {
"version": hf_id.split("/")[-1],
"hf_id": hf_id,
"subdir": subdir,
"checkpoints": checkpoints,
"desc": "",
}
def fetch_collection() -> list[dict]:
"""Query HF for all CompactAI-O TMLM-Haiku models, newest first."""
from huggingface_hub import HfApi
print("Checking HuggingFace collection for available models...")
try:
api = HfApi()
infos = list(
api.list_models(
author=_AUTHOR,
search=_SEARCH,
sort="lastModified",
)
)
infos.sort(key=lambda m: getattr(m, "lastModified", ""), reverse=True)
except Exception as exc:
print(f" Could not reach HuggingFace ({exc}); using fallback list.")
infos = [type("M", (), {"id": e["hf_id"]})() for e in _FALLBACK_COLLECTION]
entries = []
seen_ids: set = set()
for info in infos:
repo_id = info.id
if _SEARCH.lower() not in repo_id.lower():
continue
entry = _probe_repo(repo_id)
if entry:
entries.append(entry)
seen_ids.add(repo_id)
# Always include extra repos (e.g. Glint-1) not caught by TMLM-Haiku search
for repo_id in _EXTRA_REPOS:
if repo_id not in seen_ids:
entry = _probe_repo(repo_id)
if entry:
entries.append(entry)
seen_ids.add(repo_id)
if not entries:
print(" No models found; using fallback list.")
for fb in _FALLBACK_COLLECTION:
e = _probe_repo(fb["hf_id"])
if e:
entries.append(e)
return entries
# ---------------------------------------------------------------------------
# Download helper
# ---------------------------------------------------------------------------
def _download_version(entry: dict, cache_dir: Path) -> Path:
"""Download full repo snapshot; return the directory containing model files."""
try:
from huggingface_hub import snapshot_download
except ImportError:
print("huggingface_hub not installed. Run: pip install huggingface_hub")
sys.exit(1)
hf_id = entry["hf_id"]
print(f"Fetching {hf_id} ...")
try:
local_dir = Path(snapshot_download(repo_id=hf_id, cache_dir=str(cache_dir)))
except Exception as exc:
print(f"Download failed: {exc}")
sys.exit(1)
subdir = entry.get("subdir")
model_dir = (local_dir / subdir) if subdir else local_dir
if not model_dir.exists():
# Fallback to root
model_dir = local_dir
return model_dir
# ---------------------------------------------------------------------------
# Selection prompts
# ---------------------------------------------------------------------------
def _prompt_int(prompt: str, lo: int, hi: int, default: int = 1) -> int:
while True:
try:
raw = input(f"{prompt} [{default}]: ").strip()
except (EOFError, KeyboardInterrupt):
print()
sys.exit(0)
if not raw:
return default
if raw.isdigit() and lo <= int(raw) <= hi:
return int(raw)
print(f" Enter a number {lo}{hi}.")
def pick_version(collection: list[dict]) -> dict:
print("\nTMLM-Haiku series (CompactAI-O)\n")
for i, entry in enumerate(collection):
desc = f" — {entry['desc']}" if entry["desc"] else ""
print(f" [{i + 1}] {entry['version']}{desc}")
idx = _prompt_int("Select version", 1, len(collection))
return collection[idx - 1]
def pick_checkpoint(entry: dict) -> tuple[str, bool]:
"""Return (filename, is_pretrain)."""
ckpts = entry["checkpoints"]
if len(ckpts) == 1:
label, fname, is_pretrain = ckpts[0]
print(f" Using: {label} ({fname})")
return fname, is_pretrain
print(f"\nCheckpoints for {entry['version']}:")
for i, (label, fname, _) in enumerate(ckpts):
print(f" [{i + 1}] {label} ({fname})")
idx = _prompt_int("Select checkpoint", 1, len(ckpts))
label, fname, is_pretrain = ckpts[idx - 1]
return fname, is_pretrain
# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------
def main() -> None:
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--compare", "-c", action="store_true")
parser.add_argument("--prompt", "-p", type=str, default="Hello")
mode_group = parser.add_mutually_exclusive_group()
mode_group.add_argument("--pretrain", action="store_true")
mode_group.add_argument("--sft", action="store_true")
args, _ = parser.parse_known_args()
print("=" * 56)
print(" CompactAI-O Interactive Chat")
print(" Models: huggingface.co/CompactAI-O")
print("=" * 56)
if args.compare:
prefetch_huggingface_models()
cfg = pick_mode(is_pretrain=args.pretrain)
prompt_label = "You" if cfg["sft_mode"] else "Prompt"
while True:
print(f"{prompt_label}:", end=" ", flush=True)
prompt = sys.stdin.readline().strip()
if not prompt or prompt in ("/quit", "/exit", "/q"):
break
compare_all_models(prompt, cfg)
return
collection = fetch_collection()
if not collection:
print("No models found. Check your internet connection.")
sys.exit(1)
entry = pick_version(collection)
fname, is_pretrain = pick_checkpoint(entry)
if args.pretrain:
is_pretrain = True
elif args.sft:
is_pretrain = False
root = Path(__file__).resolve().parent
cache_dir = root / "cache" / "huggingface"
cache_dir.mkdir(parents=True, exist_ok=True)
model_dir = _download_version(entry, cache_dir)
model_path = model_dir / fname
tokenizer_path = model_dir / "tokenizer.json"
if not model_path.exists():
print(f"File not found: {model_path}")
sys.exit(1)
if not tokenizer_path.exists():
print(f"Tokenizer not found: {tokenizer_path}")
sys.exit(1)
print(f"Loading {entry['version']} / {fname} ...")
bundle = load_local_model(model_path, tokenizer_path, "Haiku")
# Use checkpoint-embedded sft_mode/phase if available
sft_mode_flag = bundle.get("sft_mode")
phase_flag = bundle.get("phase")
if sft_mode_flag is not None and not args.pretrain and not args.sft:
is_pretrain = not sft_mode_flag
elif phase_flag is not None and not args.pretrain and not args.sft:
is_pretrain = phase_flag == "pretrain"
print("\nChoose action:")
print(" [1] Chat with this model")
print(" [2] Compare ALL models (local + HuggingFace)")
print(" [3] Run Benchmark (BLiMP / WikiText-2 / ARC-Easy)")
print("Select [1]:", end=" ", flush=True)
choice = sys.stdin.readline().strip() or "1"
if choice == "1":
cfg = pick_mode(is_pretrain)
_run_loop(bundle, cfg)
elif choice == "2":
print("\nDownloading/preparing HuggingFace models...")
prefetch_huggingface_models()
cfg = pick_mode(is_pretrain)
prompt_label = "You" if cfg["sft_mode"] else "Prompt"
while True:
print(f"{prompt_label}:", end=" ", flush=True)
prompt = sys.stdin.readline().strip()
if not prompt or prompt in ("/quit", "/exit", "/q"):
break
compare_all_models(prompt, cfg)
elif choice == "3":
run_benchmark_mode()
else:
print("Enter 1, 2, or 3")
if __name__ == "__main__":
main()