from __future__ import annotations from typing import Optional, Any import torch from torch import nn from transformers.cache_utils import Cache # kept for potential future use from transformers.models.qwen3.modeling_qwen3 import ( Qwen3Attention, Qwen3DecoderLayer, Qwen3MLP, Qwen3RMSNorm, Qwen3Model, Qwen3ForCausalLM, Qwen3PreTrainedModel, ) from transformers.configuration_utils import PretrainedConfig from transformers.utils import logging from transformers.modeling_utils import PreTrainedModel try: from peft import PeftModel except ImportError: PeftModel = Any # soft dependency logger = logging.get_logger(__name__) # --------------------------------------------------------------------------- # 1) Bidirectional attention: disable causal masking & sliding window # --------------------------------------------------------------------------- class ModifiedQwen3Attention(Qwen3Attention): """Full-context self-attention (no causal mask).""" def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.is_causal = False self.sliding_window = None # --------------------------------------------------------------------------- # 2) Decoder layer using the bidirectional attention module # --------------------------------------------------------------------------- class ModifiedQwen3DecoderLayer(Qwen3DecoderLayer): """Decoder layer with full-context attention.""" def __init__(self, config: PretrainedConfig, layer_idx: int): super().__init__(config, layer_idx) self.self_attn = ModifiedQwen3Attention(config=config, layer_idx=layer_idx) self.attention_type = "full_attention" self.sliding_window = None # --------------------------------------------------------------------------- # 3) Backbone: Qwen-3 with bidirectional self-attention # --------------------------------------------------------------------------- class Qwen3BiModel(Qwen3Model): """Qwen-3 backbone whose self-attention is bidirectional.""" _no_split_modules = ["ModifiedQwen3DecoderLayer"] def __init__(self, config: PretrainedConfig): super().__init__(config) self.layers = nn.ModuleList( [ModifiedQwen3DecoderLayer(config, i) for i in range(config.num_hidden_layers)] ) self.has_sliding_layers = False @staticmethod def _build_pad_bias(pad_mask: torch.Tensor, dtype: torch.dtype) -> torch.Tensor: """[B,L] -> additive bias [B,1,1,L] with -inf on padding.""" neg_inf = torch.finfo(dtype).min bias = (~pad_mask.bool()).to(dtype) * neg_inf return bias[:, None, None, :] def forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, **kwargs, ): # Default to keep-all if no mask is provided if attention_mask is None: if input_ids is None: raise ValueError("Either attention_mask or input_ids must be provided.") attention_mask = torch.ones_like(input_ids, dtype=torch.bool) pad_bias = self._build_pad_bias(attention_mask, self.embed_tokens.weight.dtype) # Dict mask tells parent to skip causal-mask generation attn_mask_dict = {"full_attention": pad_bias} return super().forward( input_ids=input_ids, attention_mask=attn_mask_dict, **kwargs, ) # --------------------------------------------------------------------------- # 4) Task head: MNTP (masked next-token) — no generation API # --------------------------------------------------------------------------- class Qwen3BiForMNTP(Qwen3ForCausalLM): """Bidirectional Qwen-3 with LM head for masked-token objectives.""" def __init__(self, config: PretrainedConfig): # Bypass parent __init__ to wire a custom backbone Qwen3PreTrainedModel.__init__(self, config) self.model = Qwen3BiModel(config) self.vocab_size = config.vocab_size self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) self.post_init() def generate(self, *args, **kwargs): # type: ignore[override] """Disabled: bidirectional backbone is not autoregressive.""" raise NotImplementedError( "generate() is disabled: this backbone is bidirectional and not autoregressive." ) # -------- PEFT helpers -------- def get_model_for_peft(self): return self.model def set_model_for_peft(self, model: PeftModel): # type: ignore[override] self.model = model def save_peft_model(self, path: str): if isinstance(self.model, PeftModel): # type: ignore[arg-type] self.model.save_pretrained(path) else: raise ValueError("Backbone is not a PEFT model; nothing to save.")