File size: 2,879 Bytes
1c1a0c5
 
 
 
 
57758ad
1c1a0c5
 
 
 
 
 
 
 
 
 
 
5b36e55
 
 
 
1c1a0c5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6994d11
 
1c1a0c5
 
 
 
 
57758ad
 
1c1a0c5
 
 
6994d11
 
 
 
6316762
57758ad
 
1c1a0c5
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import os
import sys
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))

import gradio as gr
import glob
import torch
import numpy as np
from PIL import Image, ImageOps
import torchvision.transforms as T
from models.unet import UNet
from models.unet_depthwise import UNetDepthwise
from models.unet_depthwise_small import UNetDepthwiseSmall
from models.unet_depthwise_nano import UNetDepthwiseNano
from utils.checkpoint import load_checkpoint

MODEL_PATHS = {
    "unet": "checkpoints/unet_best.pt",
    "unet_depthwise": "checkpoints/unet_depthwise_best.pt",
    "unet_depthwise_small": "checkpoints/unet_depthwise_small_best.pt",
    "unet_depthwise_nano": "checkpoints/unet_depthwise_nano_best.pt"
}

IMG_SIZE = 256

MODEL_CLASSES = {
    "unet": UNet,
    "unet_depthwise": UNetDepthwise,
    "unet_depthwise_small": UNetDepthwiseSmall,
    "unet_depthwise_nano": UNetDepthwiseNano
}

def get_model(model_type):
    model_class = MODEL_CLASSES[model_type]
    model = model_class(in_channels=3, out_channels=1)
    checkpoint_path = MODEL_PATHS[model_type]
    load_checkpoint(checkpoint_path, torch.device("cpu"), model)
    model.eval()
    return model

def infer_gradio(model, image):
    # Preprocessing (same as test.py)
    orig_size = image.size
    transform = T.Compose([
        T.Resize((IMG_SIZE, IMG_SIZE)),
        T.ToTensor(),
    ])
    input_tensor = transform(image).unsqueeze(0)
    with torch.no_grad():
        output = model(input_tensor)
        output = output.squeeze().cpu().numpy()
        output = np.uint8(output * 255)
    # Resize mask back to original image size
    mask_img = Image.fromarray(output)
    mask_img = mask_img.resize(orig_size, resample=Image.BILINEAR)
    # Invert mask for overlay
    inverted_mask = ImageOps.invert(mask_img.convert("L"))
    color_mask = Image.new("RGBA", orig_size, color=(0, 255, 0, 0))
    alpha = inverted_mask.point(lambda p: int(p * 0.8))
    color_mask.putalpha(alpha)
    image_rgba = image.convert("RGBA")
    overlay_img = Image.alpha_composite(image_rgba, color_mask).convert("RGB")
    # Return both overlay and mask
    return overlay_img, mask_img

def lane_detection(image, model_type):
    model = get_model(model_type)
    return infer_gradio(model, image)

example_images = glob.glob("images/*.jpg")

demo = gr.Interface(
    fn=lane_detection,
    inputs=[gr.Image(type="pil"), gr.Radio(["unet", "unet_depthwise", "unet_depthwise_small", "unet_depthwise_nano"], label="Model Type")],
    outputs=[
        gr.Image(type="pil", label="Lane Detection Result (Overlay)"),
        gr.Image(type="pil", label="Mask Output")
    ],
    title="Lane Detection using UNet",
    description="Upload a road image and select a model to see lane detection results.",
    examples=[[img, "unet"] for img in example_images]
)

if __name__ == "__main__":
    demo.launch()