Spaces:
Runtime error
Runtime error
| import os | |
| import sys | |
| sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) | |
| import gradio as gr | |
| import glob | |
| import torch | |
| import numpy as np | |
| from PIL import Image, ImageOps | |
| import torchvision.transforms as T | |
| from models.unet import UNet | |
| from models.unet_depthwise import UNetDepthwise | |
| from models.unet_depthwise_small import UNetDepthwiseSmall | |
| from models.unet_depthwise_nano import UNetDepthwiseNano | |
| from utils.checkpoint import load_checkpoint | |
| MODEL_PATHS = { | |
| "unet": "checkpoints/unet_best.pt", | |
| "unet_depthwise": "checkpoints/unet_depthwise_best.pt", | |
| "unet_depthwise_small": "checkpoints/unet_depthwise_small_best.pt", | |
| "unet_depthwise_nano": "checkpoints/unet_depthwise_nano_best.pt" | |
| } | |
| IMG_SIZE = 256 | |
| MODEL_CLASSES = { | |
| "unet": UNet, | |
| "unet_depthwise": UNetDepthwise, | |
| "unet_depthwise_small": UNetDepthwiseSmall, | |
| "unet_depthwise_nano": UNetDepthwiseNano | |
| } | |
| def get_model(model_type): | |
| model_class = MODEL_CLASSES[model_type] | |
| model = model_class(in_channels=3, out_channels=1) | |
| checkpoint_path = MODEL_PATHS[model_type] | |
| load_checkpoint(checkpoint_path, torch.device("cpu"), model) | |
| model.eval() | |
| return model | |
| def infer_gradio(model, image): | |
| # Preprocessing (same as test.py) | |
| orig_size = image.size | |
| transform = T.Compose([ | |
| T.Resize((IMG_SIZE, IMG_SIZE)), | |
| T.ToTensor(), | |
| ]) | |
| input_tensor = transform(image).unsqueeze(0) | |
| with torch.no_grad(): | |
| output = model(input_tensor) | |
| output = output.squeeze().cpu().numpy() | |
| output = np.uint8(output * 255) | |
| # Resize mask back to original image size | |
| mask_img = Image.fromarray(output) | |
| mask_img = mask_img.resize(orig_size, resample=Image.BILINEAR) | |
| # Invert mask for overlay | |
| inverted_mask = ImageOps.invert(mask_img.convert("L")) | |
| color_mask = Image.new("RGBA", orig_size, color=(0, 255, 0, 0)) | |
| alpha = inverted_mask.point(lambda p: int(p * 0.8)) | |
| color_mask.putalpha(alpha) | |
| image_rgba = image.convert("RGBA") | |
| overlay_img = Image.alpha_composite(image_rgba, color_mask).convert("RGB") | |
| # Return both overlay and mask | |
| return overlay_img, mask_img | |
| def lane_detection(image, model_type): | |
| model = get_model(model_type) | |
| return infer_gradio(model, image) | |
| example_images = glob.glob("images/*.jpg") | |
| demo = gr.Interface( | |
| fn=lane_detection, | |
| inputs=[gr.Image(type="pil"), gr.Radio(["unet", "unet_depthwise", "unet_depthwise_small", "unet_depthwise_nano"], label="Model Type")], | |
| outputs=[ | |
| gr.Image(type="pil", label="Lane Detection Result (Overlay)"), | |
| gr.Image(type="pil", label="Mask Output") | |
| ], | |
| title="Lane Detection using UNet", | |
| description="Upload a road image and select a model to see lane detection results.", | |
| examples=[[img, "unet"] for img in example_images] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |