Spaces:
Sleeping
Sleeping
| """FastAPI service for MobileCLIP2-S2 embeddings (PyTorch).""" | |
| import time | |
| from io import BytesIO | |
| from typing import List | |
| import numpy as np | |
| import torch | |
| from fastapi import FastAPI, File, UploadFile, HTTPException, status | |
| from fastapi.responses import JSONResponse | |
| from PIL import Image | |
| from pydantic import BaseModel, Field | |
| from open_clip import create_model_and_transforms, get_tokenizer | |
| from mobileclip.modules.common.mobileone import reparameterize_model | |
| # --- Configuration --- | |
| MAX_IMAGE_SIZE_MB = 10 | |
| MAX_BATCH_SIZE = 10 | |
| ALLOWED_EXTENSIONS = {"jpg", "jpeg", "png", "webp"} | |
| MODEL_NAME = "MobileCLIP2-S2" | |
| PRETRAINED = "dfndr2b" | |
| # --- Pydantic Models --- | |
| class EmbeddingResponse(BaseModel): | |
| """Single embedding response.""" | |
| embedding: List[float] = Field(..., min_length=512, max_length=512) | |
| model: str | |
| inference_time_ms: float | |
| class BatchEmbeddingResponse(BaseModel): | |
| """Batch embedding response.""" | |
| embeddings: List[List[float]] | |
| count: int | |
| total_time_ms: float | |
| model: str | |
| class TextEmbeddingRequest(BaseModel): | |
| """Text embedding request.""" | |
| text: str = Field(..., min_length=1, max_length=1000) | |
| class TextEmbeddingResponse(BaseModel): | |
| """Text embedding response.""" | |
| embedding: List[float] = Field(..., min_length=512, max_length=512) | |
| model: str | |
| inference_time_ms: float | |
| text: str | |
| class HealthResponse(BaseModel): | |
| """Health check response.""" | |
| status: str | |
| model: str | |
| device: str | |
| backend: str | |
| class InfoResponse(BaseModel): | |
| """Model info response.""" | |
| model: str | |
| embedding_dim: int | |
| backend: str | |
| max_image_size_mb: int | |
| max_batch_size: int | |
| image_size: int | |
| # --- Global Model Loading --- | |
| app = FastAPI( | |
| title="MobileCLIP2-S2 Embedder", | |
| description="PyTorch-based image embedding service", | |
| version="2.0.0" | |
| ) | |
| # Load model on startup | |
| model = None | |
| preprocess = None | |
| device = None | |
| IMAGE_SIZE = 256 | |
| EMBEDDING_DIM = 512 | |
| def load_model(): | |
| """Load MobileCLIP model using PyTorch.""" | |
| global model, preprocess, device | |
| # Determine device | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| print(f"Loading MobileCLIP model: {MODEL_NAME}...") | |
| print(f"Device: {device}") | |
| # Load model and preprocessing transform | |
| model, _, preprocess = create_model_and_transforms( | |
| MODEL_NAME, | |
| pretrained=PRETRAINED | |
| ) | |
| model = model.to(device) | |
| model.eval() | |
| # Reparameterize model for inference (required for MobileCLIP) | |
| print("Reparameterizing model for inference...") | |
| model = reparameterize_model(model) | |
| print(f"β Model loaded: {MODEL_NAME}") | |
| print(f"β Pretrained: {PRETRAINED}") | |
| print(f"β Device: {device}") | |
| print(f"β Embedding dim: {EMBEDDING_DIM}") | |
| async def startup_event(): | |
| """Initialize model on startup.""" | |
| load_model() | |
| # --- Preprocessing --- | |
| def preprocess_image(image: Image.Image) -> torch.Tensor: | |
| """ | |
| Preprocess image for MobileCLIP2-S2. | |
| Args: | |
| image: PIL Image | |
| Returns: | |
| Preprocessed torch tensor | |
| """ | |
| image = image.convert("RGB") | |
| img_tensor = preprocess(image).unsqueeze(0) # Add batch dimension | |
| return img_tensor.to(device) | |
| def normalize_embedding(embedding: torch.Tensor) -> torch.Tensor: | |
| """L2 normalize embedding.""" | |
| return torch.nn.functional.normalize(embedding, p=2, dim=-1) | |
| # --- Validation --- | |
| def validate_image_file(file: UploadFile) -> None: | |
| """Validate uploaded file.""" | |
| # Check extension | |
| if file.filename: | |
| ext = file.filename.split(".")[-1].lower() | |
| if ext not in ALLOWED_EXTENSIONS: | |
| raise HTTPException( | |
| status_code=status.HTTP_400_BAD_REQUEST, | |
| detail=f"Invalid file type. Allowed: {ALLOWED_EXTENSIONS}" | |
| ) | |
| # Check size (if available) | |
| if hasattr(file, "size") and file.size: | |
| max_bytes = MAX_IMAGE_SIZE_MB * 1024 * 1024 | |
| if file.size > max_bytes: | |
| raise HTTPException( | |
| status_code=status.HTTP_413_REQUEST_ENTITY_TOO_LARGE, | |
| detail=f"File too large. Max: {MAX_IMAGE_SIZE_MB}MB" | |
| ) | |
| async def load_image_from_upload(file: UploadFile) -> Image.Image: | |
| """Load PIL Image from upload.""" | |
| try: | |
| contents = await file.read() | |
| image = Image.open(BytesIO(contents)) | |
| return image | |
| except Exception as e: | |
| raise HTTPException( | |
| status_code=status.HTTP_400_BAD_REQUEST, | |
| detail=f"Invalid image file: {str(e)}" | |
| ) | |
| # --- Endpoints --- | |
| async def health_check(): | |
| """Health check endpoint.""" | |
| return HealthResponse( | |
| status="healthy", | |
| model=MODEL_NAME, | |
| device=str(device), | |
| backend="pytorch" | |
| ) | |
| async def get_info(): | |
| """Get model information.""" | |
| return InfoResponse( | |
| model=MODEL_NAME, | |
| embedding_dim=EMBEDDING_DIM, | |
| backend="pytorch", | |
| max_image_size_mb=MAX_IMAGE_SIZE_MB, | |
| max_batch_size=MAX_BATCH_SIZE, | |
| image_size=IMAGE_SIZE | |
| ) | |
| async def generate_embedding(file: UploadFile = File(...)): | |
| """ | |
| Generate embedding for single image. | |
| Args: | |
| file: Image file (JPEG, PNG, WebP) | |
| Returns: | |
| 512-dimensional embedding vector | |
| Raises: | |
| 400: Invalid file format | |
| 413: File too large | |
| 500: Inference error | |
| """ | |
| start_time = time.time() | |
| # Validate | |
| validate_image_file(file) | |
| try: | |
| # Load and preprocess | |
| image = await load_image_from_upload(file) | |
| img_tensor = preprocess_image(image) | |
| # Run inference | |
| with torch.no_grad(): | |
| embedding = model.encode_image(img_tensor) | |
| embedding = normalize_embedding(embedding) | |
| # Convert to numpy and then to list | |
| embedding = embedding.cpu().numpy()[0] | |
| # Calculate time | |
| inference_time = (time.time() - start_time) * 1000 | |
| return EmbeddingResponse( | |
| embedding=embedding.tolist(), | |
| model=MODEL_NAME, | |
| inference_time_ms=round(inference_time, 2) | |
| ) | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| raise HTTPException( | |
| status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, | |
| detail=f"Inference failed: {str(e)}" | |
| ) | |
| async def generate_batch_embeddings(files: List[UploadFile] = File(...)): | |
| """ | |
| Generate embeddings for multiple images. | |
| Args: | |
| files: List of image files (max 10) | |
| Returns: | |
| List of 512-dimensional embeddings | |
| Raises: | |
| 400: Invalid files or too many files | |
| 500: Inference error | |
| """ | |
| start_time = time.time() | |
| # Validate batch size | |
| if len(files) > MAX_BATCH_SIZE: | |
| raise HTTPException( | |
| status_code=status.HTTP_400_BAD_REQUEST, | |
| detail=f"Too many files. Max batch size: {MAX_BATCH_SIZE}" | |
| ) | |
| embeddings = [] | |
| try: | |
| # Process each image | |
| img_tensors = [] | |
| for file in files: | |
| validate_image_file(file) | |
| image = await load_image_from_upload(file) | |
| img_tensor = preprocess_image(image) | |
| img_tensors.append(img_tensor) | |
| # Batch inference | |
| batch_tensor = torch.cat(img_tensors, dim=0) | |
| with torch.no_grad(): | |
| batch_embeddings = model.encode_image(batch_tensor) | |
| batch_embeddings = normalize_embedding(batch_embeddings) | |
| # Convert to list | |
| batch_embeddings = batch_embeddings.cpu().numpy() | |
| embeddings = [emb.tolist() for emb in batch_embeddings] | |
| # Calculate time | |
| total_time = (time.time() - start_time) * 1000 | |
| return BatchEmbeddingResponse( | |
| embeddings=embeddings, | |
| count=len(embeddings), | |
| total_time_ms=round(total_time, 2), | |
| model=MODEL_NAME | |
| ) | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| raise HTTPException( | |
| status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, | |
| detail=f"Batch inference failed: {str(e)}" | |
| ) | |
| async def generate_text_embedding(request: TextEmbeddingRequest): | |
| """ | |
| Generate embedding for text query. | |
| Args: | |
| request: Text to embed | |
| Returns: | |
| 512-dimensional embedding for the text | |
| Raises: | |
| 500: Inference error | |
| """ | |
| start_time = time.time() | |
| try: | |
| # Tokenize text | |
| tokenizer = get_tokenizer(MODEL_NAME) | |
| text_tokens = tokenizer([request.text]) | |
| text_tokens = text_tokens.to(device) | |
| # Run inference | |
| with torch.no_grad(): | |
| text_embedding = model.encode_text(text_tokens) | |
| text_embedding = normalize_embedding(text_embedding) | |
| # Convert to numpy and then to list | |
| embedding = text_embedding.cpu().numpy()[0] | |
| # Calculate time | |
| inference_time = (time.time() - start_time) * 1000 | |
| return TextEmbeddingResponse( | |
| embedding=embedding.tolist(), | |
| model=MODEL_NAME, | |
| inference_time_ms=round(inference_time, 2), | |
| text=request.text | |
| ) | |
| except Exception as e: | |
| raise HTTPException( | |
| status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, | |
| detail=f"Text inference failed: {str(e)}" | |
| ) | |
| # --- Main --- | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run( | |
| app, | |
| host="0.0.0.0", | |
| port=7860, # HF Spaces default port | |
| log_level="info" | |
| ) |