dmpetrov commited on
Commit
8ecf001
·
1 Parent(s): 8280fcc

added app.py

Browse files
Files changed (1) hide show
  1. app.py +816 -4
app.py CHANGED
@@ -1,7 +1,819 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
 
 
 
 
 
 
 
 
5
 
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- demo.launch()
 
1
+ """
2
+ ShapeWords: Guiding Text-to-Image Synthesis with 3D Shape-Aware Prompts
3
+ =======================================================================
4
+
5
+ A Gradio web interface for the ShapeWords paper, allowing users to generate
6
+ images guided by 3D shape information.
7
+
8
+ Author: Melinos Averkiou
9
+ Date: 11 March 2025
10
+ Version: 1.0
11
+
12
+ Paper: "ShapeWords: Guiding Text-to-Image Synthesis with 3D Shape-Aware Prompts"
13
+ arXiv: https://arxiv.org/abs/2412.02912
14
+ Project Page: https://lodurality.github.io/shapewords/
15
+
16
+ Citation:
17
+ @misc{petrov2024shapewords,
18
+ title={ShapeWords: Guiding Text-to-Image Synthesis with 3D Shape-Aware Prompts},
19
+ author={Dmitry Petrov and Pradyumn Goyal and Divyansh Shivashok and Yuanming Tao and Melinos Averkiou and Evangelos Kalogerakis},
20
+ year={2024},
21
+ eprint={2412.02912},
22
+ archivePrefix={arXiv},
23
+ primaryClass={cs.CV},
24
+ url={https://arxiv.org/abs/2412.02912},
25
+ }
26
+
27
+ License: MIT License
28
+
29
+ Usage:
30
+ python app.py [--share]
31
+
32
+ This demo allows users to:
33
+ 1. Select a 3D object category
34
+ 2. Choose a specific 3D shape using a slider
35
+ 3. Enter a text prompt
36
+ 4. Generate images guided by the selected 3D shape
37
+
38
+ The code is structured as a class and is compatible with Hugging Face ZeroGPU deployment.
39
+ """
40
+
41
+ import os
42
+ import sys
43
+ import numpy as np
44
+ import torch
45
  import gradio as gr
46
+ from PIL import Image, ImageFont, ImageDraw
47
+ from diffusers.utils import load_image
48
+ from diffusers import DPMSolverMultistepScheduler, StableDiffusionPipeline
49
+ import open_clip
50
+ import gdown
51
+ import argparse
52
+ import random
53
+ import spaces
54
+
55
+ class ShapeWordsDemo:
56
+ # Constants
57
+ NAME2CAT = {
58
+ "chair": "03001627", "table": "04379243", "jar": "03593526", "skateboard": "04225987",
59
+ "car": "02958343", "bottle": "02876657", "tower": "04460130", "bookshelf": "02871439",
60
+ "camera": "02942699", "airplane": "02691156", "laptop": "03642806", "basket": "02801938",
61
+ "sofa": "04256520", "knife": "03624134", "can": "02946921", "rifle": "04090263",
62
+ "train": "04468005", "pillow": "03938244", "lamp": "03636649", "trash bin": "02747177",
63
+ "mailbox": "03710193", "watercraft": "04530566", "motorbike": "03790512",
64
+ "dishwasher": "03207941", "bench": "02828884", "pistol": "03948459", "rocket": "04099429",
65
+ "loudspeaker": "03691459", "file cabinet": "03337140", "bag": "02773838",
66
+ "cabinet": "02933112", "bed": "02818832", "birdhouse": "02843684", "display": "03211117",
67
+ "piano": "03928116", "earphone": "03261776", "telephone": "04401088", "stove": "04330267",
68
+ "microphone": "03759954", "bus": "02924116", "mug": "03797390", "remote": "04074963",
69
+ "bathtub": "02808440", "bowl": "02880940", "keyboard": "03085013", "guitar": "03467517",
70
+ "washer": "04554684", "bicycle": "02834778", "faucet": "03325088", "printer": "04004475",
71
+ "cap": "02954340", "phone": "02992529", "clock": "03046257", "helmet": "03513137",
72
+ "microwave": "03761084", "plant": "03991062"
73
+ }
74
+
75
+ def __init__(self):
76
+ # Initialize class attributes
77
+ self.pipeline = None
78
+ self.shape2clip_model = None
79
+ self.text_encoder = None
80
+ self.tokenizer = None
81
+ self.category_embeddings = {}
82
+ self.category_counts = {}
83
+ self.available_categories = []
84
+ self.shape_thumbnail_cache = {} # Cache for shape thumbnails
85
+ self.CAT2NAME = {v: k for k, v in self.NAME2CAT.items()}
86
+
87
+ # Initialize all models and data
88
+ self.initialize_models()
89
+
90
+ def draw_text(self, img, text, color=(10, 10, 10), size=80, location=(200, 30)):
91
+ img = img.copy()
92
+ draw = ImageDraw.Draw(img)
93
+
94
+ try:
95
+ font = ImageFont.truetype("Arial", size=size)
96
+ except IOError:
97
+ font = ImageFont.load_default()
98
+
99
+ bbox = draw.textbbox(location, text, font=font)
100
+ draw.rectangle(bbox, fill="white")
101
+ draw.text(location, text, color, font=font)
102
+
103
+ return img
104
+
105
+ def get_ulip_image(self, guidance_shape_id, angle='036'):
106
+ shape_id_ulip = guidance_shape_id.replace('_', '-')
107
+ ulip_template = 'https://storage.googleapis.com/sfr-ulip-code-release-research/shapenet-55/only_rgb_depth_images/{}_r_{}_depth0001.png'
108
+ ulip_path = ulip_template.format(shape_id_ulip, angle)
109
+
110
+ try:
111
+ ulip_image = load_image(ulip_path).resize((512, 512))
112
+ return ulip_image
113
+ except Exception as e:
114
+ print(f"Error loading image: {e}")
115
+ return Image.new('RGB', (512, 512), color='gray')
116
+
117
+ def get_ulip_thumbnail(self, guidance_shape_id, angle='036', size=(150, 150)):
118
+ """Get a thumbnail version of the ULIP image for use in the gallery"""
119
+ image = self.get_ulip_image(guidance_shape_id, angle)
120
+ return image.resize(size)
121
+
122
+ def initialize_models(self):
123
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
124
+ print(f"Using device: {device}")
125
+
126
+ # Download Shape2CLIP code if it doesn't exist
127
+ if not os.path.exists("shapewords_paper_code"):
128
+ os.system("git clone https://github.com/lodurality/shapewords_paper_code.git")
129
+
130
+ # Import Shape2CLIP model
131
+ sys.path.append("./shapewords_paper_code")
132
+ from shapewords_paper_code.geometry_guidance_models import Shape2CLIP
133
+
134
+ # Initialize the pipeline
135
+ self.pipeline = StableDiffusionPipeline.from_pretrained(
136
+ "stabilityai/stable-diffusion-2-1-base",
137
+ torch_dtype=torch.float16 if device.type == "cuda" else torch.float32
138
+ )
139
+
140
+ self.pipeline.scheduler = DPMSolverMultistepScheduler.from_config(
141
+ self.pipeline.scheduler.config,
142
+ algorithm_type="sde-dpmsolver++"
143
+ )
144
+
145
+ # Load CLIP model
146
+ clip_model, _, preprocess = open_clip.create_model_and_transforms(
147
+ 'ViT-H-14',
148
+ pretrained='laion2b_s32b_b79k'
149
+ )
150
+
151
+ # Move models to device if not using ZeroGPU
152
+ if device.type == "cuda":
153
+ self.pipeline = self.pipeline.to(device)
154
+ self.pipeline.enable_model_cpu_offload()
155
+
156
+ clip_tokenizer = open_clip.get_tokenizer('ViT-H-14')
157
+ self.text_encoder = self.pipeline.text_encoder
158
+ self.tokenizer = self.pipeline.tokenizer
159
+
160
+ # Look for Shape2CLIP checkpoint in multiple locations
161
+ checkpoint_paths = [
162
+ "projection_model-0920192.pth",
163
+ "embeddings/projection_model-0920192.pth"
164
+ ]
165
+
166
+ checkpoint_found = False
167
+ checkpoint_path = None
168
+ for path in checkpoint_paths:
169
+ if os.path.exists(path):
170
+ checkpoint_path = path
171
+ print(f"Found Shape2CLIP checkpoint at: {checkpoint_path}")
172
+ checkpoint_found = True
173
+ break
174
+
175
+ # Download Shape2CLIP checkpoint if not found
176
+ if not checkpoint_found:
177
+ checkpoint_path = "projection_model-0920192.pth"
178
+ print("Downloading Shape2CLIP model checkpoint...")
179
+ gdown.download("1nvEXnwMpNkRts6rxVqMZt8i9FZ40KjP7", checkpoint_path, quiet=False)
180
+ print("Download complete")
181
+
182
+ # Initialize Shape2CLIP model
183
+ self.shape2clip_model = Shape2CLIP(depth=6, drop_path_rate=0.1, pb_dim=384)
184
+ self.shape2clip_model.load_state_dict(torch.load(checkpoint_path, map_location=device))
185
+ if device.type == "cuda":
186
+ self.shape2clip_model = self.shape2clip_model.to(device)
187
+ self.shape2clip_model.eval()
188
+
189
+ # Scan for available embeddings
190
+ self.scan_available_embeddings()
191
+
192
+ def scan_available_embeddings(self):
193
+ self.available_categories = []
194
+ self.category_counts = {}
195
+
196
+ for category, cat_id in self.NAME2CAT.items():
197
+ possible_filenames = [
198
+ f"pointbert_shapenet_{cat_id}.npz",
199
+ f"{cat_id}_pb_embs.npz",
200
+ f"embeddings/pointbert_shapenet_{cat_id}.npz",
201
+ f"embeddings/{cat_id}_pb_embs.npz"
202
+ ]
203
+
204
+ found_file = None
205
+ for filename in possible_filenames:
206
+ if os.path.exists(filename):
207
+ found_file = filename
208
+ break
209
+
210
+ if found_file:
211
+ try:
212
+ pb_data = np.load(found_file)
213
+ if 'ids' in pb_data:
214
+ count = len(pb_data['ids'])
215
+ else:
216
+ # Try to infer the correct keys
217
+ keys = list(pb_data.keys())
218
+ if len(keys) >= 1:
219
+ count = len(pb_data[keys[0]])
220
+ else:
221
+ count = 0
222
+
223
+ if count > 0:
224
+ self.available_categories.append(category)
225
+ self.category_counts[category] = count
226
+ print(f"Found {count} embeddings for category '{category}'")
227
+ except Exception as e:
228
+ print(f"Error loading embeddings for {category}: {e}")
229
+
230
+ if not self.available_categories:
231
+ self.available_categories = ["chair"] # Fallback
232
+ self.category_counts["chair"] = 50 # Default value
233
+
234
+ # Sort categories alphabetically
235
+ self.available_categories.sort()
236
+
237
+ print(f"Found {len(self.available_categories)} categories with embeddings")
238
+ print(f"Available categories: {', '.join(self.available_categories)}")
239
+
240
+ def load_category_embeddings(self, category):
241
+ if category in self.category_embeddings:
242
+ return self.category_embeddings[category]
243
+
244
+ if category not in self.NAME2CAT:
245
+ return None, []
246
+
247
+ cat_id = self.NAME2CAT[category]
248
+
249
+ # Check for different possible embedding filenames and locations
250
+ possible_filenames = [
251
+ f"pointbert_shapenet_{cat_id}.npz",
252
+ f"{cat_id}_pb_embs.npz",
253
+ f"embeddings/pointbert_shapenet_{cat_id}.npz",
254
+ f"embeddings/{cat_id}_pb_embs.npz",
255
+ ]
256
+
257
+ # Find the first existing file
258
+ pb_emb_filename = None
259
+ for filename in possible_filenames:
260
+ if os.path.exists(filename):
261
+ pb_emb_filename = filename
262
+ print(f"Found embeddings file: {pb_emb_filename}")
263
+ break
264
+
265
+ if pb_emb_filename is None:
266
+ print(f"No embeddings found for {category}")
267
+ return None, []
268
+
269
+ # Load embeddings
270
+ try:
271
+ print(f"Loading embeddings from {pb_emb_filename}...")
272
+ pb_data = np.load(pb_emb_filename)
273
+
274
+ # Check for different key names in the NPZ file
275
+ if 'ids' in pb_data and 'embs' in pb_data:
276
+ pb_dict = dict(zip(pb_data['ids'], pb_data['embs']))
277
+ else:
278
+ # Try to infer the correct keys
279
+ keys = list(pb_data.keys())
280
+ if len(keys) >= 2:
281
+ # Assume first key is for IDs and second is for embeddings
282
+ pb_dict = dict(zip(pb_data[keys[0]], pb_data[keys[1]]))
283
+ else:
284
+ print("Unexpected embedding file format")
285
+ return None, []
286
+
287
+ all_ids = sorted(list(pb_dict.keys()))
288
+ print(f"Loaded {len(all_ids)} shape embeddings for {category}")
289
+
290
+ # Cache the results
291
+ self.category_embeddings[category] = (pb_dict, all_ids)
292
+ return pb_dict, all_ids
293
+ except Exception as e:
294
+ print(f"Error loading embeddings: {e}")
295
+ print(f"Exception details: {str(e)}")
296
+ return None, []
297
+
298
+ def get_shape_preview(self, category, shape_idx, size=(300, 300)):
299
+ """Get a preview image for a specific shape"""
300
+ if shape_idx is None or shape_idx < 0:
301
+ return None
302
+
303
+ pb_dict, all_ids = self.load_category_embeddings(category)
304
+ if pb_dict is None or not all_ids or shape_idx >= len(all_ids):
305
+ return None
306
+
307
+ shape_id = all_ids[shape_idx]
308
+
309
+ try:
310
+ # Get the shape image at the requested size
311
+ preview_image = self.get_ulip_image(shape_id)
312
+ preview_image = preview_image.resize(size)
313
+ preview_with_text = self.draw_text(preview_image, f"Shape #{shape_idx}", size=30, location=(10, 10))
314
+ return preview_with_text
315
+ except Exception as e:
316
+ print(f"Error loading preview for {shape_id}: {e}")
317
+ # Create an empty error image
318
+ empty_img = Image.new('RGB', size, color='gray')
319
+ error_text = f"Error loading Shape #{shape_idx}"
320
+ return self.draw_text(empty_img, error_text, size=30, location=(10, 10))
321
+
322
+ def on_slider_change(self, shape_idx, category):
323
+ """Update the preview when the slider changes"""
324
+ max_idx = self.category_counts.get(category, 0) - 1
325
+
326
+ # Get preview image
327
+ preview_image = self.get_shape_preview(category, shape_idx)
328
+
329
+ # Update counter text
330
+ counter_text = f"Shape {shape_idx} of {max_idx}"
331
+
332
+ return preview_image, counter_text, shape_idx
333
+
334
+ def prev_shape(self, current_idx, category):
335
+ """Go to previous shape"""
336
+ max_idx = self.category_counts.get(category, 0) - 1
337
+ new_idx = max(0, current_idx - 1)
338
+
339
+ # Get preview image
340
+ preview_image = self.get_shape_preview(category, new_idx)
341
+
342
+ # Update counter text
343
+ counter_text = f"Shape {new_idx} of {max_idx}"
344
+
345
+ return new_idx, preview_image, counter_text
346
+
347
+ def next_shape(self, current_idx, category):
348
+ """Go to next shape"""
349
+ max_idx = self.category_counts.get(category, 0) - 1
350
+ new_idx = min(max_idx, current_idx + 1)
351
+
352
+ # Get preview image
353
+ preview_image = self.get_shape_preview(category, new_idx)
354
+
355
+ # Update counter text
356
+ counter_text = f"Shape {new_idx} of {max_idx}"
357
+
358
+ return new_idx, preview_image, counter_text
359
+
360
+ def jump_to_start(self, category):
361
+ """Jump to the first shape"""
362
+ max_idx = self.category_counts.get(category, 0) - 1
363
+ new_idx = 0
364
+
365
+ # Get preview image
366
+ preview_image = self.get_shape_preview(category, new_idx)
367
+
368
+ # Update counter text
369
+ counter_text = f"Shape {new_idx} of {max_idx}"
370
+
371
+ return new_idx, preview_image, counter_text
372
+
373
+ def jump_to_end(self, category):
374
+ """Jump to the last shape"""
375
+ max_idx = self.category_counts.get(category, 0) - 1
376
+ new_idx = max_idx
377
+
378
+ # Get preview image
379
+ preview_image = self.get_shape_preview(category, new_idx)
380
+
381
+ # Update counter text
382
+ counter_text = f"Shape {new_idx} of {max_idx}"
383
+
384
+ return new_idx, preview_image, counter_text
385
+
386
+ def random_shape(self, category):
387
+ """Select a random shape from the category"""
388
+ max_idx = self.category_counts.get(category, 0) - 1
389
+ if max_idx <= 0:
390
+ return 0, self.get_shape_preview(category, 0), f"Shape 0 of 0"
391
+
392
+ # Generate random index
393
+ random_idx = random.randint(0, max_idx)
394
+
395
+ # Get preview image
396
+ preview_image = self.get_shape_preview(category, random_idx)
397
+
398
+ # Update counter text
399
+ counter_text = f"Shape {random_idx} of {max_idx}"
400
+
401
+ return random_idx, preview_image, counter_text
402
+
403
+ def on_category_change(self, category):
404
+ """Update the slider and preview when the category changes"""
405
+ # Reset to the first shape
406
+ current_idx = 0
407
+ max_idx = self.category_counts.get(category, 0) - 1
408
+
409
+ # Get preview image
410
+ preview_image = self.get_shape_preview(category, current_idx)
411
+
412
+ # Update counter text
413
+ counter_text = f"Shape {current_idx} of {max_idx}"
414
+
415
+ # Need to update the slider range
416
+ new_slider = gr.Slider(
417
+ minimum=0,
418
+ maximum=max_idx,
419
+ step=1,
420
+ value=current_idx,
421
+ label="Shape Index"
422
+ )
423
+
424
+ return new_slider, current_idx, preview_image, counter_text
425
+
426
+ def get_guidance(self, test_prompt, category_name, guidance_emb):
427
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
428
+ prompt_tokens = torch.LongTensor(self.tokenizer.encode(test_prompt, padding='max_length')).to(device)
429
+
430
+ with torch.no_grad():
431
+ out = self.text_encoder(prompt_tokens.unsqueeze(0), output_attentions=True)
432
+ prompt_emb = out.last_hidden_state.detach().clone()
433
+
434
+ if len(guidance_emb.shape) == 1:
435
+ guidance_emb = torch.FloatTensor(guidance_emb).unsqueeze(0).unsqueeze(0)
436
+ else:
437
+ guidance_emb = torch.FloatTensor(guidance_emb).unsqueeze(0)
438
+ guidance_emb = guidance_emb.to(device)
439
+
440
+ eos_inds = torch.where(prompt_tokens.unsqueeze(0) == 49407)[1]
441
+ obj_word = category_name
442
+ obj_word_token = self.tokenizer.encode(obj_word)[-2]
443
+ chair_inds = torch.where(prompt_tokens.unsqueeze(0) == obj_word_token)[1]
444
+
445
+ eos_strength = 0.8
446
+ obj_strength = 1.0
447
+
448
+ self.shape2clip_model.eval()
449
+ with torch.no_grad():
450
+ guided_prompt_emb_cond = self.shape2clip_model(prompt_emb.float(), guidance_emb[:,:,:].float()).half()
451
+ guided_prompt_emb = guided_prompt_emb_cond.clone()
452
+
453
+ guided_prompt_emb[:,:1] = 0
454
+ guided_prompt_emb[:,:chair_inds] = 0
455
+ guided_prompt_emb[:,chair_inds] *= obj_strength
456
+ guided_prompt_emb[:,eos_inds+1:] = 0
457
+ guided_prompt_emb[:,eos_inds] *= eos_strength
458
+ guided_prompt_emb[:,chair_inds+1:eos_inds:] = 0
459
+ fin_guidance = guided_prompt_emb
460
+
461
+ return fin_guidance, prompt_emb
462
+
463
+ # For ZeroGPU compatibility, uncomment this decorator when using ZeroGPU
464
+ @spaces.GPU(duration=120)
465
+ def generate_images(self, prompt, category, selected_shape_idx, guidance_strength, seed):
466
+ # Clear status text immediately
467
+ status = ""
468
+
469
+ # Check if the category is in the prompt
470
+ if category not in prompt:
471
+ # Add the category to the prompt
472
+ prompt = f"{prompt} {category}"
473
+ status = f"<div style='padding: 10px; background-color: #f0f7ff; border-left: 5px solid #3498db; margin-bottom: 10px;'>Note: Added '{category}' to your prompt since it was missing.</div>"
474
+
475
+ # Verify that the prompt doesn't contain other conflicting categories
476
+ for other_category in self.available_categories:
477
+ if other_category != category:
478
+ # Check with word boundaries to avoid partial matches
479
+ # e.g., "dishwasher" shouldn't match "washer"
480
+ if f" {other_category} " in f" {prompt} " or prompt == other_category:
481
+ return [], f"<div style='padding: 10px; background-color: #ffebee; border-left: 5px solid #e74c3c; font-weight: bold; margin-bottom: 10px;'>⚠️ ERROR: Your prompt contains '{other_category}' but you selected '{category}'. Please use matching category in prompt and selection.</div>"
482
+
483
+ # Load category embeddings if not already loaded
484
+ pb_dict, all_ids = self.load_category_embeddings(category)
485
+ if pb_dict is None or not all_ids:
486
+ return [], f"<div style='padding: 10px; background-color: #ffebee; border-left: 5px solid #e74c3c; font-weight: bold; margin-bottom: 10px;'>⚠️ ERROR: Failed to load embeddings for {category}</div>"
487
+
488
+ # Ensure shape index is valid
489
+ if selected_shape_idx is None or selected_shape_idx < 0:
490
+ selected_shape_idx = 0
491
+
492
+ max_idx = len(all_ids) - 1
493
+ selected_shape_idx = max(0, min(selected_shape_idx, max_idx))
494
+ guidance_shape_id = all_ids[selected_shape_idx]
495
+
496
+ # Set device and generator
497
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
498
+ generator = torch.Generator(device=device).manual_seed(seed)
499
+
500
+ results = []
501
+
502
+ # Add status message for generation
503
+ updating_status = f"<div style='padding: 10px; background-color: #e8f5e9; border-left: 5px solid #4caf50; margin-bottom: 10px;'>Generating images using Shape #{selected_shape_idx}...</div>"
504
+
505
+ try:
506
+ # For ZeroGPU, move models to GPU if not already there
507
+ if hasattr(spaces, 'GPU'):
508
+ self.pipeline = self.pipeline.to(device)
509
+ self.shape2clip_model = self.shape2clip_model.to(device)
510
+
511
+ # Generate base image (without guidance)
512
+ with torch.no_grad():
513
+ base_images = self.pipeline(
514
+ prompt=prompt,
515
+ num_inference_steps=50,
516
+ num_images_per_prompt=1,
517
+ generator=generator,
518
+ guidance_scale=7.5
519
+ ).images
520
+
521
+ base_image = base_images[0]
522
+ base_image = self.draw_text(base_image, "Unguided result")
523
+ results.append(base_image)
524
+ except Exception as e:
525
+ print(f"Error generating base image: {e}")
526
+ status = f"<div style='padding: 10px; background-color: #ffebee; border-left: 5px solid #e74c3c; font-weight: bold; margin-bottom: 10px;'>⚠️ ERROR: Error generating base image: {str(e)}</div>"
527
+ return results, status
528
+
529
+ try:
530
+ # Get shape guidance image
531
+ ulip_image = self.get_ulip_image(guidance_shape_id)
532
+ ulip_image = self.draw_text(ulip_image, "Guidance shape")
533
+ results.append(ulip_image)
534
+ except Exception as e:
535
+ print(f"Error getting guidance shape: {e}")
536
+ status = f"<div style='padding: 10px; background-color: #ffebee; border-left: 5px solid #e74c3c; font-weight: bold; margin-bottom: 10px;'>⚠️ ERROR: Error getting guidance shape: {str(e)}</div>"
537
+ return results, status
538
+
539
+ try:
540
+ # Get shape guidance embedding
541
+ pb_emb = pb_dict[guidance_shape_id]
542
+ out_guidance, prompt_emb = self.get_guidance(prompt, category, pb_emb)
543
+ except Exception as e:
544
+ print(f"Error generating guidance: {e}")
545
+ status = f"<div style='padding: 10px; background-color: #ffebee; border-left: 5px solid #e74c3c; font-weight: bold; margin-bottom: 10px;'>⚠️ ERROR: Error generating guidance: {str(e)}</div>"
546
+ return results, status
547
+
548
+ try:
549
+ # Generate guided image
550
+ generator = torch.Generator(device=device).manual_seed(seed)
551
+ with torch.no_grad():
552
+ guided_images = self.pipeline(
553
+ prompt_embeds=prompt_emb + guidance_strength * out_guidance,
554
+ num_inference_steps=50,
555
+ num_images_per_prompt=1,
556
+ generator=generator,
557
+ guidance_scale=7.5
558
+ ).images
559
+
560
+ guided_image = guided_images[0]
561
+ guided_image = self.draw_text(guided_image, f"Guided result (λ={guidance_strength:.1f})")
562
+ results.append(guided_image)
563
+
564
+ # Success status
565
+ status = f"<div style='padding: 10px; background-color: #e8f5e9; border-left: 5px solid #4caf50; margin-bottom: 10px;'>✓ Successfully generated images using Shape #{selected_shape_idx} from category '{category}'.</div>"
566
+
567
+ # For ZeroGPU, optionally move models back to CPU to free resources
568
+ if hasattr(spaces, 'GPU'):
569
+ self.pipeline = self.pipeline.to('cpu')
570
+ self.shape2clip_model = self.shape2clip_model.to('cpu')
571
+ torch.cuda.empty_cache()
572
+
573
+ except Exception as e:
574
+ print(f"Error generating guided image: {e}")
575
+ status = f"<div style='padding: 10px; background-color: #ffebee; border-left: 5px solid #e74c3c; font-weight: bold; margin-bottom: 10px;'>⚠️ ERROR: Error generating guided image: {str(e)}</div>"
576
+
577
+ return results, status
578
+
579
+ def update_prompt_for_category(self, old_prompt, new_category):
580
+ # Remove all existing categories from the prompt
581
+ cleaned_prompt = old_prompt
582
+ for cat in self.available_categories:
583
+ # Skip the current category
584
+ if cat == new_category:
585
+ continue
586
+
587
+ # Replace the category with a space, being careful about word boundaries
588
+ cleaned_prompt = cleaned_prompt.replace(f" {cat} ", " ")
589
+ cleaned_prompt = cleaned_prompt.replace(f" {cat}", "")
590
+ cleaned_prompt = cleaned_prompt.replace(f"{cat} ", "")
591
+ # Only do exact match for the whole prompt
592
+ if cleaned_prompt == cat:
593
+ cleaned_prompt = ""
594
+
595
+ # Add the new category if it's not already in the cleaned prompt
596
+ cleaned_prompt = cleaned_prompt.strip()
597
+ if new_category not in cleaned_prompt:
598
+ if cleaned_prompt:
599
+ return f"{cleaned_prompt} {new_category}"
600
+ else:
601
+ return new_category
602
+ else:
603
+ return cleaned_prompt
604
+
605
+ def on_demo_load(self):
606
+ """Function to ensure initial image is loaded when demo starts"""
607
+ default_category = "chair" if "chair" in self.available_categories else self.available_categories[0]
608
+ initial_img = self.get_shape_preview(default_category, 0)
609
+ return initial_img
610
+
611
+ def create_ui(self):
612
+ # Ensure chair is in available categories, otherwise use the first available
613
+ default_category = "chair" if "chair" in self.available_categories else self.available_categories[0]
614
+
615
+ with gr.Blocks(title="ShapeWords: Guiding Text-to-Image Synthesis with 3D Shape-Aware Prompts") as demo:
616
+ gr.Markdown("""
617
+ # ShapeWords: Guiding Text-to-Image Synthesis with 3D Shape-Aware Prompts
618
+
619
+ ShapeWords incorporates target 3D shape information with text prompts to guide image synthesis.
620
+
621
+ - **Website**: [ShapeWords Project Page](https://lodurality.github.io/shapewords/)
622
+ - **Paper**: [ArXiv](https://arxiv.org/abs/2412.02912)
623
+ - **Publication**: Accepted to CVPR 2025
624
+ """)
625
+
626
+ with gr.Row():
627
+ with gr.Column(scale=1):
628
+ prompt = gr.Textbox(
629
+ label="Prompt",
630
+ placeholder="an aquarelle drawing of a chair",
631
+ value=f"an aquarelle drawing of a {default_category}"
632
+ )
633
+
634
+ category = gr.Dropdown(
635
+ label="Object Category",
636
+ choices=self.available_categories,
637
+ value=default_category
638
+ )
639
+
640
+ # Hidden field to store selected shape index
641
+ selected_shape_idx = gr.Number(
642
+ value=0,
643
+ visible=False
644
+ )
645
+
646
+ # Create a slider for shape selection with preview
647
+ with gr.Row():
648
+ with gr.Column(scale=1):
649
+ # Slider for shape selection
650
+ shape_slider = gr.Slider(
651
+ minimum=0,
652
+ maximum=self.category_counts.get(default_category, 0) - 1,
653
+ step=1,
654
+ value=0,
655
+ label="Shape Index",
656
+ interactive=True
657
+ )
658
+
659
+ # Display shape index counter
660
+ shape_counter = gr.Markdown(f"Shape 0 of {self.category_counts.get(default_category, 0) - 1}")
661
+
662
+ # Quick navigation buttons
663
+ with gr.Row():
664
+ jump_start_btn = gr.Button("⏮️ First", size="sm")
665
+ random_btn = gr.Button("🎲 Random", size="sm", variant="secondary")
666
+ jump_end_btn = gr.Button("Last ⏭️", size="sm")
667
+
668
+ with gr.Row():
669
+ prev_shape_btn = gr.Button("◀️ Previous", size="sm")
670
+ next_shape_btn = gr.Button("Next ▶️", size="sm")
671
+
672
+ with gr.Column(scale=1):
673
+ # Preview image for the current shape
674
+ current_shape_image = gr.Image(
675
+ label="Selected Shape",
676
+ height=300,
677
+ width=300
678
+ )
679
+
680
+ guidance_strength = gr.Slider(
681
+ minimum=0.0, maximum=1.0, step=0.1, value=0.9,
682
+ label="Guidance Strength (λ)"
683
+ )
684
+
685
+ seed = gr.Slider(
686
+ minimum=0, maximum=10000, step=1, value=42,
687
+ label="Random Seed"
688
+ )
689
+
690
+ run_button = gr.Button("Generate Images", variant="primary")
691
+
692
+ info = gr.Markdown("""
693
+ **Note**: Higher guidance strength (λ) means stronger adherence to the 3D shape.
694
+ Start with λ=0.9 for a good balance between shape and prompt adherence.
695
+ """)
696
+
697
+ status_text = gr.HTML("")
698
+
699
+ with gr.Column(scale=2):
700
+ gallery = gr.Gallery(
701
+ label="Results",
702
+ show_label=True,
703
+ elem_id="results_gallery",
704
+ columns=3,
705
+ height="auto"
706
+ )
707
+
708
+ # Make sure the initial image is loaded when the demo starts
709
+ demo.load(
710
+ fn=self.on_demo_load,
711
+ inputs=None,
712
+ outputs=[current_shape_image]
713
+ )
714
+
715
+ # Connect slider to update preview
716
+ shape_slider.change(
717
+ fn=self.on_slider_change,
718
+ inputs=[shape_slider, category],
719
+ outputs=[current_shape_image, shape_counter, selected_shape_idx]
720
+ )
721
+
722
+ # Previous shape button
723
+ prev_shape_btn.click(
724
+ fn=self.prev_shape,
725
+ inputs=[selected_shape_idx, category],
726
+ outputs=[shape_slider, current_shape_image, shape_counter]
727
+ )
728
+
729
+ # Next shape button
730
+ next_shape_btn.click(
731
+ fn=self.next_shape,
732
+ inputs=[selected_shape_idx, category],
733
+ outputs=[shape_slider, current_shape_image, shape_counter]
734
+ )
735
+
736
+ # Jump to start button
737
+ jump_start_btn.click(
738
+ fn=self.jump_to_start,
739
+ inputs=[category],
740
+ outputs=[shape_slider, current_shape_image, shape_counter]
741
+ )
742
+
743
+ # Jump to end button
744
+ jump_end_btn.click(
745
+ fn=self.jump_to_end,
746
+ inputs=[category],
747
+ outputs=[shape_slider, current_shape_image, shape_counter]
748
+ )
749
+
750
+ # Random shape button
751
+ random_btn.click(
752
+ fn=self.random_shape,
753
+ inputs=[category],
754
+ outputs=[shape_slider, current_shape_image, shape_counter]
755
+ )
756
+
757
+ # Update the UI when category changes
758
+ category.change(
759
+ fn=self.on_category_change,
760
+ inputs=[category],
761
+ outputs=[shape_slider, selected_shape_idx, current_shape_image, shape_counter]
762
+ )
763
+
764
+ # Automatically update prompt when category changes
765
+ category.change(
766
+ fn=self.update_prompt_for_category,
767
+ inputs=[prompt, category],
768
+ outputs=[prompt]
769
+ )
770
+
771
+ # Clear status text before generating new images
772
+ run_button.click(
773
+ fn=lambda: None, # Empty function to clear the status
774
+ inputs=None,
775
+ outputs=[status_text]
776
+ )
777
+
778
+ # Generate images when button is clicked
779
+ run_button.click(
780
+ fn=self.generate_images,
781
+ inputs=[prompt, category, selected_shape_idx, guidance_strength, seed],
782
+ outputs=[gallery, status_text]
783
+ )
784
+
785
+ gr.Markdown("""
786
+ ## Credits
787
+
788
+ This demo is based on the ShapeWords paper by Petrov et al. (2024) accepted to CVPR 2025.
789
+
790
+ If you use this in your work, please cite:
791
+ ```
792
+ @misc{petrov2024shapewords,
793
+ title={ShapeWords: Guiding Text-to-Image Synthesis with 3D Shape-Aware Prompts},
794
+ author={Dmitry Petrov and Pradyumn Goyal and Divyansh Shivashok and Yuanming Tao and Melinos Averkiou and Evangelos Kalogerakis},
795
+ year={2024},
796
+ eprint={2412.02912},
797
+ archivePrefix={arXiv},
798
+ primaryClass={cs.CV},
799
+ url={https://arxiv.org/abs/2412.02912},
800
+ }
801
+ ```
802
+ """)
803
+
804
+ return demo
805
+
806
 
807
+ # Main function and entry point
808
+ def main():
809
+ parser = argparse.ArgumentParser(description="ShapeWords Gradio Demo")
810
+ parser.add_argument('--share', action='store_true', help='Create a public link')
811
+ args = parser.parse_args()
812
+
813
+ # Create the demo app and UI
814
+ app = ShapeWordsDemo()
815
+ demo = app.create_ui()
816
+ demo.launch(share=args.share)
817
 
818
+ if __name__ == "__main__":
819
+ main()