from typing import Optional, List import torch from torch import nn from torch.nn import functional as F from transformers import PreTrainedModel, AutoModel, AutoModelForCausalLM, Qwen2_5_VLForConditionalGeneration from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLCausalLMOutputWithPast from typing import List, Optional, Tuple, Union, Dict import torch from torch.nn import CrossEntropyLoss from transformers.modeling_outputs import CausalLMOutputWithPast from transformers.generation.utils import GenerateOutput from .configuration_qqmm import QQMMConfig def _prepare_4d_causal_attention_mask_with_cache_position( attention_mask: torch.Tensor, sequence_length: int, target_length: int, dtype: torch.dtype, device: torch.device, min_dtype: float, cache_position: torch.Tensor, batch_size: int, ): """ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. Args: attention_mask (`torch.Tensor`): A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`. sequence_length (`int`): The sequence length being processed. target_length (`int`): The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. dtype (`torch.dtype`): The dtype to use for the 4D attention mask. device (`torch.device`): The device to plcae the 4D attention mask on. min_dtype (`float`): The minimum value representable with the dtype `dtype`. cache_position (`torch.Tensor`): Indices depicting the position of the input sequence tokens in the sequence. batch_size (`torch.Tensor`): Batch size. """ if attention_mask is not None and attention_mask.dim() == 4: # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. causal_mask = attention_mask else: causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device) if sequence_length != 1: causal_mask = torch.triu(causal_mask, diagonal=1) causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) if attention_mask is not None: causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit mask_length = attention_mask.shape[-1] padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] padding_mask = padding_mask == 0 causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( padding_mask, min_dtype ) return causal_mask def padcat_sequences(sequences, value=0, pad_side='right'): if all(s is None for s in sequences): return None max_l = max(s.size(1) for s in sequences) sequences_ = [] for seq in sequences: if seq.size(1) != max_l: pad_len = max_l - seq.size(1) pad_len = (0, pad_len) if pad_side == 'right' else (pad_len, 0) seq = F.pad(seq, pad_len, value=value) sequences_.append(seq) sequences = torch.cat(sequences_) return sequences class QQMMPreTrainedModel(PreTrainedModel): config_class = QQMMConfig supports_gradient_checkpointing = True _skip_keys_device_placement = "past_key_values" _supports_cache_class = True _supports_flash_attn_2 = True _supports_sdpa = True class QQMMForCausalLM(QQMMPreTrainedModel): def __init__(self, config, qwen2_5_vl_model=None): super().__init__(config) if qwen2_5_vl_model is None: kwargs_ = {} if config._attn_implementation_internal is not None: kwargs_['attn_implementation'] = config._attn_implementation_internal model = Qwen2_5_VLForConditionalGeneration(config.model_config) # model = Qwen2_5_VLForConditionalGeneration.from_pretrained("/group/40048/windzhchen/pretrain_models/deepeyes_convert") else: model = qwen2_5_vl_model self.qwen2_5_vl_model = model self.post_init() def make_diy_mask(self, input_ids, attention_mask, embed_token_id, im_start_id, im_end_id): if len(attention_mask.shape) == 2: sequence_length = attention_mask.shape[1] target_length = attention_mask.shape[1] dtype = torch.bfloat16 device = input_ids.device min_dtype = torch.finfo(dtype).min cache_position = torch.arange(0, sequence_length, device=attention_mask.device) attention_mask = _prepare_4d_causal_attention_mask_with_cache_position( attention_mask, sequence_length=sequence_length, target_length=target_length, dtype=dtype, device=device, min_dtype=min_dtype, cache_position=cache_position, batch_size=attention_mask.shape[0], ) else: dtype = torch.bfloat16 min_dtype = torch.finfo(dtype).min mask = input_ids == embed_token_id embed_index = torch.argmax(mask.float(), dim=1) embed_index[embed_index==0] = input_ids.shape[1] embed_index = embed_index.view(-1, ) mask = input_ids == im_start_id im_start_index_tmp = torch.argmax(mask.float(), dim=1).view(-1, 1) mask = torch.scatter(mask, dim=1, index=im_start_index_tmp, value=False) im_start_index = torch.argmax(mask.float(), dim=1).view(-1, ) mask = input_ids == im_end_id im_end_index_tmp = torch.argmax(mask.float(), dim=1).view(-1, 1) mask = torch.scatter(mask, dim=1, index=im_end_index_tmp, value=False) im_end_index = torch.argmax(mask.float(), dim=1).view(-1, ) for b in range(attention_mask.shape[0]): attention_mask[b, 0, embed_index[b]+1:, im_start_index[b]:im_end_index[b]+2] = min_dtype # <|im_start|>user\nxxxxx<|im_end|>\n return attention_mask def forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, pixel_values: Optional[torch.Tensor] = None, pixel_values_videos: Optional[torch.FloatTensor] = None, image_grid_thw: Optional[torch.LongTensor] = None, video_grid_thw: Optional[torch.LongTensor] = None, rope_deltas: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None, second_per_grid_ts: Optional[torch.Tensor] = None, embed_token_id: Optional[int] = None, return_emb: Optional[bool] = False, cal_loss: Optional[bool] = False ) -> Union[Tuple, Qwen2_5_VLCausalLMOutputWithPast]: if pixel_values is not None and pixel_values.shape[0] == 0: pixel_values = None image_grid_thw = None output_attentions = output_attentions if output_attentions is not None else self.qwen2_5_vl_model.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.qwen2_5_vl_model.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.qwen2_5_vl_model.config.use_return_dict if inputs_embeds is None: inputs_embeds = self.qwen2_5_vl_model.model.embed_tokens(input_ids) if pixel_values is not None: pixel_values = pixel_values.type(self.qwen2_5_vl_model.visual.dtype) image_embeds = self.qwen2_5_vl_model.visual(pixel_values, grid_thw=image_grid_thw) n_image_tokens = (input_ids == self.qwen2_5_vl_model.config.image_token_id).sum().item() n_image_features = image_embeds.shape[0] if n_image_tokens != n_image_features: raise ValueError( f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" ) mask = input_ids == self.qwen2_5_vl_model.config.image_token_id mask_unsqueezed = mask.unsqueeze(-1) mask_expanded = mask_unsqueezed.expand_as(inputs_embeds) image_mask = mask_expanded.to(inputs_embeds.device) image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype) inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) if pixel_values_videos is not None: pixel_values_videos = pixel_values_videos.type(self.qwen2_5_vl_model.visual.dtype) video_embeds = self.qwen2_5_vl_model.visual(pixel_values_videos, grid_thw=video_grid_thw) n_video_tokens = (input_ids == self.qwen2_5_vl_model.config.video_token_id).sum().item() n_video_features = video_embeds.shape[0] if n_video_tokens != n_video_features: raise ValueError( f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}" ) mask = input_ids == self.qwen2_5_vl_model.config.video_token_id mask_unsqueezed = mask.unsqueeze(-1) mask_expanded = mask_unsqueezed.expand_as(inputs_embeds) video_mask = mask_expanded.to(inputs_embeds.device) video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype) inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) if attention_mask is not None: attention_mask = attention_mask.to(inputs_embeds.device) # if we get 4D attention mask we cannot calculate rope deltas anymore. TODO @raushan fixme if position_ids is None and (attention_mask is None or attention_mask.ndim == 2): # calculate RoPE index once per generation in the pre-fill stage only if ( (cache_position is not None and cache_position[0] == 0) or self.qwen2_5_vl_model.rope_deltas is None or (past_key_values is None or past_key_values.get_seq_length() == 0) ): position_ids, rope_deltas = self.qwen2_5_vl_model.get_rope_index( input_ids, image_grid_thw, video_grid_thw, second_per_grid_ts, attention_mask, ) self.rope_deltas = rope_deltas # then use the prev pre-calculated rope-deltas to get the correct position ids else: batch_size, seq_length, _ = inputs_embeds.shape delta = ( (cache_position[0] + self.rope_deltas).to(inputs_embeds.device) if cache_position is not None else 0 ) position_ids = torch.arange(seq_length, device=inputs_embeds.device) position_ids = position_ids.view(1, -1).expand(batch_size, -1) if cache_position is not None: # otherwise `deltas` is an int `0` delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0) position_ids = position_ids.add(delta) position_ids = position_ids.unsqueeze(0).expand(3, -1, -1) outputs = self.qwen2_5_vl_model.model( input_ids=None, position_ids=position_ids, attention_mask=attention_mask, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, cache_position=cache_position, ) hidden_states = outputs[0] if labels is not None: mask = labels == embed_token_id labels[mask] = -100 logits = self.qwen2_5_vl_model.lm_head(hidden_states) if return_emb: assert labels is not None, 'labels must be provided to obtain embed' hidden_index = torch.argmax(mask.float(), dim=1) hidden_index[hidden_index==0] = labels.shape[1] hidden_states = torch.gather(hidden_states, dim=1, index=(hidden_index-1).view(hidden_index.shape[0], 1, 1).repeat(1, 1, hidden_states.shape[-1])) emb = hidden_states[:, 0, :].contiguous() # B, C else: emb = None loss = None if labels is not None and cal_loss: # Upcast to float if we need to compute the loss to avoid potential precision issues logits = logits.float() # Shift so that tokens < n predict n shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() if (shift_labels < 0).all().item(): loss = 0.0 else: # Flatten the tokens loss_fct = CrossEntropyLoss() shift_logits = shift_logits.view(-1, self.qwen2_5_vl_model.config.vocab_size) shift_labels = shift_labels.view(-1) # Enable model parallelism shift_labels = shift_labels.to(shift_logits.device) loss = loss_fct(shift_logits, shift_labels) if not return_dict: output = (logits,) + outputs[1:] return (loss,) + output if loss is not None else output else: outputs = Qwen2_5_VLCausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, rope_deltas=self.qwen2_5_vl_model.rope_deltas, ) if emb is not None: outputs['emb'] = emb return outputs @torch.no_grad() def generate(self, input_ids, *args, **kwargs) -> Union[GenerateOutput, torch.LongTensor]: return self.qwen2_5_vl_model.generate(input_ids, *args, **kwargs) def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None): super().gradient_checkpointing_enable(gradient_checkpointing_kwargs) self.qwen2_5_vl_model.model.enable_input_require_grads() def get_input_embeddings(self): return self.qwen2_5_vl_model.model.get_input_embeddings() def set_input_embeddings(self, value): self.qwen2_5_vl_model.model.set_input_embeddings(value) def get_output_embeddings(self): return self.qwen2_5_vl_model.lm_head def set_output_embeddings(self, new_embeddings): self.qwen2_5_vl_model.lm_head = new_embeddings def set_decoder(self, decoder): self.qwen2_5_vl_model.model = decoder def get_decoder(self): return self.qwen2_5_vl_model.model