Cloudzy / embedder.py
GitHub Actions
πŸš€ Deploy embedder from GitHub Actions - 2025-10-27 22:54:05
3e8073f
raw
history blame
10.1 kB
"""FastAPI service for MobileCLIP2-S2 embeddings (PyTorch)."""
import time
from io import BytesIO
from typing import List
import numpy as np
import torch
from fastapi import FastAPI, File, UploadFile, HTTPException, status
from fastapi.responses import JSONResponse
from PIL import Image
from pydantic import BaseModel, Field
from open_clip import create_model_and_transforms, get_tokenizer
from mobileclip.modules.common.mobileone import reparameterize_model
# --- Configuration ---
MAX_IMAGE_SIZE_MB = 10
MAX_BATCH_SIZE = 10
ALLOWED_EXTENSIONS = {"jpg", "jpeg", "png", "webp"}
MODEL_NAME = "MobileCLIP2-S2"
PRETRAINED = "dfndr2b"
# --- Pydantic Models ---
class EmbeddingResponse(BaseModel):
"""Single embedding response."""
embedding: List[float] = Field(..., min_length=512, max_length=512)
model: str
inference_time_ms: float
class BatchEmbeddingResponse(BaseModel):
"""Batch embedding response."""
embeddings: List[List[float]]
count: int
total_time_ms: float
model: str
class TextEmbeddingRequest(BaseModel):
"""Text embedding request."""
text: str = Field(..., min_length=1, max_length=1000)
class TextEmbeddingResponse(BaseModel):
"""Text embedding response."""
embedding: List[float] = Field(..., min_length=512, max_length=512)
model: str
inference_time_ms: float
text: str
class HealthResponse(BaseModel):
"""Health check response."""
status: str
model: str
device: str
backend: str
class InfoResponse(BaseModel):
"""Model info response."""
model: str
embedding_dim: int
backend: str
max_image_size_mb: int
max_batch_size: int
image_size: int
# --- Global Model Loading ---
app = FastAPI(
title="MobileCLIP2-S2 Embedder",
description="PyTorch-based image embedding service",
version="2.0.0"
)
# Load model on startup
model = None
preprocess = None
device = None
IMAGE_SIZE = 256
EMBEDDING_DIM = 512
def load_model():
"""Load MobileCLIP model using PyTorch."""
global model, preprocess, device
# Determine device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Loading MobileCLIP model: {MODEL_NAME}...")
print(f"Device: {device}")
# Load model and preprocessing transform
model, _, preprocess = create_model_and_transforms(
MODEL_NAME,
pretrained=PRETRAINED
)
model = model.to(device)
model.eval()
# Reparameterize model for inference (required for MobileCLIP)
print("Reparameterizing model for inference...")
model = reparameterize_model(model)
print(f"βœ“ Model loaded: {MODEL_NAME}")
print(f"βœ“ Pretrained: {PRETRAINED}")
print(f"βœ“ Device: {device}")
print(f"βœ“ Embedding dim: {EMBEDDING_DIM}")
@app.on_event("startup")
async def startup_event():
"""Initialize model on startup."""
load_model()
# --- Preprocessing ---
def preprocess_image(image: Image.Image) -> torch.Tensor:
"""
Preprocess image for MobileCLIP2-S2.
Args:
image: PIL Image
Returns:
Preprocessed torch tensor
"""
image = image.convert("RGB")
img_tensor = preprocess(image).unsqueeze(0) # Add batch dimension
return img_tensor.to(device)
def normalize_embedding(embedding: torch.Tensor) -> torch.Tensor:
"""L2 normalize embedding."""
return torch.nn.functional.normalize(embedding, p=2, dim=-1)
# --- Validation ---
def validate_image_file(file: UploadFile) -> None:
"""Validate uploaded file."""
# Check extension
if file.filename:
ext = file.filename.split(".")[-1].lower()
if ext not in ALLOWED_EXTENSIONS:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Invalid file type. Allowed: {ALLOWED_EXTENSIONS}"
)
# Check size (if available)
if hasattr(file, "size") and file.size:
max_bytes = MAX_IMAGE_SIZE_MB * 1024 * 1024
if file.size > max_bytes:
raise HTTPException(
status_code=status.HTTP_413_REQUEST_ENTITY_TOO_LARGE,
detail=f"File too large. Max: {MAX_IMAGE_SIZE_MB}MB"
)
async def load_image_from_upload(file: UploadFile) -> Image.Image:
"""Load PIL Image from upload."""
try:
contents = await file.read()
image = Image.open(BytesIO(contents))
return image
except Exception as e:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Invalid image file: {str(e)}"
)
# --- Endpoints ---
@app.get("/", response_model=HealthResponse)
async def health_check():
"""Health check endpoint."""
return HealthResponse(
status="healthy",
model=MODEL_NAME,
device=str(device),
backend="pytorch"
)
@app.get("/info", response_model=InfoResponse)
async def get_info():
"""Get model information."""
return InfoResponse(
model=MODEL_NAME,
embedding_dim=EMBEDDING_DIM,
backend="pytorch",
max_image_size_mb=MAX_IMAGE_SIZE_MB,
max_batch_size=MAX_BATCH_SIZE,
image_size=IMAGE_SIZE
)
@app.post("/embed", response_model=EmbeddingResponse)
async def generate_embedding(file: UploadFile = File(...)):
"""
Generate embedding for single image.
Args:
file: Image file (JPEG, PNG, WebP)
Returns:
512-dimensional embedding vector
Raises:
400: Invalid file format
413: File too large
500: Inference error
"""
start_time = time.time()
# Validate
validate_image_file(file)
try:
# Load and preprocess
image = await load_image_from_upload(file)
img_tensor = preprocess_image(image)
# Run inference
with torch.no_grad():
embedding = model.encode_image(img_tensor)
embedding = normalize_embedding(embedding)
# Convert to numpy and then to list
embedding = embedding.cpu().numpy()[0]
# Calculate time
inference_time = (time.time() - start_time) * 1000
return EmbeddingResponse(
embedding=embedding.tolist(),
model=MODEL_NAME,
inference_time_ms=round(inference_time, 2)
)
except HTTPException:
raise
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Inference failed: {str(e)}"
)
@app.post("/embed/batch", response_model=BatchEmbeddingResponse)
async def generate_batch_embeddings(files: List[UploadFile] = File(...)):
"""
Generate embeddings for multiple images.
Args:
files: List of image files (max 10)
Returns:
List of 512-dimensional embeddings
Raises:
400: Invalid files or too many files
500: Inference error
"""
start_time = time.time()
# Validate batch size
if len(files) > MAX_BATCH_SIZE:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Too many files. Max batch size: {MAX_BATCH_SIZE}"
)
embeddings = []
try:
# Process each image
img_tensors = []
for file in files:
validate_image_file(file)
image = await load_image_from_upload(file)
img_tensor = preprocess_image(image)
img_tensors.append(img_tensor)
# Batch inference
batch_tensor = torch.cat(img_tensors, dim=0)
with torch.no_grad():
batch_embeddings = model.encode_image(batch_tensor)
batch_embeddings = normalize_embedding(batch_embeddings)
# Convert to list
batch_embeddings = batch_embeddings.cpu().numpy()
embeddings = [emb.tolist() for emb in batch_embeddings]
# Calculate time
total_time = (time.time() - start_time) * 1000
return BatchEmbeddingResponse(
embeddings=embeddings,
count=len(embeddings),
total_time_ms=round(total_time, 2),
model=MODEL_NAME
)
except HTTPException:
raise
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Batch inference failed: {str(e)}"
)
@app.post("/embed/text", response_model=TextEmbeddingResponse)
async def generate_text_embedding(request: TextEmbeddingRequest):
"""
Generate embedding for text query.
Args:
request: Text to embed
Returns:
512-dimensional embedding for the text
Raises:
500: Inference error
"""
start_time = time.time()
try:
# Tokenize text
tokenizer = get_tokenizer(MODEL_NAME)
text_tokens = tokenizer([request.text])
text_tokens = text_tokens.to(device)
# Run inference
with torch.no_grad():
text_embedding = model.encode_text(text_tokens)
text_embedding = normalize_embedding(text_embedding)
# Convert to numpy and then to list
embedding = text_embedding.cpu().numpy()[0]
# Calculate time
inference_time = (time.time() - start_time) * 1000
return TextEmbeddingResponse(
embedding=embedding.tolist(),
model=MODEL_NAME,
inference_time_ms=round(inference_time, 2),
text=request.text
)
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Text inference failed: {str(e)}"
)
# --- Main ---
if __name__ == "__main__":
import uvicorn
uvicorn.run(
app,
host="0.0.0.0",
port=7860, # HF Spaces default port
log_level="info"
)