Spaces:
Sleeping
Sleeping
| """Dynamic CLIP model loading and inference utilities.""" | |
| from typing import Optional | |
| import asyncio | |
| import torch | |
| from PIL import Image | |
| from open_clip import create_model_and_transforms, get_tokenizer | |
| # Try to import MobileCLIP reparameterize - it's optional | |
| try: | |
| from mobileclip.modules.common.mobileone import reparameterize_model | |
| MOBILECLIP_AVAILABLE = True | |
| except ImportError: | |
| MOBILECLIP_AVAILABLE = False | |
| reparameterize_model = None | |
| # --- Configuration --- | |
| MODEL_NAME = "MobileCLIP2-S2" | |
| PRETRAINED = "dfndr2b" | |
| IMAGE_SIZE = 256 | |
| EMBEDDING_DIM = 512 | |
| MAX_IMAGE_SIZE_MB = 10 | |
| MAX_BATCH_SIZE = 10 | |
| ALLOWED_EXTENSIONS = {"jpg", "jpeg", "png", "webp"} | |
| # --- Global Model State --- | |
| model = None | |
| preprocess = None | |
| tokenizer = None | |
| device = None | |
| # Create model lock - will be properly initialized in get_model_lock | |
| try: | |
| model_lock = asyncio.Lock() | |
| except RuntimeError: | |
| # If called outside of event loop, create a placeholder | |
| model_lock = None | |
| # Track loading status and current model metadata | |
| loading_status = { | |
| "is_loading": False, | |
| "progress": "idle", | |
| "error": None | |
| } | |
| # Store current model metadata (separate from constants) | |
| model_metadata = { | |
| "model_name": "MobileCLIP2-S2", | |
| "pretrained": "dfndr2b", | |
| "embedding_dim": 512, | |
| "image_size": 256, | |
| "device": "cpu" | |
| } | |
| def get_model_lock(): | |
| """Get or create the model lock.""" | |
| global model_lock | |
| if model_lock is None: | |
| model_lock = asyncio.Lock() | |
| return model_lock | |
| def is_mobileclip_model(model_name: str) -> bool: | |
| """Check if model name is a MobileCLIP variant.""" | |
| return model_name.lower().startswith("mobileclip") | |
| def load_model(model_name: str = "MobileCLIP2-S2", pretrained: str = "dfndr2b") -> dict: | |
| """ | |
| Load any OpenCLIP model dynamically. | |
| Args: | |
| model_name: Model name (e.g., 'MobileCLIP2-S2', 'ViT-B-32') | |
| pretrained: Pretrained weights (e.g., 'dfndr2b', 'openai') | |
| Returns: | |
| dict with status, model_name, embedding_dim, message | |
| """ | |
| global model, preprocess, tokenizer, device, loading_status, model_metadata | |
| # Update loading status | |
| loading_status["is_loading"] = True | |
| loading_status["progress"] = f"Loading {model_name} ({pretrained})..." | |
| loading_status["error"] = None | |
| # Keep reference to old model for rollback | |
| old_model = model | |
| old_preprocess = preprocess | |
| old_tokenizer = tokenizer | |
| old_metadata = model_metadata.copy() | |
| try: | |
| # Determine device | |
| if device is None: | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| print(f"Loading model: {model_name} (pretrained: {pretrained})") | |
| print(f"Device: {device}") | |
| # Load model and preprocessing transform | |
| loading_status["progress"] = "Creating model and transforms..." | |
| new_model, _, new_preprocess = create_model_and_transforms( | |
| model_name, | |
| pretrained=pretrained | |
| ) | |
| new_model = new_model.to(device) | |
| new_model.eval() | |
| # Reparameterize if it's a MobileCLIP model | |
| if is_mobileclip_model(model_name): | |
| if not MOBILECLIP_AVAILABLE: | |
| raise RuntimeError( | |
| "MobileCLIP model requested but mobileclip package not available. " | |
| "Install it with: pip install git+https://github.com/apple/ml-mobileclip.git" | |
| ) | |
| loading_status["progress"] = "Reparameterizing MobileCLIP model..." | |
| print("Reparameterizing MobileCLIP model for inference...") | |
| new_model = reparameterize_model(new_model) | |
| # Load tokenizer | |
| loading_status["progress"] = "Loading tokenizer..." | |
| new_tokenizer = get_tokenizer(model_name) | |
| # Get embedding dimension from model | |
| if hasattr(new_model, 'text_projection'): | |
| new_embedding_dim = new_model.text_projection.shape[1] | |
| elif hasattr(new_model, 'visual') and hasattr(new_model.visual, 'output_dim'): | |
| new_embedding_dim = new_model.visual.output_dim | |
| else: | |
| # Default fallback | |
| new_embedding_dim = 512 | |
| # Get image size from preprocess transform | |
| if hasattr(new_preprocess.transforms[0], 'size'): | |
| size = new_preprocess.transforms[0].size | |
| new_image_size = size if isinstance(size, int) else size[0] | |
| else: | |
| new_image_size = 224 # Default | |
| # Clear old model from memory | |
| if old_model is not None: | |
| loading_status["progress"] = "Clearing old model from memory..." | |
| del old_model | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| # Update global state | |
| model = new_model | |
| preprocess = new_preprocess | |
| tokenizer = new_tokenizer | |
| # Update model metadata | |
| model_metadata["model_name"] = model_name | |
| model_metadata["pretrained"] = pretrained | |
| model_metadata["embedding_dim"] = new_embedding_dim | |
| model_metadata["image_size"] = new_image_size | |
| model_metadata["device"] = str(device) | |
| loading_status["is_loading"] = False | |
| loading_status["progress"] = "complete" | |
| loading_status["error"] = None | |
| print(f"β Model loaded: {model_metadata['model_name']}") | |
| print(f"β Pretrained: {model_metadata['pretrained']}") | |
| print(f"β Device: {model_metadata['device']}") | |
| print(f"β Embedding dim: {model_metadata['embedding_dim']}") | |
| print(f"β Image size: {model_metadata['image_size']}") | |
| return { | |
| "status": "success", | |
| "model_name": model_metadata["model_name"], | |
| "pretrained": model_metadata["pretrained"], | |
| "embedding_dim": model_metadata["embedding_dim"], | |
| "image_size": model_metadata["image_size"], | |
| "device": model_metadata["device"], | |
| "message": f"Successfully loaded {model_metadata['model_name']}" | |
| } | |
| except Exception as e: | |
| # Rollback on error | |
| print(f"β Failed to load model: {e}") | |
| print("Rolling back to previous model...") | |
| model = old_model | |
| preprocess = old_preprocess | |
| tokenizer = old_tokenizer | |
| model_metadata.update(old_metadata) | |
| loading_status["is_loading"] = False | |
| loading_status["progress"] = "error" | |
| loading_status["error"] = str(e) | |
| if old_model is None: | |
| # First load failed - re-raise | |
| raise | |
| return { | |
| "status": "error", | |
| "model_name": model_metadata["model_name"], | |
| "pretrained": model_metadata["pretrained"], | |
| "embedding_dim": model_metadata["embedding_dim"], | |
| "image_size": model_metadata["image_size"], | |
| "device": model_metadata["device"], | |
| "message": f"Failed to load {model_name}: {str(e)}. Kept previous model.", | |
| "error": str(e) | |
| } | |
| def get_model_status() -> dict: | |
| """Get current model loading status.""" | |
| return { | |
| "is_loading": loading_status["is_loading"], | |
| "progress": loading_status["progress"], | |
| "error": loading_status["error"], | |
| "current_model": model_metadata["model_name"], | |
| "pretrained": model_metadata["pretrained"], | |
| "embedding_dim": model_metadata["embedding_dim"], | |
| "image_size": model_metadata["image_size"], | |
| "device": model_metadata["device"] | |
| } | |
| def preprocess_image(image: Image.Image) -> torch.Tensor: | |
| """ | |
| Preprocess image for CLIP model. | |
| Args: | |
| image: PIL Image | |
| Returns: | |
| Preprocessed torch tensor | |
| """ | |
| if preprocess is None: | |
| raise RuntimeError("Model not loaded. Call load_model() first.") | |
| image = image.convert("RGB") | |
| img_tensor = preprocess(image).unsqueeze(0) # Add batch dimension | |
| return img_tensor.to(device) | |
| def normalize_embedding(embedding: torch.Tensor) -> torch.Tensor: | |
| """L2 normalize embedding.""" | |
| return torch.nn.functional.normalize(embedding, p=2, dim=-1) | |