"""FastAPI service for dynamic CLIP embeddings.""" import time from io import BytesIO from typing import List, Optional import numpy as np import torch from fastapi import FastAPI, File, UploadFile, HTTPException, status from fastapi.responses import JSONResponse from PIL import Image from pydantic import BaseModel, Field import model as model_module from model import ( MODEL_NAME, IMAGE_SIZE, EMBEDDING_DIM, MAX_IMAGE_SIZE_MB, MAX_BATCH_SIZE, ALLOWED_EXTENSIONS, load_model, preprocess_image, normalize_embedding, get_model_lock, get_model_status ) # --- Pydantic Models --- class EmbeddingResponse(BaseModel): """Single embedding response.""" embedding: List[float] model: str inference_time_ms: float class BatchEmbeddingResponse(BaseModel): """Batch embedding response.""" embeddings: List[List[float]] count: int total_time_ms: float model: str class TextEmbeddingRequest(BaseModel): """Text embedding request.""" text: str = Field(..., min_length=1, max_length=1000) class TextEmbeddingResponse(BaseModel): """Text embedding response.""" embedding: List[float] model: str inference_time_ms: float text: str class HealthResponse(BaseModel): """Health check response.""" status: str model: str device: str backend: str class InfoResponse(BaseModel): """Model info response.""" model: str embedding_dim: int backend: str max_image_size_mb: int max_batch_size: int image_size: int class ModelUpdateRequest(BaseModel): """Model update request.""" model_name: str = Field(..., min_length=1, description="OpenCLIP model name (e.g., 'ViT-B-32', 'MobileCLIP2-S2')") pretrained: str = Field(..., min_length=1, description="Pretrained weights (e.g., 'openai', 'dfndr2b')") class ModelUpdateResponse(BaseModel): """Model update response.""" status: str model_name: str pretrained: str embedding_dim: int image_size: int device: str message: str error: Optional[str] = None class ModelStatusResponse(BaseModel): """Model status response.""" is_loading: bool progress: str error: Optional[str] = None current_model: str pretrained: str embedding_dim: int image_size: int device: str # --- FastAPI App --- app = FastAPI( title="Dynamic CLIP Embedder", description="PyTorch-based image embedding service with dynamic model loading", version="3.0.0" ) @app.on_event("startup") async def startup_event(): """Initialize model on startup.""" load_model() # --- Model Management Endpoints --- @app.post("/model/update", response_model=ModelUpdateResponse) async def update_model(request: ModelUpdateRequest): """ Update the active model. This endpoint allows switching between different OpenCLIP models at runtime. The operation is thread-safe and will reject inference requests during the swap. Args: request: Model name and pretrained weights Returns: Status of the model update operation Raises: 503: Model is currently being updated 500: Model update failed Examples: - MobileCLIP2-S2 with dfndr2b weights - ViT-B-32 with openai weights - ViT-L-14 with laion2b_s32b_b82k weights """ # Check if already loading status_obj = get_model_status() if status_obj["is_loading"]: raise HTTPException( status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail="Model update already in progress" ) # Acquire lock to prevent inference during model swap async with get_model_lock(): try: result = load_model(request.model_name, request.pretrained) return ModelUpdateResponse( status=result["status"], model_name=result["model_name"], pretrained=result["pretrained"], embedding_dim=result["embedding_dim"], image_size=result["image_size"], device=result["device"], message=result["message"], error=result.get("error") ) except Exception as e: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to update model: {str(e)}" ) @app.get("/model/status", response_model=ModelStatusResponse) async def get_model_loading_status(): """ Get current model loading status. Returns: Current model status including loading progress """ try: status = get_model_status() return status except Exception as e: import traceback print(f"Error in /model/status endpoint: {e}") traceback.print_exc() return { "is_loading": False, "progress": "error", "error": f"Failed to get model status: {str(e)}", "current_model": "unknown", "pretrained": "unknown", "embedding_dim": 0, "image_size": 0, "device": "unknown" } # --- Validation --- def validate_image_file(file: UploadFile) -> None: """Validate uploaded file.""" # Check extension if file.filename: ext = file.filename.split(".")[-1].lower() if ext not in ALLOWED_EXTENSIONS: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail=f"Invalid file type. Allowed: {ALLOWED_EXTENSIONS}" ) # Check size (if available) if hasattr(file, "size") and file.size: max_bytes = MAX_IMAGE_SIZE_MB * 1024 * 1024 if file.size > max_bytes: raise HTTPException( status_code=status.HTTP_413_REQUEST_ENTITY_TOO_LARGE, detail=f"File too large. Max: {MAX_IMAGE_SIZE_MB}MB" ) async def load_image_from_upload(file: UploadFile) -> Image.Image: """Load PIL Image from upload.""" try: contents = await file.read() image = Image.open(BytesIO(contents)) return image except Exception as e: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail=f"Invalid image file: {str(e)}" ) # --- Endpoints --- @app.get("/", response_model=HealthResponse) async def health_check(): """Health check endpoint.""" return HealthResponse( status="healthy", model=MODEL_NAME, device=str(model_module.device), backend="pytorch" ) @app.get("/info", response_model=InfoResponse) async def get_info(): """Get model information.""" return InfoResponse( model=MODEL_NAME, embedding_dim=EMBEDDING_DIM, backend="pytorch", max_image_size_mb=MAX_IMAGE_SIZE_MB, max_batch_size=MAX_BATCH_SIZE, image_size=IMAGE_SIZE ) @app.post("/embed", response_model=EmbeddingResponse) async def generate_embedding(file: UploadFile = File(...)): """ Generate embedding for single image. Args: file: Image file (JPEG, PNG, WebP) Returns: N-dimensional embedding vector (dimension depends on model) Raises: 400: Invalid file format 413: File too large 500: Inference error 503: Model is being updated """ start_time = time.time() # Validate validate_image_file(file) try: # Acquire lock to ensure model isn't being swapped async with get_model_lock(): if model_module.model is None: raise HTTPException( status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail="Model not loaded" ) # Load and preprocess image = await load_image_from_upload(file) img_tensor = preprocess_image(image) # Run inference with torch.no_grad(): embedding = model_module.model.encode_image(img_tensor) embedding = normalize_embedding(embedding) # Convert to numpy and then to list embedding = embedding.cpu().numpy()[0] # Calculate time inference_time = (time.time() - start_time) * 1000 return EmbeddingResponse( embedding=embedding.tolist(), model=model_module.model_metadata["model_name"], inference_time_ms=round(inference_time, 2) ) except HTTPException: raise except Exception as e: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Inference failed: {str(e)}" ) @app.post("/embed/batch", response_model=BatchEmbeddingResponse) async def generate_batch_embeddings(files: List[UploadFile] = File(...)): """ Generate embeddings for multiple images. Args: files: List of image files (max 10) Returns: List of N-dimensional embeddings (dimension depends on model) Raises: 400: Invalid files or too many files 500: Inference error 503: Model is being updated """ start_time = time.time() # Validate batch size if len(files) > MAX_BATCH_SIZE: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail=f"Too many files. Max batch size: {MAX_BATCH_SIZE}" ) embeddings = [] try: # Acquire lock to ensure model isn't being swapped async with get_model_lock(): if model_module.model is None: raise HTTPException( status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail="Model not loaded" ) # Process each image img_tensors = [] for file in files: validate_image_file(file) image = await load_image_from_upload(file) img_tensor = preprocess_image(image) img_tensors.append(img_tensor) # Batch inference batch_tensor = torch.cat(img_tensors, dim=0) with torch.no_grad(): batch_embeddings = model_module.model.encode_image(batch_tensor) batch_embeddings = normalize_embedding(batch_embeddings) # Convert to list batch_embeddings = batch_embeddings.cpu().numpy() embeddings = [emb.tolist() for emb in batch_embeddings] # Calculate time total_time = (time.time() - start_time) * 1000 return BatchEmbeddingResponse( embeddings=embeddings, count=len(embeddings), total_time_ms=round(total_time, 2), model=model_module.model_metadata["model_name"] ) except HTTPException: raise except Exception as e: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Batch inference failed: {str(e)}" ) @app.post("/embed/text", response_model=TextEmbeddingResponse) async def generate_text_embedding(request: TextEmbeddingRequest): """ Generate embedding for text query. Args: request: Text to embed Returns: N-dimensional embedding for the text (dimension depends on model) Raises: 500: Inference error 503: Model is being updated """ start_time = time.time() try: # Acquire lock to ensure model isn't being swapped async with get_model_lock(): if model_module.model is None or model_module.tokenizer is None: raise HTTPException( status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail="Model not loaded" ) # Tokenize text text_tokens = model_module.tokenizer([request.text]) text_tokens = text_tokens.to(model_module.device) # Run inference with torch.no_grad(): text_embedding = model_module.model.encode_text(text_tokens) text_embedding = normalize_embedding(text_embedding) # Convert to numpy and then to list embedding = text_embedding.cpu().numpy()[0] # Calculate time inference_time = (time.time() - start_time) * 1000 return TextEmbeddingResponse( embedding=embedding.tolist(), model=model_module.model_metadata["model_name"], inference_time_ms=round(inference_time, 2), text=request.text ) except Exception as e: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Text inference failed: {str(e)}" ) # --- Main --- if __name__ == "__main__": import uvicorn uvicorn.run( app, host="0.0.0.0", port=7860, # HF Spaces default port log_level="info" )