Cloudzy / model.py
GitHub Actions
πŸš€ Deploy embedder from GitHub Actions - 2025-10-29 00:58:00
857dcde
"""Dynamic CLIP model loading and inference utilities."""
from typing import Optional
import asyncio
import torch
from PIL import Image
from open_clip import create_model_and_transforms, get_tokenizer
# Try to import MobileCLIP reparameterize - it's optional
try:
from mobileclip.modules.common.mobileone import reparameterize_model
MOBILECLIP_AVAILABLE = True
except ImportError:
MOBILECLIP_AVAILABLE = False
reparameterize_model = None
# --- Configuration ---
MODEL_NAME = "MobileCLIP2-S2"
PRETRAINED = "dfndr2b"
IMAGE_SIZE = 256
EMBEDDING_DIM = 512
MAX_IMAGE_SIZE_MB = 10
MAX_BATCH_SIZE = 10
ALLOWED_EXTENSIONS = {"jpg", "jpeg", "png", "webp"}
# --- Global Model State ---
model = None
preprocess = None
tokenizer = None
device = None
# Create model lock - will be properly initialized in get_model_lock
try:
model_lock = asyncio.Lock()
except RuntimeError:
# If called outside of event loop, create a placeholder
model_lock = None
# Track loading status and current model metadata
loading_status = {
"is_loading": False,
"progress": "idle",
"error": None
}
# Store current model metadata (separate from constants)
model_metadata = {
"model_name": "MobileCLIP2-S2",
"pretrained": "dfndr2b",
"embedding_dim": 512,
"image_size": 256,
"device": "cpu"
}
def get_model_lock():
"""Get or create the model lock."""
global model_lock
if model_lock is None:
model_lock = asyncio.Lock()
return model_lock
def is_mobileclip_model(model_name: str) -> bool:
"""Check if model name is a MobileCLIP variant."""
return model_name.lower().startswith("mobileclip")
def load_model(model_name: str = "MobileCLIP2-S2", pretrained: str = "dfndr2b") -> dict:
"""
Load any OpenCLIP model dynamically.
Args:
model_name: Model name (e.g., 'MobileCLIP2-S2', 'ViT-B-32')
pretrained: Pretrained weights (e.g., 'dfndr2b', 'openai')
Returns:
dict with status, model_name, embedding_dim, message
"""
global model, preprocess, tokenizer, device, loading_status, model_metadata
# Update loading status
loading_status["is_loading"] = True
loading_status["progress"] = f"Loading {model_name} ({pretrained})..."
loading_status["error"] = None
# Keep reference to old model for rollback
old_model = model
old_preprocess = preprocess
old_tokenizer = tokenizer
old_metadata = model_metadata.copy()
try:
# Determine device
if device is None:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Loading model: {model_name} (pretrained: {pretrained})")
print(f"Device: {device}")
# Load model and preprocessing transform
loading_status["progress"] = "Creating model and transforms..."
new_model, _, new_preprocess = create_model_and_transforms(
model_name,
pretrained=pretrained
)
new_model = new_model.to(device)
new_model.eval()
# Reparameterize if it's a MobileCLIP model
if is_mobileclip_model(model_name):
if not MOBILECLIP_AVAILABLE:
raise RuntimeError(
"MobileCLIP model requested but mobileclip package not available. "
"Install it with: pip install git+https://github.com/apple/ml-mobileclip.git"
)
loading_status["progress"] = "Reparameterizing MobileCLIP model..."
print("Reparameterizing MobileCLIP model for inference...")
new_model = reparameterize_model(new_model)
# Load tokenizer
loading_status["progress"] = "Loading tokenizer..."
new_tokenizer = get_tokenizer(model_name)
# Get embedding dimension from model
if hasattr(new_model, 'text_projection'):
new_embedding_dim = new_model.text_projection.shape[1]
elif hasattr(new_model, 'visual') and hasattr(new_model.visual, 'output_dim'):
new_embedding_dim = new_model.visual.output_dim
else:
# Default fallback
new_embedding_dim = 512
# Get image size from preprocess transform
if hasattr(new_preprocess.transforms[0], 'size'):
size = new_preprocess.transforms[0].size
new_image_size = size if isinstance(size, int) else size[0]
else:
new_image_size = 224 # Default
# Clear old model from memory
if old_model is not None:
loading_status["progress"] = "Clearing old model from memory..."
del old_model
if torch.cuda.is_available():
torch.cuda.empty_cache()
# Update global state
model = new_model
preprocess = new_preprocess
tokenizer = new_tokenizer
# Update model metadata
model_metadata["model_name"] = model_name
model_metadata["pretrained"] = pretrained
model_metadata["embedding_dim"] = new_embedding_dim
model_metadata["image_size"] = new_image_size
model_metadata["device"] = str(device)
loading_status["is_loading"] = False
loading_status["progress"] = "complete"
loading_status["error"] = None
print(f"βœ“ Model loaded: {model_metadata['model_name']}")
print(f"βœ“ Pretrained: {model_metadata['pretrained']}")
print(f"βœ“ Device: {model_metadata['device']}")
print(f"βœ“ Embedding dim: {model_metadata['embedding_dim']}")
print(f"βœ“ Image size: {model_metadata['image_size']}")
return {
"status": "success",
"model_name": model_metadata["model_name"],
"pretrained": model_metadata["pretrained"],
"embedding_dim": model_metadata["embedding_dim"],
"image_size": model_metadata["image_size"],
"device": model_metadata["device"],
"message": f"Successfully loaded {model_metadata['model_name']}"
}
except Exception as e:
# Rollback on error
print(f"βœ— Failed to load model: {e}")
print("Rolling back to previous model...")
model = old_model
preprocess = old_preprocess
tokenizer = old_tokenizer
model_metadata.update(old_metadata)
loading_status["is_loading"] = False
loading_status["progress"] = "error"
loading_status["error"] = str(e)
if old_model is None:
# First load failed - re-raise
raise
return {
"status": "error",
"model_name": model_metadata["model_name"],
"pretrained": model_metadata["pretrained"],
"embedding_dim": model_metadata["embedding_dim"],
"image_size": model_metadata["image_size"],
"device": model_metadata["device"],
"message": f"Failed to load {model_name}: {str(e)}. Kept previous model.",
"error": str(e)
}
def get_model_status() -> dict:
"""Get current model loading status."""
return {
"is_loading": loading_status["is_loading"],
"progress": loading_status["progress"],
"error": loading_status["error"],
"current_model": model_metadata["model_name"],
"pretrained": model_metadata["pretrained"],
"embedding_dim": model_metadata["embedding_dim"],
"image_size": model_metadata["image_size"],
"device": model_metadata["device"]
}
def preprocess_image(image: Image.Image) -> torch.Tensor:
"""
Preprocess image for CLIP model.
Args:
image: PIL Image
Returns:
Preprocessed torch tensor
"""
if preprocess is None:
raise RuntimeError("Model not loaded. Call load_model() first.")
image = image.convert("RGB")
img_tensor = preprocess(image).unsqueeze(0) # Add batch dimension
return img_tensor.to(device)
def normalize_embedding(embedding: torch.Tensor) -> torch.Tensor:
"""L2 normalize embedding."""
return torch.nn.functional.normalize(embedding, p=2, dim=-1)