Nick commited on
Commit
6994d11
·
1 Parent(s): 5b36e55

Add mask output to Gradio app; now returns both overlay and mask images

Browse files
Files changed (1) hide show
  1. app.py +6 -2
app.py CHANGED
@@ -59,7 +59,8 @@ def infer_gradio(model, image):
59
  color_mask.putalpha(alpha)
60
  image_rgba = image.convert("RGBA")
61
  overlay_img = Image.alpha_composite(image_rgba, color_mask).convert("RGB")
62
- return overlay_img
 
63
 
64
  def lane_detection(image, model_type):
65
  model = get_model(model_type)
@@ -68,7 +69,10 @@ def lane_detection(image, model_type):
68
  demo = gr.Interface(
69
  fn=lane_detection,
70
  inputs=[gr.Image(type="pil"), gr.Radio(["unet", "unet_depthwise", "unet_depthwise_small", "unet_depthwise_nano"], label="Model Type")],
71
- outputs=gr.Image(type="pil", label="Lane Detection Result"),
 
 
 
72
  title="Lane Detection UNet",
73
  description="Upload a road image and select a model to see lane detection results."
74
  )
 
59
  color_mask.putalpha(alpha)
60
  image_rgba = image.convert("RGBA")
61
  overlay_img = Image.alpha_composite(image_rgba, color_mask).convert("RGB")
62
+ # Return both overlay and mask
63
+ return overlay_img, mask_img
64
 
65
  def lane_detection(image, model_type):
66
  model = get_model(model_type)
 
69
  demo = gr.Interface(
70
  fn=lane_detection,
71
  inputs=[gr.Image(type="pil"), gr.Radio(["unet", "unet_depthwise", "unet_depthwise_small", "unet_depthwise_nano"], label="Model Type")],
72
+ outputs=[
73
+ gr.Image(type="pil", label="Lane Detection Result (Overlay)"),
74
+ gr.Image(type="pil", label="Mask Output")
75
+ ],
76
  title="Lane Detection UNet",
77
  description="Upload a road image and select a model to see lane detection results."
78
  )