"""Dynamic CLIP model loading and inference utilities.""" from typing import Optional import asyncio import torch from PIL import Image from open_clip import create_model_and_transforms, get_tokenizer # Try to import MobileCLIP reparameterize - it's optional try: from mobileclip.modules.common.mobileone import reparameterize_model MOBILECLIP_AVAILABLE = True except ImportError: MOBILECLIP_AVAILABLE = False reparameterize_model = None # --- Configuration --- MODEL_NAME = "MobileCLIP2-S2" PRETRAINED = "dfndr2b" IMAGE_SIZE = 256 EMBEDDING_DIM = 512 MAX_IMAGE_SIZE_MB = 10 MAX_BATCH_SIZE = 10 ALLOWED_EXTENSIONS = {"jpg", "jpeg", "png", "webp"} # --- Global Model State --- model = None preprocess = None tokenizer = None device = None # Create model lock - will be properly initialized in get_model_lock try: model_lock = asyncio.Lock() except RuntimeError: # If called outside of event loop, create a placeholder model_lock = None # Track loading status and current model metadata loading_status = { "is_loading": False, "progress": "idle", "error": None } # Store current model metadata (separate from constants) model_metadata = { "model_name": "MobileCLIP2-S2", "pretrained": "dfndr2b", "embedding_dim": 512, "image_size": 256, "device": "cpu" } def get_model_lock(): """Get or create the model lock.""" global model_lock if model_lock is None: model_lock = asyncio.Lock() return model_lock def is_mobileclip_model(model_name: str) -> bool: """Check if model name is a MobileCLIP variant.""" return model_name.lower().startswith("mobileclip") def load_model(model_name: str = "MobileCLIP2-S2", pretrained: str = "dfndr2b") -> dict: """ Load any OpenCLIP model dynamically. Args: model_name: Model name (e.g., 'MobileCLIP2-S2', 'ViT-B-32') pretrained: Pretrained weights (e.g., 'dfndr2b', 'openai') Returns: dict with status, model_name, embedding_dim, message """ global model, preprocess, tokenizer, device, loading_status, model_metadata # Update loading status loading_status["is_loading"] = True loading_status["progress"] = f"Loading {model_name} ({pretrained})..." loading_status["error"] = None # Keep reference to old model for rollback old_model = model old_preprocess = preprocess old_tokenizer = tokenizer old_metadata = model_metadata.copy() try: # Determine device if device is None: device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Loading model: {model_name} (pretrained: {pretrained})") print(f"Device: {device}") # Load model and preprocessing transform loading_status["progress"] = "Creating model and transforms..." new_model, _, new_preprocess = create_model_and_transforms( model_name, pretrained=pretrained ) new_model = new_model.to(device) new_model.eval() # Reparameterize if it's a MobileCLIP model if is_mobileclip_model(model_name): if not MOBILECLIP_AVAILABLE: raise RuntimeError( "MobileCLIP model requested but mobileclip package not available. " "Install it with: pip install git+https://github.com/apple/ml-mobileclip.git" ) loading_status["progress"] = "Reparameterizing MobileCLIP model..." print("Reparameterizing MobileCLIP model for inference...") new_model = reparameterize_model(new_model) # Load tokenizer loading_status["progress"] = "Loading tokenizer..." new_tokenizer = get_tokenizer(model_name) # Get embedding dimension from model if hasattr(new_model, 'text_projection'): new_embedding_dim = new_model.text_projection.shape[1] elif hasattr(new_model, 'visual') and hasattr(new_model.visual, 'output_dim'): new_embedding_dim = new_model.visual.output_dim else: # Default fallback new_embedding_dim = 512 # Get image size from preprocess transform if hasattr(new_preprocess.transforms[0], 'size'): size = new_preprocess.transforms[0].size new_image_size = size if isinstance(size, int) else size[0] else: new_image_size = 224 # Default # Clear old model from memory if old_model is not None: loading_status["progress"] = "Clearing old model from memory..." del old_model if torch.cuda.is_available(): torch.cuda.empty_cache() # Update global state model = new_model preprocess = new_preprocess tokenizer = new_tokenizer # Update model metadata model_metadata["model_name"] = model_name model_metadata["pretrained"] = pretrained model_metadata["embedding_dim"] = new_embedding_dim model_metadata["image_size"] = new_image_size model_metadata["device"] = str(device) loading_status["is_loading"] = False loading_status["progress"] = "complete" loading_status["error"] = None print(f"✓ Model loaded: {model_metadata['model_name']}") print(f"✓ Pretrained: {model_metadata['pretrained']}") print(f"✓ Device: {model_metadata['device']}") print(f"✓ Embedding dim: {model_metadata['embedding_dim']}") print(f"✓ Image size: {model_metadata['image_size']}") return { "status": "success", "model_name": model_metadata["model_name"], "pretrained": model_metadata["pretrained"], "embedding_dim": model_metadata["embedding_dim"], "image_size": model_metadata["image_size"], "device": model_metadata["device"], "message": f"Successfully loaded {model_metadata['model_name']}" } except Exception as e: # Rollback on error print(f"✗ Failed to load model: {e}") print("Rolling back to previous model...") model = old_model preprocess = old_preprocess tokenizer = old_tokenizer model_metadata.update(old_metadata) loading_status["is_loading"] = False loading_status["progress"] = "error" loading_status["error"] = str(e) if old_model is None: # First load failed - re-raise raise return { "status": "error", "model_name": model_metadata["model_name"], "pretrained": model_metadata["pretrained"], "embedding_dim": model_metadata["embedding_dim"], "image_size": model_metadata["image_size"], "device": model_metadata["device"], "message": f"Failed to load {model_name}: {str(e)}. Kept previous model.", "error": str(e) } def get_model_status() -> dict: """Get current model loading status.""" return { "is_loading": loading_status["is_loading"], "progress": loading_status["progress"], "error": loading_status["error"], "current_model": model_metadata["model_name"], "pretrained": model_metadata["pretrained"], "embedding_dim": model_metadata["embedding_dim"], "image_size": model_metadata["image_size"], "device": model_metadata["device"] } def preprocess_image(image: Image.Image) -> torch.Tensor: """ Preprocess image for CLIP model. Args: image: PIL Image Returns: Preprocessed torch tensor """ if preprocess is None: raise RuntimeError("Model not loaded. Call load_model() first.") image = image.convert("RGB") img_tensor = preprocess(image).unsqueeze(0) # Add batch dimension return img_tensor.to(device) def normalize_embedding(embedding: torch.Tensor) -> torch.Tensor: """L2 normalize embedding.""" return torch.nn.functional.normalize(embedding, p=2, dim=-1)