danifei commited on
Commit
5d01aa8
·
1 Parent(s): a490245

fixed appearance and added more images

Browse files
.gradio/certificate.pem ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ -----BEGIN CERTIFICATE-----
2
+ MIIFazCCA1OgAwIBAgIRAIIQz7DSQONZRGPgu2OCiwAwDQYJKoZIhvcNAQELBQAw
3
+ TzELMAkGA1UEBhMCVVMxKTAnBgNVBAoTIEludGVybmV0IFNlY3VyaXR5IFJlc2Vh
4
+ cmNoIEdyb3VwMRUwEwYDVQQDEwxJU1JHIFJvb3QgWDEwHhcNMTUwNjA0MTEwNDM4
5
+ WhcNMzUwNjA0MTEwNDM4WjBPMQswCQYDVQQGEwJVUzEpMCcGA1UEChMgSW50ZXJu
6
+ ZXQgU2VjdXJpdHkgUmVzZWFyY2ggR3JvdXAxFTATBgNVBAMTDElTUkcgUm9vdCBY
7
+ MTCCAiIwDQYJKoZIhvcNAQEBBQADggIPADCCAgoCggIBAK3oJHP0FDfzm54rVygc
8
+ h77ct984kIxuPOZXoHj3dcKi/vVqbvYATyjb3miGbESTtrFj/RQSa78f0uoxmyF+
9
+ 0TM8ukj13Xnfs7j/EvEhmkvBioZxaUpmZmyPfjxwv60pIgbz5MDmgK7iS4+3mX6U
10
+ A5/TR5d8mUgjU+g4rk8Kb4Mu0UlXjIB0ttov0DiNewNwIRt18jA8+o+u3dpjq+sW
11
+ T8KOEUt+zwvo/7V3LvSye0rgTBIlDHCNAymg4VMk7BPZ7hm/ELNKjD+Jo2FR3qyH
12
+ B5T0Y3HsLuJvW5iB4YlcNHlsdu87kGJ55tukmi8mxdAQ4Q7e2RCOFvu396j3x+UC
13
+ B5iPNgiV5+I3lg02dZ77DnKxHZu8A/lJBdiB3QW0KtZB6awBdpUKD9jf1b0SHzUv
14
+ KBds0pjBqAlkd25HN7rOrFleaJ1/ctaJxQZBKT5ZPt0m9STJEadao0xAH0ahmbWn
15
+ OlFuhjuefXKnEgV4We0+UXgVCwOPjdAvBbI+e0ocS3MFEvzG6uBQE3xDk3SzynTn
16
+ jh8BCNAw1FtxNrQHusEwMFxIt4I7mKZ9YIqioymCzLq9gwQbooMDQaHWBfEbwrbw
17
+ qHyGO0aoSCqI3Haadr8faqU9GY/rOPNk3sgrDQoo//fb4hVC1CLQJ13hef4Y53CI
18
+ rU7m2Ys6xt0nUW7/vGT1M0NPAgMBAAGjQjBAMA4GA1UdDwEB/wQEAwIBBjAPBgNV
19
+ HRMBAf8EBTADAQH/MB0GA1UdDgQWBBR5tFnme7bl5AFzgAiIyBpY9umbbjANBgkq
20
+ hkiG9w0BAQsFAAOCAgEAVR9YqbyyqFDQDLHYGmkgJykIrGF1XIpu+ILlaS/V9lZL
21
+ ubhzEFnTIZd+50xx+7LSYK05qAvqFyFWhfFQDlnrzuBZ6brJFe+GnY+EgPbk6ZGQ
22
+ 3BebYhtF8GaV0nxvwuo77x/Py9auJ/GpsMiu/X1+mvoiBOv/2X/qkSsisRcOj/KK
23
+ NFtY2PwByVS5uCbMiogziUwthDyC3+6WVwW6LLv3xLfHTjuCvjHIInNzktHCgKQ5
24
+ ORAzI4JMPJ+GslWYHb4phowim57iaztXOoJwTdwJx4nLCgdNbOhdjsnvzqvHu7Ur
25
+ TkXWStAmzOVyyghqpZXjFaH3pO3JLF+l+/+sKAIuvtd7u+Nxe5AW0wdeRlN8NwdC
26
+ jNPElpzVmbUq4JUagEiuTDkHzsxHpFKVK7q4+63SM1N95R1NbdWhscdCb+ZAJzVc
27
+ oyi3B43njTOQ5yOf+1CceWxG1bQVs5ZufpsMljq4Ui0/1lvh+wjChP4kqKOJ2qxq
28
+ 4RgqsahDYVvTH9w7jXbyLeiNdd8XM2w9U/t7y0Ff/9yi0GE44Za4rF2LN9d11TPA
29
+ mRGunUHBcnWEvgJBQl9nJEiU0Zsnvgc/ubhPgXRR4Xq37Z0j4r7g1SgEEzwxA57d
30
+ emyPxgcYxn/eR44/KJ4EBs+lVDR3veyJm+kXQ99b21/+jh5Xos1AnX5iItreGCc=
31
+ -----END CERTIFICATE-----
app.py CHANGED
@@ -6,12 +6,12 @@ import torch.nn.functional as F
6
  import os
7
  import glob
8
 
9
- from archs import create_model, resume_model
10
 
11
  # -------- Detect folders & images (assets/<folder>) --------
12
  IMG_EXTS = (".png", ".jpg", ".jpeg", ".bmp", ".webp")
13
 
14
- def list_subfolders(base="assets"):
15
  """Return a sorted list of immediate subfolders inside base."""
16
  if not os.path.isdir(base):
17
  return []
@@ -19,14 +19,15 @@ def list_subfolders(base="assets"):
19
  return subs
20
 
21
  def list_images(folder):
22
- """Return full paths of images inside assets/<folder>."""
23
- paths = sorted(glob.glob(os.path.join("assets", folder, "*")))
24
  return [p for p in paths if p.lower().endswith(IMG_EXTS)]
25
 
26
  # -------- Folder/Gallery interactions --------
27
  def update_gallery(folder):
28
  """Given a folder name, return the gallery items (list of image paths) and store the same list in state."""
29
  files = list_images(folder)
 
30
  return gr.update(value=files, visible=True), files
31
 
32
  def load_from_gallery(evt: gr.SelectData, current_files):
@@ -35,6 +36,7 @@ def load_from_gallery(evt: gr.SelectData, current_files):
35
  if not current_files or idx is None or idx >= len(current_files):
36
  return gr.update()
37
  path = current_files[idx]
 
38
  return Image.open(path)
39
 
40
 
@@ -59,7 +61,9 @@ tensor_to_pil = transforms.ToPILImage()
59
  model = create_model(model_opt, device)
60
 
61
  checkpoints = torch.load(PATH_MODEL, map_location=device, weights_only=False)
62
- model = resume_model(model, PATH_MODEL, device)
 
 
63
 
64
  def pad_tensor(tensor, multiple = 16):
65
  '''pad the tensor to be multiple of some number'''
@@ -88,12 +92,14 @@ def process_img(image, task_label = 'auto'):
88
  task_label = LABEL_TO_TASK.get(task_label, 'auto')
89
  tensor = pil_to_tensor(image).unsqueeze(0).to(device)
90
  _, _, H, W = tensor.shape
91
-
92
  tensor = pad_tensor(tensor)
93
 
94
  with torch.no_grad():
95
- output = model(tensor, task_label)
96
 
 
 
97
  output = torch.clamp(output, 0., 1.)
98
  output = output[:,:, :H, :W].squeeze(0)
99
  return tensor_to_pil(output)
@@ -117,17 +123,12 @@ Available code at [github](https://github.com/cidautai/DeMoE). More information
117
  <br>
118
  '''
119
 
120
- # examples = [['examples/1POA1811.png'],
121
- # ['examples/12_blur.png'],
122
- # ['examples/0031.png'],
123
- # ['examples/000143.png'],
124
- # ['examples/blur_4.png']]
125
-
126
  css = """
127
- .image-frame img, .image-container img {
128
- width: auto;
129
- height: auto;
130
- max-width: none;
 
131
  }
132
  """
133
 
@@ -150,6 +151,7 @@ examples_synth_global_motion = list_basenames("synth_global_motion")
150
  examples_local_motion = list_basenames("local_motion")
151
  examples_defocus = list_basenames("defocus")
152
 
 
153
  # -----------------------------
154
  # Gradio Blocks layout
155
  # -----------------------------
@@ -158,12 +160,12 @@ with gr.Blocks(css=css, title=title) as demo:
158
 
159
  with gr.Row():
160
  # Input image and the task selector (Radio)
161
- inp_img = gr.Image(type='pil', label='input')
162
  # Output image and action button
163
- out_img = gr.Image(type='pil', label='output')
164
  task_selector = gr.Radio(
165
  choices=TASK_LABELS,
166
- value="auto",
167
  label="Blur type"
168
  )
169
 
@@ -181,6 +183,7 @@ with gr.Blocks(css=css, title=title) as demo:
181
  with gr.Row():
182
  # List folders found in ./assets
183
  folders = list_subfolders("examples")
 
184
  folder_radio = gr.Radio(choices=folders, label="Examples Folders", interactive=True)
185
 
186
  gallery = gr.Gallery(
 
6
  import os
7
  import glob
8
 
9
+ from archs import create_model, load_model
10
 
11
  # -------- Detect folders & images (assets/<folder>) --------
12
  IMG_EXTS = (".png", ".jpg", ".jpeg", ".bmp", ".webp")
13
 
14
+ def list_subfolders(base="examples"):
15
  """Return a sorted list of immediate subfolders inside base."""
16
  if not os.path.isdir(base):
17
  return []
 
19
  return subs
20
 
21
  def list_images(folder):
22
+ """Return full paths of images inside examples/<folder>."""
23
+ paths = sorted(glob.glob(os.path.join("examples", folder, "*")))
24
  return [p for p in paths if p.lower().endswith(IMG_EXTS)]
25
 
26
  # -------- Folder/Gallery interactions --------
27
  def update_gallery(folder):
28
  """Given a folder name, return the gallery items (list of image paths) and store the same list in state."""
29
  files = list_images(folder)
30
+ print(files)
31
  return gr.update(value=files, visible=True), files
32
 
33
  def load_from_gallery(evt: gr.SelectData, current_files):
 
36
  if not current_files or idx is None or idx >= len(current_files):
37
  return gr.update()
38
  path = current_files[idx]
39
+ print(path)
40
  return Image.open(path)
41
 
42
 
 
61
  model = create_model(model_opt, device)
62
 
63
  checkpoints = torch.load(PATH_MODEL, map_location=device, weights_only=False)
64
+ model = load_model(model, PATH_MODEL, device)
65
+
66
+ model.eval()
67
 
68
  def pad_tensor(tensor, multiple = 16):
69
  '''pad the tensor to be multiple of some number'''
 
92
  task_label = LABEL_TO_TASK.get(task_label, 'auto')
93
  tensor = pil_to_tensor(image).unsqueeze(0).to(device)
94
  _, _, H, W = tensor.shape
95
+ print('Using task:', task_label)
96
  tensor = pad_tensor(tensor)
97
 
98
  with torch.no_grad():
99
+ output_dict = model(tensor, task_label)
100
 
101
+ output = output_dict['output']
102
+ # print(output.shape)
103
  output = torch.clamp(output, 0., 1.)
104
  output = output[:,:, :H, :W].squeeze(0)
105
  return tensor_to_pil(output)
 
123
  <br>
124
  '''
125
 
 
 
 
 
 
 
126
  css = """
127
+ .fitbox img,
128
+ .fitbox canvas {
129
+ width: 100% !important;
130
+ height: 100% !important;
131
+ object-fit: contain !important;
132
  }
133
  """
134
 
 
151
  examples_local_motion = list_basenames("local_motion")
152
  examples_defocus = list_basenames("defocus")
153
 
154
+ # print(examples_defocus, examples_global_motion, examples_low_light, examples_synth_global_motion, examples_local_motion)
155
  # -----------------------------
156
  # Gradio Blocks layout
157
  # -----------------------------
 
160
 
161
  with gr.Row():
162
  # Input image and the task selector (Radio)
163
+ inp_img = gr.Image(type='pil', label='input', height=320)
164
  # Output image and action button
165
+ out_img = gr.Image(type='pil', label='output', height=320)
166
  task_selector = gr.Radio(
167
  choices=TASK_LABELS,
168
+ value="Auto",
169
  label="Blur type"
170
  )
171
 
 
183
  with gr.Row():
184
  # List folders found in ./assets
185
  folders = list_subfolders("examples")
186
+ print(folders)
187
  folder_radio = gr.Radio(choices=folders, label="Examples Folders", interactive=True)
188
 
189
  gallery = gr.Gallery(
archs/__init__.py CHANGED
@@ -25,34 +25,70 @@ def create_model(opt, device):
25
 
26
  return model
27
 
28
- def load_weights(model, model_weights):
29
- '''
30
- Loads the weights of a pretrained model, picking only the weights that are
31
- in the new model.
32
- '''
33
- new_weights = model.state_dict()
34
- new_weights.update({k: v for k, v in model_weights.items() if k in new_weights})
35
 
36
- model.load_state_dict(new_weights)
 
 
 
 
 
 
 
 
 
 
 
 
 
37
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  return model
39
 
40
- def resume_model(model,
41
- path_model,
42
- device):
43
 
44
- '''
45
- Returns the loaded weights of model and optimizer if resume flag is True
46
- '''
47
 
48
- checkpoints = torch.load(path_model, map_location=device, weights_only=False)
49
- weights = checkpoints['params']
50
- model = load_weights(model, model_weights=weights)
51
 
52
- return model
53
 
54
 
55
- __all__ = ['create_model', 'resume_model', 'load_weights']
56
 
57
 
58
 
 
25
 
26
  return model
27
 
28
+ # def load_weights(model, model_weights):
29
+ # '''
30
+ # Loads the weights of a pretrained model, picking only the weights that are
31
+ # in the new model.
32
+ # '''
33
+ # new_weights = model.state_dict()
34
+ # new_weights.update({k: v for k, v in model_weights.items() if k in new_weights})
35
 
36
+ # new_weights = {key.replace('module.', ''): value for key, value in new_weights.items()}
37
+ # print(new_weights.keys())
38
+ # print(model.state_dict().keys())
39
+ # model.load_state_dict(new_weights, strict= True)
40
+
41
+
42
+ # total_checkpoint_keys = len(model_weights)
43
+ # total_model_keys = len(new_weights)
44
+ # matching_keys = len(set(model_weights.keys()) & set(new_weights.keys()))
45
+
46
+ # print(f"Total keys in checkpoint: {total_checkpoint_keys}")
47
+ # print(f"Total keys in model state dict: {total_model_keys}")
48
+ # print(f"Number of matching keys: {matching_keys}")
49
+
50
 
51
+ # return model
52
+
53
+ def strip_prefixes(sd: dict, prefixes=("module.", "model.", "ema.", "net.", "netG.", "generator.")) -> dict:
54
+ out = {}
55
+ for k, v in sd.items():
56
+ nk = k
57
+ for p in prefixes:
58
+ if nk.startswith(p):
59
+ nk = nk[len(p):]
60
+ break
61
+ out[nk] = v
62
+ return out
63
+
64
+ # ===== quita DDP y local_rank, usa un device único =====
65
+ def load_model(model, path_weights: str, device: torch.device):
66
+ # siempre carga en CPU y luego mueve
67
+ ckpt = torch.load(path_weights, map_location='cpu', weights_only=False)
68
+ # intenta varias claves habituales; si no, usa el dict tal cual
69
+ sd = ckpt.get("params") or ckpt.get("model_state_dict") or ckpt.get("state_dict") or ckpt
70
+ sd = strip_prefixes(sd)
71
+ missing, unexpected = model.load_state_dict(sd, strict=False)
72
+ print(f"[DeMoE] load_state: missing={len(missing)}, unexpected={len(unexpected)}")
73
+ model = model.to(device=device, dtype=torch.float32).eval()
74
  return model
75
 
76
+ # def resume_model(model,
77
+ # path_model,
78
+ # device):
79
 
80
+ # '''
81
+ # Returns the loaded weights of model and optimizer if resume flag is True
82
+ # '''
83
 
84
+ # checkpoints = torch.load(path_model, map_location=device, weights_only=False)
85
+ # weights = checkpoints['params']
86
+ # model = load_weights(model, model_weights=weights)
87
 
88
+ # return model
89
 
90
 
91
+ __all__ = ['create_model', 'load_model']
92
 
93
 
94
 
archs/__pycache__/DeMoE.cpython-312.pyc ADDED
Binary file (7.15 kB). View file
 
archs/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (2.46 kB). View file
 
archs/__pycache__/arch_model.cpython-312.pyc ADDED
Binary file (5.72 kB). View file
 
archs/__pycache__/arch_util.cpython-312.pyc ADDED
Binary file (5.1 kB). View file
 
archs/__pycache__/moeblocks.cpython-312.pyc ADDED
Binary file (3.48 kB). View file
 
examples/defocus/1P0A1916.png ADDED

Git LFS Details

  • SHA256: c878a86370802846fbb50ddf8ef6261e605c7a3514e2a06719567853faca2e13
  • Pointer size: 132 Bytes
  • Size of remote file: 7.89 MB
examples/defocus/1P0A2151.png ADDED

Git LFS Details

  • SHA256: 3853d6539151ebad3457c5636a6300b7714253a27e26231cd36e655de3a0e5f2
  • Pointer size: 132 Bytes
  • Size of remote file: 9.16 MB
examples/defocus/1P0A2239.png ADDED

Git LFS Details

  • SHA256: 1a0c6cffcc82fb30dddb1232862bbbf3ec6dd4e24e6d94f75c91f1699a61f28e
  • Pointer size: 132 Bytes
  • Size of remote file: 7.56 MB
examples/global_motion/blur_2.png ADDED

Git LFS Details

  • SHA256: c6ff99209a7c76fb6cc93afe244b0a5f013e91805b90c90230ab0cad4b9573af
  • Pointer size: 131 Bytes
  • Size of remote file: 578 kB
examples/global_motion/blur_4.png CHANGED

Git LFS Details

  • SHA256: 4daac3165f76b91c48f80562196d5c357f849d34a0db1024c264142331c216b3
  • Pointer size: 131 Bytes
  • Size of remote file: 553 kB

Git LFS Details

  • SHA256: ac968c8d9d6a584822bdbee6648414a4f8d39e93bfd4896b7a07f8afcd6d2757
  • Pointer size: 131 Bytes
  • Size of remote file: 482 kB
examples/global_motion/blur_6.png ADDED

Git LFS Details

  • SHA256: 90b9601269912fd7ae03e9d382e7f5053de65f964b4338310b0d5ce65bc89e9e
  • Pointer size: 131 Bytes
  • Size of remote file: 461 kB
examples/global_motion/blur_9.png ADDED

Git LFS Details

  • SHA256: fec8544851e3e14a8b5376ea5b5ee9612d88ebe247dd7db28f7a459ca4c869d0
  • Pointer size: 131 Bytes
  • Size of remote file: 712 kB
examples/local_motion/00_blur.png ADDED

Git LFS Details

  • SHA256: 4e834e939294c289ad2a74d63706a31224b326db08df2b9507ff268c718cab81
  • Pointer size: 132 Bytes
  • Size of remote file: 4.47 MB
examples/local_motion/08_blur.png ADDED

Git LFS Details

  • SHA256: 07983e7ee375506121c9992d7ec74da09d508c7fbab3c2e53841ec5ec2ffc00f
  • Pointer size: 132 Bytes
  • Size of remote file: 4.62 MB
examples/local_motion/09_blur.png ADDED

Git LFS Details

  • SHA256: 929c2490ae821aec766718582e6c721ad3e2d0f2dfc00c1cbb05765e05148cf3
  • Pointer size: 132 Bytes
  • Size of remote file: 4.77 MB
examples/low_light/0062.png ADDED

Git LFS Details

  • SHA256: ac6d88c7419931cf34fc4222c2b6af2b58db2d5fbbd228130a0abb2a23a1ba63
  • Pointer size: 131 Bytes
  • Size of remote file: 802 kB
examples/low_light/0065.png ADDED

Git LFS Details

  • SHA256: b85711d5c5203dbecc1ee0ab86021be276b159e50c44716ab3f234d3b27c0c44
  • Pointer size: 131 Bytes
  • Size of remote file: 518 kB
examples/low_light/0071.png ADDED

Git LFS Details

  • SHA256: f8931315e2108f7a1767b82a5b364770d64dd71a1d3f69e9286f5fc761ba0193
  • Pointer size: 131 Bytes
  • Size of remote file: 818 kB
examples/synth_global_motion/000050.png ADDED

Git LFS Details

  • SHA256: 5df0bf03ebcf840862de3f50a41e5f849075d61ad8584831a05278fd3583b914
  • Pointer size: 131 Bytes
  • Size of remote file: 984 kB
examples/synth_global_motion/000059.png ADDED

Git LFS Details

  • SHA256: 5326e19933aeee10c6671383e0ca33a435f9afa165401e7b7ca93e599c3d823f
  • Pointer size: 132 Bytes
  • Size of remote file: 1.04 MB
examples/synth_global_motion/004084.png ADDED

Git LFS Details

  • SHA256: c449e2421b17589a7a21e9fc5f57a55e8f0862a32cd1573d4aefa788a9025e38
  • Pointer size: 132 Bytes
  • Size of remote file: 1 MB