import os from typing import List, Dict, Any, Union import torch from torch import nn import torchvision.models as tvm from torchvision.transforms import functional as F from torchvision import transforms as T from PIL import Image import gradio as gr CHECKPOINT_PATH = os.environ.get("CKPT_PATH", "best.pth") def get_device() -> torch.device: if torch.cuda.is_available(): return torch.device("cuda") return torch.device("cpu") def build_model(num_classes: int = 1000) -> nn.Module: model = tvm.resnet50(weights=None) model.fc = nn.Linear(model.fc.in_features, num_classes) return model def get_preprocess_and_labels(): # Use torchvision's ImageNet-1k metadata for categories and canonical transforms try: weights = tvm.ResNet50_Weights.IMAGENET1K_V2 except Exception: # Fallback if weights enum not available weights = None if weights is not None: preprocess = weights.transforms() labels = weights.meta.get("categories", [str(i) for i in range(1000)]) else: preprocess = T.Compose( [ T.Resize(256, interpolation=T.InterpolationMode.BILINEAR), T.CenterCrop(224), T.ToTensor(), T.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], ), ] ) labels = [str(i) for i in range(1000)] return preprocess, labels def load_checkpoint_into_model(model: nn.Module, checkpoint_path: str) -> None: if not os.path.exists(checkpoint_path): raise FileNotFoundError( f"Checkpoint not found at '{checkpoint_path}'. " f"Place your file at runs/exp1/best.pth or set CKPT_PATH env var." ) checkpoint = torch.load(checkpoint_path, map_location="cpu") # Support either a full training checkpoint dict or a raw state_dict state_dict = checkpoint.get("model", checkpoint) model.load_state_dict(state_dict, strict=False) model.eval() device = get_device() model = build_model(num_classes=1000).to(device) preprocess, imagenet_labels = get_preprocess_and_labels() load_checkpoint_into_model(model, CHECKPOINT_PATH) def predict_images( images: Union[Image.Image, List[Image.Image]], top_k: int = 5, ) -> List[List[Dict[str, Any]]]: if images is None: return [] if not isinstance(images, list): images = [images] results: List[List[Dict[str, Any]]] = [] with torch.no_grad(): for image in images: if not isinstance(image, Image.Image): # Some gradio versions may return dicts; handle defensively image = Image.fromarray(image) tensor = preprocess(image).unsqueeze(0).to(device) logits = model(tensor) probs = torch.softmax(logits, dim=1)[0] topk = torch.topk(probs, k=top_k) sample_result: List[Dict[str, Any]] = [] for score, idx in zip(topk.values.tolist(), topk.indices.tolist()): label = imagenet_labels[idx] if 0 <= idx < len(imagenet_labels) else str(idx) sample_result.append({"label": label, "probability": float(score)}) results.append(sample_result) return results # Custom CSS for modern UI custom_css = """ .gradio-container { font-family: 'IBM Plex Sans', sans-serif; max-width: 1400px !important; } .header-box { background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); padding: 40px; border-radius: 15px; color: white; text-align: center; margin-bottom: 30px; box-shadow: 0 8px 16px rgba(0,0,0,0.1); } .stats-card { background: linear-gradient(145deg, #f8f9fa 0%, #e9ecef 100%); padding: 20px; border-radius: 12px; border-left: 5px solid #667eea; margin: 10px 0; box-shadow: 0 4px 6px rgba(0,0,0,0.05); } .prediction-box { background: #ffffff; border-radius: 12px; padding: 20px; box-shadow: 0 4px 12px rgba(0,0,0,0.08); } """ with gr.Blocks(title="ResNet-50 ImageNet-1k Classifier", css=custom_css, theme=gr.themes.Soft()) as demo: # Header gr.HTML("""
Trained from Scratch on ImageNet-1K | 75%+ Top-1 Accuracy
1000 classes • 25.6M parameters • 98MB model
1.28M training images
1000 ImageNet classes
75-77% top-1 accuracy
92-94% top-5 accuracy
ResNet50 (Bottleneck)
25.6M parameters
💜 Built with Gradio • Trained on AWS EC2 • Deployed on 🤗 Hugging Face Spaces
Model trained from scratch achieving 76.12% top-1 accuracy on ImageNet-1K