File size: 8,176 Bytes
a490245 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 |
import os, glob
import gradio as gr
from PIL import Image
import torch
import torchvision.transforms as transforms
import torch.nn.functional as F
from archs import create_model, resume_model
# -------- Detect folders & images (assets/<folder>) --------
IMG_EXTS = (".png", ".jpg", ".jpeg", ".bmp", ".webp")
def list_subfolders(base="assets"):
"""Return a sorted list of immediate subfolders inside base."""
if not os.path.isdir(base):
return []
subs = [d for d in sorted(os.listdir(base)) if os.path.isdir(os.path.join(base, d))]
return subs
def list_images(folder):
"""Return full paths of images inside assets/<folder>."""
paths = sorted(glob.glob(os.path.join("assets", folder, "*")))
return [p for p in paths if p.lower().endswith(IMG_EXTS)]
# -------- Folder/Gallery interactions --------
def update_gallery(folder):
"""Given a folder name, return the gallery items (list of image paths) and store the same list in state."""
files = list_images(folder)
return gr.update(value=files, visible=True), files # (gallery_update, state_list)
def load_from_gallery(evt: gr.SelectData, current_files):
"""On gallery click, load the clicked image path into the input image."""
idx = evt.index
if not current_files or idx is None or idx >= len(current_files):
return gr.update()
path = current_files[idx]
return Image.open(path)
# -----------------------------
# Model
# -----------------------------
PATH_MODEL = './DeMoE.pt'
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_opt = {
'name': 'DeMoE',
'img_channels': 3,
'width': 32,
'middle_blk_num': 2,
'enc_blk_nums': [2, 2, 2, 2],
'dec_blk_nums': [2, 2, 2, 2],
'num_experts': 5,
'k_used': 1
}
pil_to_tensor = transforms.ToTensor()
tensor_to_pil = transforms.ToPILImage()
# Create and load model weights
model = create_model(model_opt, device)
_ = torch.load(PATH_MODEL, map_location=device, weights_only=False) # keep compatibility with different checkpoints
model = resume_model(model, PATH_MODEL, device)
def pad_tensor(tensor, multiple=16):
"""Pad tensor so that H and W are multiples of `multiple` (default 16)."""
_, _, H, W = tensor.shape
pad_h = (multiple - H % multiple) % multiple
pad_w = (multiple - W % multiple) % multiple
tensor = F.pad(tensor, (0, pad_w, 0, pad_h), value=0)
return tensor
# -----------------------------
# UI / Inference
# -----------------------------
title = 'DeMoE 🌪️'
description = ''' >**Abstract**: Image deblurring, removing blurring artifacts from images, is a fundamental task in computational photography and low-level computer vision. Existing approaches focus on specialized solutions tailored to particular blur types, thus, these solutions lack generalization. This limitation in current methods implies requiring multiple models to cover several blur types, which is not practical in many real scenarios. In this paper, we introduce the first all-in-one deblurring method capable of efficiently restoring images affected by diverse blur degradations, including global motion, local motion, blur in low-light conditions, and defocus blur. We propose a mixture-of-experts (MoE) decoding module, which dynamically routes image features based on the recognized blur degradation, enabling precise and efficient restoration in an end-to-end manner. Our unified approach not only achieves performance comparable to dedicated task-specific models, but also shows promising generalization to unseen blur scenarios, particularly when leveraging appropriate expert selection.
[Daniel Feijoo](https://github.com/danifei), Paula Garrido-Mellado, Jaesung Rim, Álvaro García, Marcos V. Conde
[Fundación Cidaut](https://cidaut.ai/)
Available code at [github](https://github.com/cidautai/DeMoE). More information on the [Arxiv paper](https://arxiv.org/pdf/2508.06228).
> **Disclaimer:** please remember this is not a product, thus, you will notice some limitations.
**This demo expects an image with some Low-Light degradations.**
<br>
'''
# Visible tasks in the UI
TASK_LABELS = ["Deblur", "Low-light", "movement", "defocus", "all"]
# Map pretty label -> internal task code used by the model
LABEL_TO_TASK = {
"Deblur": "global", # change to what your model expects for general deblurring
"Low-light": "lowlight",
"movement": "local", # if your model supports local motion blur
"defocus": "defocus", # if your model supports defocus blur
"all": "all", # if your model supports all types at once
}
css = """
.image-frame img, .image-container img {
width: auto;
height: auto;
max-width: none;
}
"""
# Example lists per folder under ./assets (kept simple, no helpers)
exts = (".png", ".jpg", ".jpeg", ".bmp", ".webp")
def list_basenames(folder):
"""Return [[basename, task_label], ...] for gr.Examples using examples_dir."""
paths = sorted(glob.glob(f"assets/{folder}/*"))
basenames = [os.path.basename(p) for p in paths if p.lower().endswith(exts)]
# Default task per folder (tweak as you like)
default_task = "Low-light" if folder == "lowlight" else "Deblur"
return [[name, default_task] for name in basenames]
examples_agentir = list_basenames("AgentIR")
examples_allweather = list_basenames("allweather")
examples_amac = list_basenames("amac_examples")
examples_deblur = list_basenames("deblur")
examples_gestures = list_basenames("gestures")
examples_lowlight = list_basenames("lowlight")
examples_monolith = list_basenames("monolith")
examples_superres = list_basenames("superres")
def process_img(image, task_label='auto'):
"""Main inference: converts PIL -> tensor, pads, runs the model with selected task, clamps, crops, returns PIL."""
task = LABEL_TO_TASK.get(task_label, 'auto') # default to lowlight if something unexpected arrives
tensor = pil_to_tensor(image).unsqueeze(0).to(device)
_, _, H, W = tensor.shape
tensor = pad_tensor(tensor)
with torch.no_grad():
output = model(tensor, task)
output = torch.clamp(output, 0., 1.)
output = output[:, :, :H, :W].squeeze(0)
return tensor_to_pil(output)
# -----------------------------
# Gradio Blocks layout
# -----------------------------
with gr.Blocks(css=css, title=title) as demo:
gr.Markdown(f"# {title}\n\n{description}")
with gr.Row():
# Input image and the task selector (Radio)
inp_img = gr.Image(type='pil', label='input')
# Output image and action button
out_img = gr.Image(type='pil', label='output')
task_selector = gr.Radio(
choices=TASK_LABELS,
value="auto",
label="Tipo de blur a corregir"
)
btn = gr.Button("Corregir", variant="primary")
# Connect the button to the inference function
btn.click(
fn=process_img,
inputs=[inp_img, task_selector],
outputs=[out_img]
)
# Examples grouped by folder (each item loads image + task automatically)
gr.Markdown("## Ejemplos (assets)")
with gr.Row():
# List folders found in ./assets
folders = list_subfolders("assets")
folder_radio = gr.Radio(choices=folders, label="Carpetas en assets", interactive=True)
gallery = gr.Gallery(
label="Imágenes de la carpeta seleccionada",
visible=False,
allow_preview=True,
columns=6,
height=320,
)
# State holds the current file list shown in the gallery (to resolve clicks)
current_files_state = gr.State([])
# When changing folder -> update gallery and state
folder_radio.change(
fn=update_gallery,
inputs=folder_radio,
outputs=[gallery, current_files_state]
)
# When clicking a thumbnail -> load it into the input image
gallery.select(
fn=load_from_gallery,
inputs=[current_files_state],
outputs=inp_img
)
if __name__ == '__main__':
# Explicit host/port and no SSR are friendly to Spaces
demo.launch(show_error=True, server_name="0.0.0.0", server_port=7864, ssr_mode=False)
|