dmpetrov commited on
Commit
6e699f5
·
1 Parent(s): 7c78d1f

added updated melinos code with point clouds and random prompts

Browse files
Files changed (2) hide show
  1. app.py +434 -259
  2. shapewords_paper_code +1 -1
app.py CHANGED
@@ -31,7 +31,7 @@ Usage:
31
 
32
  This demo allows users to:
33
  1. Select a 3D object category
34
- 2. Choose a specific 3D shape using a slider
35
  3. Enter a text prompt
36
  4. Generate images guided by the selected 3D shape
37
 
@@ -46,32 +46,34 @@ import gradio as gr
46
  from PIL import Image, ImageFont, ImageDraw
47
  from diffusers.utils import load_image
48
  from diffusers import DPMSolverMultistepScheduler, StableDiffusionPipeline
49
- #import open_clip
50
  import gdown
51
  import argparse
52
  import random
53
- import spaces
 
 
 
54
 
55
-
56
- os.environ['CUDA_LAUNCH_BLOCKING'] = "1"
57
 
58
  class ShapeWordsDemo:
59
  # Constants
60
  NAME2CAT = {
61
- "chair": "03001627", "table": "04379243", "jar": "03593526", "skateboard": "04225987",
62
- "car": "02958343", "bottle": "02876657", "tower": "04460130", "bookshelf": "02871439",
63
- "camera": "02942699", "airplane": "02691156", "laptop": "03642806", "basket": "02801938",
64
- "sofa": "04256520", "knife": "03624134", "can": "02946921", "rifle": "04090263",
65
- "train": "04468005", "pillow": "03938244", "lamp": "03636649", "trash bin": "02747177",
66
- "mailbox": "03710193", "watercraft": "04530566", "motorbike": "03790512",
67
- "dishwasher": "03207941", "bench": "02828884", "pistol": "03948459", "rocket": "04099429",
68
- "loudspeaker": "03691459", "file cabinet": "03337140", "bag": "02773838",
69
- "cabinet": "02933112", "bed": "02818832", "birdhouse": "02843684", "display": "03211117",
70
- "piano": "03928116", "earphone": "03261776", "telephone": "04401088", "stove": "04330267",
71
- "microphone": "03759954", "bus": "02924116", "mug": "03797390", "remote": "04074963",
72
- "bathtub": "02808440", "bowl": "02880940", "keyboard": "03085013", "guitar": "03467517",
73
- "washer": "04554684", "bicycle": "02834778", "faucet": "03325088", "printer": "04004475",
74
- "cap": "02954340", "phone": "02992529", "clock": "03046257", "helmet": "03513137",
75
  "microwave": "03761084", "plant": "03991062"
76
  }
77
 
@@ -86,30 +88,30 @@ class ShapeWordsDemo:
86
  self.available_categories = []
87
  self.shape_thumbnail_cache = {} # Cache for shape thumbnails
88
  self.CAT2NAME = {v: k for k, v in self.NAME2CAT.items()}
89
-
90
  # Initialize all models and data
91
  self.initialize_models()
92
 
93
  def draw_text(self, img, text, color=(10, 10, 10), size=80, location=(200, 30)):
94
  img = img.copy()
95
  draw = ImageDraw.Draw(img)
96
-
97
  try:
98
  font = ImageFont.truetype("Arial", size=size)
99
  except IOError:
100
  font = ImageFont.load_default()
101
-
102
  bbox = draw.textbbox(location, text, font=font)
103
  draw.rectangle(bbox, fill="white")
104
  draw.text(location, text, color, font=font)
105
-
106
  return img
107
 
108
  def get_ulip_image(self, guidance_shape_id, angle='036'):
109
  shape_id_ulip = guidance_shape_id.replace('_', '-')
110
  ulip_template = 'https://storage.googleapis.com/sfr-ulip-code-release-research/shapenet-55/only_rgb_depth_images/{}_r_{}_depth0001.png'
111
  ulip_path = ulip_template.format(shape_id_ulip, angle)
112
-
113
  try:
114
  ulip_image = load_image(ulip_path).resize((512, 512))
115
  return ulip_image
@@ -117,56 +119,40 @@ class ShapeWordsDemo:
117
  print(f"Error loading image: {e}")
118
  return Image.new('RGB', (512, 512), color='gray')
119
 
120
- def get_ulip_thumbnail(self, guidance_shape_id, angle='036', size=(150, 150)):
121
- """Get a thumbnail version of the ULIP image for use in the gallery"""
122
- image = self.get_ulip_image(guidance_shape_id, angle)
123
- return image.resize(size)
124
-
125
  def initialize_models(self):
126
- device = torch.device("cpu" if torch.cuda.is_available() else "cpu")
127
- print(f"Using device: {device}")
128
-
 
129
  # Download Shape2CLIP code if it doesn't exist
130
  if not os.path.exists("shapewords_paper_code"):
131
  print("Loading models file")
132
  os.system("git clone https://github.com/lodurality/shapewords_paper_code.git")
133
-
134
  # Import Shape2CLIP model
135
  sys.path.append("./shapewords_paper_code")
136
  from shapewords_paper_code.geometry_guidance_models import Shape2CLIP
137
-
138
  # Initialize the pipeline
139
  self.pipeline = StableDiffusionPipeline.from_pretrained(
140
- "stabilityai/stable-diffusion-2-1-base",
141
  torch_dtype=torch.float16 if device.type == "cuda" else torch.float32
142
  )
143
-
144
  self.pipeline.scheduler = DPMSolverMultistepScheduler.from_config(
145
- self.pipeline.scheduler.config,
146
  algorithm_type="sde-dpmsolver++"
147
  )
148
-
149
- # Load CLIP model
150
- #clip_model, _, preprocess = open_clip.create_model_and_transforms(
151
- # 'ViT-H-14',
152
- # pretrained='laion2b_s32b_b79k'
153
- #)
154
-
155
- # Move models to device if not using ZeroGPU
156
- if device.type == "cuda":
157
- self.pipeline = self.pipeline.to(device)
158
- #self.pipeline.enable_model_cpu_offload()
159
-
160
- #clip_tokenizer = open_clip.get_tokenizer('ViT-H-14')
161
  self.text_encoder = self.pipeline.text_encoder
162
  self.tokenizer = self.pipeline.tokenizer
163
-
164
  # Look for Shape2CLIP checkpoint in multiple locations
165
  checkpoint_paths = [
166
- "/data/projection_model-0920192.pth",
167
- "/data/embeddings/projection_model-0920192.pth"
168
  ]
169
-
170
  checkpoint_found = False
171
  checkpoint_path = None
172
  for path in checkpoint_paths:
@@ -175,43 +161,40 @@ class ShapeWordsDemo:
175
  print(f"Found Shape2CLIP checkpoint at: {checkpoint_path}")
176
  checkpoint_found = True
177
  break
178
-
179
  # Download Shape2CLIP checkpoint if not found
180
  if not checkpoint_found:
181
  checkpoint_path = "projection_model-0920192.pth"
182
  print("Downloading Shape2CLIP model checkpoint...")
183
- gdown.download("1nvEXnwMpNkRts6rxVqMZt8i9FZ40KjP7", checkpoint_path, quiet=False)
184
  print("Download complete")
185
-
186
  # Initialize Shape2CLIP model
187
  self.shape2clip_model = Shape2CLIP(depth=6, drop_path_rate=0.1, pb_dim=384)
188
  self.shape2clip_model.load_state_dict(torch.load(checkpoint_path, map_location=device))
189
- if device.type == "cuda":
190
- self.shape2clip_model = self.shape2clip_model.to(device)
191
  self.shape2clip_model.eval()
192
-
193
  # Scan for available embeddings
194
  self.scan_available_embeddings()
195
 
196
  def scan_available_embeddings(self):
197
  self.available_categories = []
198
  self.category_counts = {}
199
-
 
200
  for category, cat_id in self.NAME2CAT.items():
201
  possible_filenames = [
202
- f"pointbert_shapenet_{cat_id}.npz",
203
  f"{cat_id}_pb_embs.npz",
204
- f"embeddings/pointbert_shapenet_{cat_id}.npz",
205
- f"embeddings/{cat_id}_pb_embs.npz",
206
- f"/data/shapenet_pointbert_tokens/{cat_id}_pb_embs.npz"
207
  ]
208
-
209
  found_file = None
210
  for filename in possible_filenames:
211
  if os.path.exists(filename):
212
  found_file = filename
213
  break
214
-
215
  if found_file:
216
  try:
217
  pb_data = np.load(found_file)
@@ -224,42 +207,41 @@ class ShapeWordsDemo:
224
  count = len(pb_data[keys[0]])
225
  else:
226
  count = 0
227
-
228
  if count > 0:
229
  self.available_categories.append(category)
230
  self.category_counts[category] = count
231
  print(f"Found {count} embeddings for category '{category}'")
232
  except Exception as e:
233
  print(f"Error loading embeddings for {category}: {e}")
234
-
235
- if not self.available_categories:
236
- self.available_categories = ["chair"] # Fallback
237
- self.category_counts["chair"] = 50 # Default value
238
-
239
  # Sort categories alphabetically
240
  self.available_categories.sort()
241
-
242
  print(f"Found {len(self.available_categories)} categories with embeddings")
243
  print(f"Available categories: {', '.join(self.available_categories)}")
 
 
 
 
 
244
 
245
  def load_category_embeddings(self, category):
246
  if category in self.category_embeddings:
247
  return self.category_embeddings[category]
248
-
249
  if category not in self.NAME2CAT:
250
  return None, []
251
-
252
  cat_id = self.NAME2CAT[category]
253
-
254
  # Check for different possible embedding filenames and locations
255
  possible_filenames = [
256
- f"pointbert_shapenet_{cat_id}.npz",
257
- f"{cat_id}_pb_embs.npz",
258
- f"embeddings/pointbert_shapenet_{cat_id}.npz",
259
  f"embeddings/{cat_id}_pb_embs.npz",
260
- f"/data/shapenet_pointbert_tokens/{cat_id}_pb_embs.npz"
261
  ]
262
-
263
  # Find the first existing file
264
  pb_emb_filename = None
265
  for filename in possible_filenames:
@@ -267,16 +249,16 @@ class ShapeWordsDemo:
267
  pb_emb_filename = filename
268
  print(f"Found embeddings file: {pb_emb_filename}")
269
  break
270
-
271
  if pb_emb_filename is None:
272
  print(f"No embeddings found for {category}")
273
  return None, []
274
-
275
  # Load embeddings
276
  try:
277
  print(f"Loading embeddings from {pb_emb_filename}...")
278
  pb_data = np.load(pb_emb_filename)
279
-
280
  # Check for different key names in the NPZ file
281
  if 'ids' in pb_data and 'embs' in pb_data:
282
  pb_dict = dict(zip(pb_data['ids'], pb_data['embs']))
@@ -289,10 +271,10 @@ class ShapeWordsDemo:
289
  else:
290
  print("Unexpected embedding file format")
291
  return None, []
292
-
293
  all_ids = sorted(list(pb_dict.keys()))
294
  print(f"Loaded {len(all_ids)} shape embeddings for {category}")
295
-
296
  # Cache the results
297
  self.category_embeddings[category] = (pb_dict, all_ids)
298
  return pb_dict, all_ids
@@ -301,90 +283,280 @@ class ShapeWordsDemo:
301
  print(f"Exception details: {str(e)}")
302
  return None, []
303
 
304
- def get_shape_preview(self, category, shape_idx, size=(300, 300)):
305
- """Get a preview image for a specific shape"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
306
  if shape_idx is None or shape_idx < 0:
307
  return None
308
-
 
309
  pb_dict, all_ids = self.load_category_embeddings(category)
310
  if pb_dict is None or not all_ids or shape_idx >= len(all_ids):
311
  return None
 
312
  shape_id = all_ids[shape_idx]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
313
  try:
314
- # Get the shape image at the requested size
315
  preview_image = self.get_ulip_image(shape_id)
316
- preview_image = preview_image.resize(size)
317
- preview_with_text = self.draw_text(preview_image, f"Shape #{shape_idx}", size=30, location=(10, 10))
318
- return preview_with_text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
319
  except Exception as e:
320
  print(f"Error loading preview for {shape_id}: {e}")
321
- # Create an empty error image
322
- empty_img = Image.new('RGB', size, color='gray')
323
- error_text = f"Error loading Shape #{shape_idx}"
324
- return self.draw_text(empty_img, error_text, size=30, location=(10, 10))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
325
 
326
  def on_slider_change(self, shape_idx, category):
327
  """Update the preview when the slider changes"""
328
  max_idx = self.category_counts.get(category, 0) - 1
329
-
330
  # Get preview image
331
  preview_image = self.get_shape_preview(category, shape_idx)
332
- #
333
  # Update counter text
334
  counter_text = f"Shape {shape_idx} of {max_idx}"
335
-
336
  return preview_image, counter_text, shape_idx
337
 
338
  def prev_shape(self, current_idx, category):
339
  """Go to previous shape"""
340
  max_idx = self.category_counts.get(category, 0) - 1
341
  new_idx = max(0, current_idx - 1)
342
-
343
  # Get preview image
344
  preview_image = self.get_shape_preview(category, new_idx)
345
-
346
  # Update counter text
347
  counter_text = f"Shape {new_idx} of {max_idx}"
348
-
349
  return new_idx, preview_image, counter_text
350
 
351
  def next_shape(self, current_idx, category):
352
  """Go to next shape"""
353
  max_idx = self.category_counts.get(category, 0) - 1
354
  new_idx = min(max_idx, current_idx + 1)
355
-
356
  # Get preview image
357
  preview_image = self.get_shape_preview(category, new_idx)
358
-
359
  # Update counter text
360
  counter_text = f"Shape {new_idx} of {max_idx}"
361
-
362
  return new_idx, preview_image, counter_text
363
 
364
  def jump_to_start(self, category):
365
  """Jump to the first shape"""
366
  max_idx = self.category_counts.get(category, 0) - 1
367
  new_idx = 0
368
-
369
  # Get preview image
370
  preview_image = self.get_shape_preview(category, new_idx)
371
-
372
  # Update counter text
373
  counter_text = f"Shape {new_idx} of {max_idx}"
374
-
375
  return new_idx, preview_image, counter_text
376
 
377
  def jump_to_end(self, category):
378
  """Jump to the last shape"""
379
  max_idx = self.category_counts.get(category, 0) - 1
380
  new_idx = max_idx
381
-
382
  # Get preview image
383
  preview_image = self.get_shape_preview(category, new_idx)
384
-
385
  # Update counter text
386
  counter_text = f"Shape {new_idx} of {max_idx}"
387
-
388
  return new_idx, preview_image, counter_text
389
 
390
  def random_shape(self, category):
@@ -392,30 +564,49 @@ class ShapeWordsDemo:
392
  max_idx = self.category_counts.get(category, 0) - 1
393
  if max_idx <= 0:
394
  return 0, self.get_shape_preview(category, 0), f"Shape 0 of 0"
395
-
396
  # Generate random index
397
  random_idx = random.randint(0, max_idx)
398
-
399
  # Get preview image
400
  preview_image = self.get_shape_preview(category, random_idx)
401
-
402
  # Update counter text
403
  counter_text = f"Shape {random_idx} of {max_idx}"
404
-
405
  return random_idx, preview_image, counter_text
406
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
407
  def on_category_change(self, category):
408
  """Update the slider and preview when the category changes"""
409
  # Reset to the first shape
410
  current_idx = 0
411
  max_idx = self.category_counts.get(category, 0) - 1
412
-
413
  # Get preview image
414
  preview_image = self.get_shape_preview(category, current_idx)
415
-
416
  # Update counter text
417
  counter_text = f"Shape {current_idx} of {max_idx}"
418
-
419
  # Need to update the slider range
420
  new_slider = gr.Slider(
421
  minimum=0,
@@ -424,19 +615,20 @@ class ShapeWordsDemo:
424
  value=current_idx,
425
  label="Shape Index"
426
  )
427
-
428
  return new_slider, current_idx, preview_image, counter_text
429
 
430
  def get_guidance(self, test_prompt, category_name, guidance_emb):
431
- print("Getting guidance")
432
  print(test_prompt, category_name)
433
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
 
434
  prompt_tokens = torch.LongTensor(self.tokenizer.encode(test_prompt, padding='max_length')).to(device)
 
435
  with torch.no_grad():
436
  out = self.text_encoder(prompt_tokens.unsqueeze(0), output_attentions=True)
437
  prompt_emb = out.last_hidden_state.detach().clone()
438
-
439
-
440
  if len(guidance_emb.shape) == 1:
441
  guidance_emb = torch.FloatTensor(guidance_emb).unsqueeze(0).unsqueeze(0)
442
  else:
@@ -455,7 +647,7 @@ class ShapeWordsDemo:
455
  with torch.no_grad():
456
  guided_prompt_emb_cond = self.shape2clip_model(prompt_emb.float(), guidance_emb[:,:,:].float()).half()
457
  guided_prompt_emb = guided_prompt_emb_cond.clone()
458
-
459
  guided_prompt_emb[:,:1] = 0
460
  guided_prompt_emb[:,:chair_inds] = 0
461
  guided_prompt_emb[:,chair_inds] *= obj_strength
@@ -466,72 +658,76 @@ class ShapeWordsDemo:
466
 
467
  return fin_guidance, prompt_emb
468
 
469
- # For ZeroGPU compatibility, uncomment this decorator when using ZeroGPU
470
  @spaces.GPU(duration=120)
471
  def generate_images(self, prompt, category, selected_shape_idx, guidance_strength, seed):
 
 
 
 
 
 
 
 
472
  # Clear status text immediately
473
  status = ""
474
-
475
- # Check if the category is in the prompt
476
- if category not in prompt:
477
- # Add the category to the prompt
478
- prompt = f"{prompt} {category}"
479
- status = f"<div style='padding: 10px; background-color: #f0f7ff; border-left: 5px solid #3498db; margin-bottom: 10px;'>Note: Added '{category}' to your prompt since it was missing.</div>"
480
-
481
- # Verify that the prompt doesn't contain other conflicting categories
 
 
 
 
 
482
  for other_category in self.available_categories:
483
- if other_category != category:
484
- # Check with word boundaries to avoid partial matches
485
- # e.g., "dishwasher" shouldn't match "washer"
486
- if f" {other_category} " in f" {prompt} " or prompt == other_category:
487
- return [], f"<div style='padding: 10px; background-color: #ffebee; border-left: 5px solid #e74c3c; font-weight: bold; margin-bottom: 10px;'>⚠️ ERROR: Your prompt contains '{other_category}' but you selected '{category}'. Please use matching category in prompt and selection.</div>"
488
-
489
  # Load category embeddings if not already loaded
490
  pb_dict, all_ids = self.load_category_embeddings(category)
491
  if pb_dict is None or not all_ids:
492
- return [], f"<div style='padding: 10px; background-color: #ffebee; border-left: 5px solid #e74c3c; font-weight: bold; margin-bottom: 10px;'>⚠️ ERROR: Failed to load embeddings for {category}</div>"
493
-
 
494
  # Ensure shape index is valid
495
  if selected_shape_idx is None or selected_shape_idx < 0:
496
  selected_shape_idx = 0
497
-
498
  max_idx = len(all_ids) - 1
499
  selected_shape_idx = max(0, min(selected_shape_idx, max_idx))
500
  guidance_shape_id = all_ids[selected_shape_idx]
501
-
502
- # Set device and generator
503
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
504
  generator = torch.Generator(device=device).manual_seed(seed)
505
-
506
  results = []
507
-
508
- # Add status message for generation
509
- updating_status = f"<div style='padding: 10px; background-color: #e8f5e9; border-left: 5px solid #4caf50; margin-bottom: 10px;'>Generating images using Shape #{selected_shape_idx}...</div>"
510
-
511
  try:
512
- # For ZeroGPU, move models to GPU if not already there
513
- if hasattr(spaces, 'GPU'):
514
- self.pipeline = self.pipeline.to(device)
515
- self.shape2clip_model = self.shape2clip_model.to(device)
516
-
517
  # Generate base image (without guidance)
518
  with torch.no_grad():
519
  base_images = self.pipeline(
520
- prompt=prompt,
521
  num_inference_steps=50,
522
  num_images_per_prompt=1,
523
  generator=generator,
524
  guidance_scale=7.5
525
  ).images
526
-
527
  base_image = base_images[0]
528
  base_image = self.draw_text(base_image, "Unguided result")
529
  results.append(base_image)
530
  except Exception as e:
531
  print(f"Error generating base image: {e}")
532
- status = f"<div style='padding: 10px; background-color: #ffebee; border-left: 5px solid #e74c3c; font-weight: bold; margin-bottom: 10px;'>⚠️ ERROR: Error generating base image: {str(e)}</div>"
533
  return results, status
534
- print('Base image done')
535
  try:
536
  # Get shape guidance image
537
  ulip_image = self.get_ulip_image(guidance_shape_id)
@@ -539,19 +735,18 @@ class ShapeWordsDemo:
539
  results.append(ulip_image)
540
  except Exception as e:
541
  print(f"Error getting guidance shape: {e}")
542
- status = f"<div style='padding: 10px; background-color: #ffebee; border-left: 5px solid #e74c3c; font-weight: bold; margin-bottom: 10px;'>⚠️ ERROR: Error getting guidance shape: {str(e)}</div>"
543
  return results, status
544
- print("ULIP image done")
545
  try:
546
  # Get shape guidance embedding
547
  pb_emb = pb_dict[guidance_shape_id]
548
- print('Got pb emb')
549
- out_guidance, prompt_emb = self.get_guidance(prompt, category, pb_emb)
550
  except Exception as e:
551
  print(f"Error generating guidance: {e}")
552
- status = f"<div style='padding: 10px; background-color: #ffebee; border-left: 5px solid #e74c3c; font-weight: bold; margin-bottom: 10px;'>⚠️ ERROR: Error generating guidance: {str(e)}</div>"
553
  return results, status
554
- print("Guidance done")
555
  try:
556
  # Generate guided image
557
  generator = torch.Generator(device=device).manual_seed(seed)
@@ -563,51 +758,21 @@ class ShapeWordsDemo:
563
  generator=generator,
564
  guidance_scale=7.5
565
  ).images
566
-
567
  guided_image = guided_images[0]
568
  guided_image = self.draw_text(guided_image, f"Guided result (λ={guidance_strength:.1f})")
569
  results.append(guided_image)
570
-
571
  # Success status
572
- status = f"<div style='padding: 10px; background-color: #e8f5e9; border-left: 5px solid #4caf50; margin-bottom: 10px;'>✓ Successfully generated images using Shape #{selected_shape_idx} from category '{category}'.</div>"
573
-
574
- # For ZeroGPU, optionally move models back to CPU to free resources
575
- if hasattr(spaces, 'GPU'):
576
- self.pipeline = self.pipeline.to('cpu')
577
- self.shape2clip_model = self.shape2clip_model.to('cpu')
578
- torch.cuda.empty_cache()
579
-
580
  except Exception as e:
581
  print(f"Error generating guided image: {e}")
582
- status = f"<div style='padding: 10px; background-color: #ffebee; border-left: 5px solid #e74c3c; font-weight: bold; margin-bottom: 10px;'>⚠️ ERROR: Error generating guided image: {str(e)}</div>"
583
-
584
- return results, status
585
 
586
- def update_prompt_for_category(self, old_prompt, new_category):
587
- # Remove all existing categories from the prompt
588
- cleaned_prompt = old_prompt
589
- for cat in self.available_categories:
590
- # Skip the current category
591
- if cat == new_category:
592
- continue
593
-
594
- # Replace the category with a space, being careful about word boundaries
595
- cleaned_prompt = cleaned_prompt.replace(f" {cat} ", " ")
596
- cleaned_prompt = cleaned_prompt.replace(f" {cat}", "")
597
- cleaned_prompt = cleaned_prompt.replace(f"{cat} ", "")
598
- # Only do exact match for the whole prompt
599
- if cleaned_prompt == cat:
600
- cleaned_prompt = ""
601
-
602
- # Add the new category if it's not already in the cleaned prompt
603
- cleaned_prompt = cleaned_prompt.strip()
604
- if new_category not in cleaned_prompt:
605
- if cleaned_prompt:
606
- return f"{cleaned_prompt} {new_category}"
607
- else:
608
- return new_category
609
- else:
610
- return cleaned_prompt
611
 
612
  def on_demo_load(self):
613
  """Function to ensure initial image is loaded when demo starts"""
@@ -618,7 +783,7 @@ class ShapeWordsDemo:
618
  def create_ui(self):
619
  # Ensure chair is in available categories, otherwise use the first available
620
  default_category = "chair" if "chair" in self.available_categories else self.available_categories[0]
621
-
622
  with gr.Blocks(title="ShapeWords: Guiding Text-to-Image Synthesis with 3D Shape-Aware Prompts") as demo:
623
  gr.Markdown("""
624
  # ShapeWords: Guiding Text-to-Image Synthesis with 3D Shape-Aware Prompts
@@ -629,27 +794,35 @@ class ShapeWordsDemo:
629
  - **Paper**: [ArXiv](https://arxiv.org/abs/2412.02912)
630
  - **Publication**: Accepted to CVPR 2025
631
  """)
632
-
633
  with gr.Row():
634
  with gr.Column(scale=1):
635
  prompt = gr.Textbox(
636
- label="Prompt",
637
- placeholder="an aquarelle drawing of a chair",
638
- value=f"an aquarelle drawing of a {default_category}"
639
  )
640
-
 
 
 
 
 
 
 
 
641
  category = gr.Dropdown(
642
- label="Object Category",
643
  choices=self.available_categories,
644
  value=default_category
645
  )
646
-
647
  # Hidden field to store selected shape index
648
  selected_shape_idx = gr.Number(
649
  value=0,
650
  visible=False
651
  )
652
-
653
  # Create a slider for shape selection with preview
654
  with gr.Row():
655
  with gr.Column(scale=1):
@@ -662,47 +835,48 @@ class ShapeWordsDemo:
662
  label="Shape Index",
663
  interactive=True
664
  )
665
-
666
  # Display shape index counter
667
  shape_counter = gr.Markdown(f"Shape 0 of {self.category_counts.get(default_category, 0) - 1}")
668
-
669
  # Quick navigation buttons
670
  with gr.Row():
671
  jump_start_btn = gr.Button("⏮️ First", size="sm")
672
- random_btn = gr.Button("🎲 Random", size="sm", variant="secondary")
673
  jump_end_btn = gr.Button("Last ⏭️", size="sm")
674
-
675
  with gr.Row():
676
  prev_shape_btn = gr.Button("◀️ Previous", size="sm")
677
  next_shape_btn = gr.Button("Next ▶️", size="sm")
678
-
679
  with gr.Column(scale=1):
680
- # Preview image for the current shape
681
- current_shape_image = gr.Image(
682
- label="Selected Shape",
683
- height=300,
684
- width=300
 
685
  )
686
-
687
  guidance_strength = gr.Slider(
688
  minimum=0.0, maximum=1.0, step=0.1, value=0.9,
689
  label="Guidance Strength (λ)"
690
  )
691
-
692
  seed = gr.Slider(
693
  minimum=0, maximum=10000, step=1, value=42,
694
  label="Random Seed"
695
  )
696
-
697
  run_button = gr.Button("Generate Images", variant="primary")
698
-
699
  info = gr.Markdown("""
700
  **Note**: Higher guidance strength (λ) means stronger adherence to the 3D shape.
701
  Start with λ=0.9 for a good balance between shape and prompt adherence.
702
  """)
703
-
704
  status_text = gr.HTML("")
705
-
706
  with gr.Column(scale=2):
707
  gallery = gr.Gallery(
708
  label="Results",
@@ -711,84 +885,84 @@ class ShapeWordsDemo:
711
  columns=3,
712
  height="auto"
713
  )
714
-
715
  # Make sure the initial image is loaded when the demo starts
716
  demo.load(
717
  fn=self.on_demo_load,
718
  inputs=None,
719
- outputs=[current_shape_image]
720
  )
721
-
722
  # Connect slider to update preview
723
  shape_slider.change(
724
  fn=self.on_slider_change,
725
  inputs=[shape_slider, category],
726
- outputs=[current_shape_image, shape_counter, selected_shape_idx]
727
  )
728
-
729
  # Previous shape button
730
  prev_shape_btn.click(
731
  fn=self.prev_shape,
732
  inputs=[selected_shape_idx, category],
733
- outputs=[shape_slider, current_shape_image, shape_counter]
734
  )
735
-
736
  # Next shape button
737
  next_shape_btn.click(
738
  fn=self.next_shape,
739
  inputs=[selected_shape_idx, category],
740
- outputs=[shape_slider, current_shape_image, shape_counter]
741
  )
742
-
743
  # Jump to start button
744
  jump_start_btn.click(
745
  fn=self.jump_to_start,
746
  inputs=[category],
747
- outputs=[shape_slider, current_shape_image, shape_counter]
748
  )
749
-
750
  # Jump to end button
751
  jump_end_btn.click(
752
  fn=self.jump_to_end,
753
  inputs=[category],
754
- outputs=[shape_slider, current_shape_image, shape_counter]
755
  )
756
-
757
  # Random shape button
758
  random_btn.click(
759
  fn=self.random_shape,
760
  inputs=[category],
761
- outputs=[shape_slider, current_shape_image, shape_counter]
762
  )
763
-
 
 
 
 
 
 
 
764
  # Update the UI when category changes
765
  category.change(
766
  fn=self.on_category_change,
767
  inputs=[category],
768
- outputs=[shape_slider, selected_shape_idx, current_shape_image, shape_counter]
769
- )
770
-
771
- # Automatically update prompt when category changes
772
- category.change(
773
- fn=self.update_prompt_for_category,
774
- inputs=[prompt, category],
775
- outputs=[prompt]
776
  )
777
-
778
  # Clear status text before generating new images
779
  run_button.click(
780
  fn=lambda: None, # Empty function to clear the status
781
  inputs=None,
782
  outputs=[status_text]
783
  )
784
-
785
  # Generate images when button is clicked
786
  run_button.click(
787
  fn=self.generate_images,
788
  inputs=[prompt, category, selected_shape_idx, guidance_strength, seed],
789
  outputs=[gallery, status_text]
790
  )
791
-
792
  gr.Markdown("""
793
  ## Credits
794
 
@@ -807,7 +981,7 @@ class ShapeWordsDemo:
807
  }
808
  ```
809
  """)
810
-
811
  return demo
812
 
813
 
@@ -816,11 +990,12 @@ def main():
816
  parser = argparse.ArgumentParser(description="ShapeWords Gradio Demo")
817
  parser.add_argument('--share', action='store_true', help='Create a public link')
818
  args = parser.parse_args()
819
-
820
  # Create the demo app and UI
821
  app = ShapeWordsDemo()
822
  demo = app.create_ui()
823
  demo.launch(share=args.share)
824
 
 
825
  if __name__ == "__main__":
826
- main()
 
31
 
32
  This demo allows users to:
33
  1. Select a 3D object category
34
+ 2. Choose a specific 3D shape
35
  3. Enter a text prompt
36
  4. Generate images guided by the selected 3D shape
37
 
 
46
  from PIL import Image, ImageFont, ImageDraw
47
  from diffusers.utils import load_image
48
  from diffusers import DPMSolverMultistepScheduler, StableDiffusionPipeline
 
49
  import gdown
50
  import argparse
51
  import random
52
+ import spaces # for Hugging Face ZeroGPU deployment
53
+ import re
54
+ import plotly.graph_objects as go
55
+ from numpy.lib.user_array import container
56
 
57
+ # Only for Hugging Face hosting - Add the Hugging Face cache to persistent storage to avoid downloading safetensors every time the demo sleeps and wakes up
58
+ os.environ['HF_HOME'] = '/data/.huggingface'
59
 
60
  class ShapeWordsDemo:
61
  # Constants
62
  NAME2CAT = {
63
+ "chair": "03001627", "table": "04379243", "jar": "03593526", "skateboard": "04225987",
64
+ "car": "02958343", "bottle": "02876657", "tower": "04460130", "bookshelf": "02871439",
65
+ "camera": "02942699", "airplane": "02691156", "laptop": "03642806", "basket": "02801938",
66
+ "sofa": "04256520", "knife": "03624134", "can": "02946921", "rifle": "04090263",
67
+ "train": "04468005", "pillow": "03938244", "lamp": "03636649", "trash bin": "02747177",
68
+ "mailbox": "03710193", "watercraft": "04530566", "motorbike": "03790512",
69
+ "dishwasher": "03207941", "bench": "02828884", "pistol": "03948459", "rocket": "04099429",
70
+ "loudspeaker": "03691459", "file cabinet": "03337140", "bag": "02773838",
71
+ "cabinet": "02933112", "bed": "02818832", "birdhouse": "02843684", "display": "03211117",
72
+ "piano": "03928116", "earphone": "03261776", "telephone": "04401088", "stove": "04330267",
73
+ "microphone": "03759954", "bus": "02924116", "mug": "03797390", "remote": "04074963",
74
+ "bathtub": "02808440", "bowl": "02880940", "keyboard": "03085013", "guitar": "03467517",
75
+ "washer": "04554684", "bicycle": "02834778", "faucet": "03325088", "printer": "04004475",
76
+ "cap": "02954340", "phone": "02992529", "clock": "03046257", "helmet": "03513137",
77
  "microwave": "03761084", "plant": "03991062"
78
  }
79
 
 
88
  self.available_categories = []
89
  self.shape_thumbnail_cache = {} # Cache for shape thumbnails
90
  self.CAT2NAME = {v: k for k, v in self.NAME2CAT.items()}
91
+ self.category_point_clouds = {}
92
  # Initialize all models and data
93
  self.initialize_models()
94
 
95
  def draw_text(self, img, text, color=(10, 10, 10), size=80, location=(200, 30)):
96
  img = img.copy()
97
  draw = ImageDraw.Draw(img)
98
+
99
  try:
100
  font = ImageFont.truetype("Arial", size=size)
101
  except IOError:
102
  font = ImageFont.load_default()
103
+
104
  bbox = draw.textbbox(location, text, font=font)
105
  draw.rectangle(bbox, fill="white")
106
  draw.text(location, text, color, font=font)
107
+
108
  return img
109
 
110
  def get_ulip_image(self, guidance_shape_id, angle='036'):
111
  shape_id_ulip = guidance_shape_id.replace('_', '-')
112
  ulip_template = 'https://storage.googleapis.com/sfr-ulip-code-release-research/shapenet-55/only_rgb_depth_images/{}_r_{}_depth0001.png'
113
  ulip_path = ulip_template.format(shape_id_ulip, angle)
114
+
115
  try:
116
  ulip_image = load_image(ulip_path).resize((512, 512))
117
  return ulip_image
 
119
  print(f"Error loading image: {e}")
120
  return Image.new('RGB', (512, 512), color='gray')
121
 
 
 
 
 
 
122
  def initialize_models(self):
123
+ # device = DEVICE
124
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
125
+ print(f"Using device: {device} in initialize_models")
126
+
127
  # Download Shape2CLIP code if it doesn't exist
128
  if not os.path.exists("shapewords_paper_code"):
129
  print("Loading models file")
130
  os.system("git clone https://github.com/lodurality/shapewords_paper_code.git")
131
+
132
  # Import Shape2CLIP model
133
  sys.path.append("./shapewords_paper_code")
134
  from shapewords_paper_code.geometry_guidance_models import Shape2CLIP
135
+
136
  # Initialize the pipeline
137
  self.pipeline = StableDiffusionPipeline.from_pretrained(
138
+ "stabilityai/stable-diffusion-2-1-base",
139
  torch_dtype=torch.float16 if device.type == "cuda" else torch.float32
140
  )
141
+
142
  self.pipeline.scheduler = DPMSolverMultistepScheduler.from_config(
143
+ self.pipeline.scheduler.config,
144
  algorithm_type="sde-dpmsolver++"
145
  )
146
+
 
 
 
 
 
 
 
 
 
 
 
 
147
  self.text_encoder = self.pipeline.text_encoder
148
  self.tokenizer = self.pipeline.tokenizer
149
+
150
  # Look for Shape2CLIP checkpoint in multiple locations
151
  checkpoint_paths = [
152
+ "./projection_model-0920192.pth",
153
+ "/data/projection_model-0920192.pth" # if using Hugging Face persistent storage look in a /data/ directory
154
  ]
155
+
156
  checkpoint_found = False
157
  checkpoint_path = None
158
  for path in checkpoint_paths:
 
161
  print(f"Found Shape2CLIP checkpoint at: {checkpoint_path}")
162
  checkpoint_found = True
163
  break
164
+
165
  # Download Shape2CLIP checkpoint if not found
166
  if not checkpoint_found:
167
  checkpoint_path = "projection_model-0920192.pth"
168
  print("Downloading Shape2CLIP model checkpoint...")
169
+ gdown.download("https://drive.google.com/uc?id=1nvEXnwMpNkRts6rxVqMZt8i9FZ40KjP7", checkpoint_path, quiet=False) # download in same directory as app.py
170
  print("Download complete")
171
+
172
  # Initialize Shape2CLIP model
173
  self.shape2clip_model = Shape2CLIP(depth=6, drop_path_rate=0.1, pb_dim=384)
174
  self.shape2clip_model.load_state_dict(torch.load(checkpoint_path, map_location=device))
 
 
175
  self.shape2clip_model.eval()
176
+
177
  # Scan for available embeddings
178
  self.scan_available_embeddings()
179
 
180
  def scan_available_embeddings(self):
181
  self.available_categories = []
182
  self.category_counts = {}
183
+
184
+ # Try to find PointBert embeddings for all 55 ShapeNetCore shape categories
185
  for category, cat_id in self.NAME2CAT.items():
186
  possible_filenames = [
 
187
  f"{cat_id}_pb_embs.npz",
188
+ f"embeddings/{cat_id}_pb_embs.npz",
189
+ f"/data/shapenet_pointbert_tokens/{cat_id}_pb_embs.npz" # if using Hugging Face persistent storage look in a /data/shapenet_pointbert_tokens directory
 
190
  ]
191
+
192
  found_file = None
193
  for filename in possible_filenames:
194
  if os.path.exists(filename):
195
  found_file = filename
196
  break
197
+
198
  if found_file:
199
  try:
200
  pb_data = np.load(found_file)
 
207
  count = len(pb_data[keys[0]])
208
  else:
209
  count = 0
210
+
211
  if count > 0:
212
  self.available_categories.append(category)
213
  self.category_counts[category] = count
214
  print(f"Found {count} embeddings for category '{category}'")
215
  except Exception as e:
216
  print(f"Error loading embeddings for {category}: {e}")
217
+
 
 
 
 
218
  # Sort categories alphabetically
219
  self.available_categories.sort()
220
+
221
  print(f"Found {len(self.available_categories)} categories with embeddings")
222
  print(f"Available categories: {', '.join(self.available_categories)}")
223
+
224
+ # No embeddings found for any category - DEMO CANNOT RUN - but still load the interface with a default placeholder category, an error will be displayed when trying to generate images
225
+ if not self.available_categories:
226
+ self.available_categories = ["chair"] # Fallback
227
+ self.category_counts["chair"] = 50 # Default value
228
 
229
  def load_category_embeddings(self, category):
230
  if category in self.category_embeddings:
231
  return self.category_embeddings[category]
232
+
233
  if category not in self.NAME2CAT:
234
  return None, []
235
+
236
  cat_id = self.NAME2CAT[category]
237
+
238
  # Check for different possible embedding filenames and locations
239
  possible_filenames = [
240
+ f"{cat_id}_pb_embs.npz",
 
 
241
  f"embeddings/{cat_id}_pb_embs.npz",
242
+ f"/data/shapenet_pointbert_tokens/{cat_id}_pb_embs.npz" # if using Hugging Face persistent storage look in a /data/shapenet_pointbert_tokens directory
243
  ]
244
+
245
  # Find the first existing file
246
  pb_emb_filename = None
247
  for filename in possible_filenames:
 
249
  pb_emb_filename = filename
250
  print(f"Found embeddings file: {pb_emb_filename}")
251
  break
252
+
253
  if pb_emb_filename is None:
254
  print(f"No embeddings found for {category}")
255
  return None, []
256
+
257
  # Load embeddings
258
  try:
259
  print(f"Loading embeddings from {pb_emb_filename}...")
260
  pb_data = np.load(pb_emb_filename)
261
+
262
  # Check for different key names in the NPZ file
263
  if 'ids' in pb_data and 'embs' in pb_data:
264
  pb_dict = dict(zip(pb_data['ids'], pb_data['embs']))
 
271
  else:
272
  print("Unexpected embedding file format")
273
  return None, []
274
+
275
  all_ids = sorted(list(pb_dict.keys()))
276
  print(f"Loaded {len(all_ids)} shape embeddings for {category}")
277
+
278
  # Cache the results
279
  self.category_embeddings[category] = (pb_dict, all_ids)
280
  return pb_dict, all_ids
 
283
  print(f"Exception details: {str(e)}")
284
  return None, []
285
 
286
+ def load_category_point_clouds(self, category):
287
+ """Load all point clouds for a category from a single NPZ file"""
288
+ if category not in self.NAME2CAT:
289
+ return None
290
+
291
+ cat_id = self.NAME2CAT[category]
292
+
293
+ # Cache to avoid reloading
294
+ if category in self.category_point_clouds:
295
+ return self.category_point_clouds[category]
296
+
297
+ # Check for different possible point cloud filenames
298
+ possible_filenames = [
299
+ f"{cat_id}.npz",
300
+ f"point_clouds/{cat_id}_clouds.npz",
301
+ f"/point_clouds/{cat_id}_clouds.npz",
302
+ f"/data/point_clouds/{cat_id}_clouds.npz" # For Hugging Face persistent storage
303
+ ]
304
+
305
+ # Find the first existing file
306
+ pc_filename = None
307
+ for filename in possible_filenames:
308
+ if os.path.exists(filename):
309
+ pc_filename = filename
310
+ print(f"Found point cloud file: {pc_filename}")
311
+ break
312
+
313
+ if pc_filename is None:
314
+ print(f"No point cloud file found for category {category}")
315
+ return None
316
+
317
+ # Load point clouds
318
+ try:
319
+ print(f"Loading point clouds from {pc_filename}...")
320
+ pc_data = np.load(pc_filename, allow_pickle=True)
321
+
322
+ # Cache the loaded data
323
+ self.category_point_clouds[category] = pc_data
324
+
325
+ return pc_data
326
+ except Exception as e:
327
+ print(f"Error loading point clouds: {e}")
328
+ return None
329
+
330
+ def get_shape_preview(self, category, shape_idx):
331
+ """Get a 3D point cloud visualization for a specific shape"""
332
  if shape_idx is None or shape_idx < 0:
333
  return None
334
+
335
+ # Get shape ID
336
  pb_dict, all_ids = self.load_category_embeddings(category)
337
  if pb_dict is None or not all_ids or shape_idx >= len(all_ids):
338
  return None
339
+
340
  shape_id = all_ids[shape_idx]
341
+
342
+ # Load all point clouds for this category
343
+ pc_data = self.load_category_point_clouds(category)
344
+ if pc_data is None:
345
+ # Fallback to image if point clouds not available
346
+ return self.get_shape_image_preview(category, shape_idx, shape_id)
347
+
348
+ # Extract point cloud for this specific shape
349
+ try:
350
+ # Get the arrays from the npz file
351
+ ids = pc_data['ids']
352
+ clouds = pc_data['clouds']
353
+
354
+ matching_indices = np.where(ids == shape_id)[0]
355
+
356
+ # Check number of matches
357
+ if len(matching_indices) == 0:
358
+ # No matches found - log error and fall back to image
359
+ print(f"Error: Shape ID {shape_id} not found in point cloud data")
360
+ return self.get_shape_image_preview(category, shape_idx, shape_id)
361
+ elif len(matching_indices) > 1:
362
+ # Multiple matches found - unexpected data issue - we will get the first one
363
+ print(f"Warning: Multiple matches ({len(matching_indices)}) found for Shape ID {shape_id}. Using first match.")
364
+
365
+ # Get the corresponding point cloud
366
+ matching_idx = matching_indices[0]
367
+ points = clouds[matching_idx]
368
+
369
+ # Create 3D visualization
370
+ fig = self.get_shape_pointcloud_preview(points, title=f"Shape #{shape_idx}")
371
+ return fig
372
+
373
+ except Exception as e:
374
+ print(f"Error extracting point cloud for {shape_id}: {e}")
375
+ return self.get_shape_image_preview(category, shape_idx, shape_id)
376
+
377
+ def get_shape_image_preview(self, category, shape_idx, shape_id):
378
+ """Fallback to image preview if point cloud not available"""
379
  try:
 
380
  preview_image = self.get_ulip_image(shape_id)
381
+ preview_image = preview_image.resize((300, 300))
382
+ preview_with_text = self.draw_text(preview_image, f"Shape #{shape_idx}", size=80, location=(10, 10))
383
+
384
+ # Convert PIL image to plotly figure
385
+ fig = go.Figure()
386
+
387
+ # Need to convert PIL image to a format plotly can use
388
+ import io
389
+ import base64
390
+
391
+ # Convert PIL image to base64
392
+ buf = io.BytesIO()
393
+ preview_with_text.save(buf, format='PNG')
394
+ img_str = base64.b64encode(buf.getvalue()).decode('utf-8')
395
+
396
+ # Add image to figure
397
+ fig.add_layout_image(
398
+ dict(
399
+ source=f"data:image/png;base64,{img_str}",
400
+ xref="paper", yref="paper",
401
+ x=0, y=1,
402
+ sizex=1, sizey=1,
403
+ sizing="contain",
404
+ layer="below"
405
+ )
406
+ )
407
+
408
+ fig.update_layout(
409
+ title=f"Shape #{shape_idx} (2D Preview - 3D not available)",
410
+ xaxis=dict(showgrid=False, zeroline=False, visible=False, range=[0, 1]),
411
+ yaxis=dict(showgrid=False, zeroline=False, visible=False, range=[0, 1], scaleanchor="x", scaleratio=1),
412
+ margin=dict(l=0, r=0, b=0, t=0),
413
+ height=450,
414
+ width=450,
415
+ plot_bgcolor='rgba(0,0,0,0)' # Transparent background
416
+ )
417
+
418
+ return fig
419
  except Exception as e:
420
  print(f"Error loading preview for {shape_id}: {e}")
421
+ # Create empty figure with error message
422
+ fig = go.Figure()
423
+ fig.update_layout(
424
+ title=f"Error loading Shape #{shape_idx}",
425
+ annotations=[dict(
426
+ text="Preview not available",
427
+ showarrow=False,
428
+ xref="paper", yref="paper",
429
+ x=0.5, y=0.5,
430
+ ont=dict(size=16, color="#E53935"), # Red error text
431
+ align="center"
432
+ )],
433
+ height=450,
434
+ width=450,
435
+ margin=dict(l=0, r=0, b=0, t=30, pad=0),
436
+ paper_bgcolor='rgba(0,0,0,0)',
437
+ plot_bgcolor='rgba(0,0,0,0)' # Transparent background
438
+ )
439
+ return fig
440
+
441
+ def get_shape_pointcloud_preview(self, points, title=None):
442
+ """Create a clean 3D point cloud visualization with Y as up axis"""
443
+ # Sample points for better performance (fewer points = smoother interaction)
444
+ sampled_points = points[::1] # Take every Nth point
445
+
446
+ # Create 3D scatter plot with fixed color
447
+ fig = go.Figure(data=[go.Scatter3d(
448
+ x=sampled_points[:, 0],
449
+ y=sampled_points[:, 1], # Use Z as Y (up axis)
450
+ z=sampled_points[:, 2], # Use Y as Z
451
+ mode='markers',
452
+ marker=dict(
453
+ size=2.5,
454
+ color='#4285F4', # Fixed blue color
455
+ opacity=1
456
+ )
457
+ )])
458
+
459
+ fig.update_layout(
460
+ title=dict(text=title,
461
+ xanchor='center',
462
+ x=0.5
463
+ ),
464
+ scene=dict(
465
+ # Remove all axes elements
466
+ xaxis=dict(visible=False, showticklabels=False, showgrid=False, zeroline=False, showline=False,
467
+ showbackground=False),
468
+ yaxis=dict(visible=False, showticklabels=False, showgrid=False, zeroline=False, showline=False,
469
+ showbackground=False),
470
+ zaxis=dict(visible=False, showticklabels=False, showgrid=False, zeroline=False, showline=False,
471
+ showbackground=False),
472
+ aspectmode='data' # Maintain data aspect ratio
473
+ ),
474
+ # Eliminate margins
475
+ margin=dict(l=0, r=0, b=0, t=30, pad=0),
476
+ autosize=True,
477
+ # Control modebar appearance through layout
478
+ modebar=dict(
479
+ bgcolor='white',
480
+ color='#333',
481
+ orientation='v', # Vertical orientation
482
+ activecolor='#009688'
483
+ ),
484
+ paper_bgcolor='rgba(0,0,0,0)', # Transparent background
485
+ )
486
+
487
+ # Better camera angle
488
+ fig.update_layout(
489
+ scene_camera=dict(
490
+ eye=dict(x=-1.5, y=0.5, z=-1.5),
491
+ up=dict(x=0, y=1, z=0), # Y is up
492
+ center=dict(x=0, y=0, z=0)
493
+ )
494
+ )
495
+
496
+ return fig
497
 
498
  def on_slider_change(self, shape_idx, category):
499
  """Update the preview when the slider changes"""
500
  max_idx = self.category_counts.get(category, 0) - 1
501
+
502
  # Get preview image
503
  preview_image = self.get_shape_preview(category, shape_idx)
504
+
505
  # Update counter text
506
  counter_text = f"Shape {shape_idx} of {max_idx}"
507
+
508
  return preview_image, counter_text, shape_idx
509
 
510
  def prev_shape(self, current_idx, category):
511
  """Go to previous shape"""
512
  max_idx = self.category_counts.get(category, 0) - 1
513
  new_idx = max(0, current_idx - 1)
514
+
515
  # Get preview image
516
  preview_image = self.get_shape_preview(category, new_idx)
517
+
518
  # Update counter text
519
  counter_text = f"Shape {new_idx} of {max_idx}"
520
+
521
  return new_idx, preview_image, counter_text
522
 
523
  def next_shape(self, current_idx, category):
524
  """Go to next shape"""
525
  max_idx = self.category_counts.get(category, 0) - 1
526
  new_idx = min(max_idx, current_idx + 1)
527
+
528
  # Get preview image
529
  preview_image = self.get_shape_preview(category, new_idx)
530
+
531
  # Update counter text
532
  counter_text = f"Shape {new_idx} of {max_idx}"
533
+
534
  return new_idx, preview_image, counter_text
535
 
536
  def jump_to_start(self, category):
537
  """Jump to the first shape"""
538
  max_idx = self.category_counts.get(category, 0) - 1
539
  new_idx = 0
540
+
541
  # Get preview image
542
  preview_image = self.get_shape_preview(category, new_idx)
543
+
544
  # Update counter text
545
  counter_text = f"Shape {new_idx} of {max_idx}"
546
+
547
  return new_idx, preview_image, counter_text
548
 
549
  def jump_to_end(self, category):
550
  """Jump to the last shape"""
551
  max_idx = self.category_counts.get(category, 0) - 1
552
  new_idx = max_idx
553
+
554
  # Get preview image
555
  preview_image = self.get_shape_preview(category, new_idx)
556
+
557
  # Update counter text
558
  counter_text = f"Shape {new_idx} of {max_idx}"
559
+
560
  return new_idx, preview_image, counter_text
561
 
562
  def random_shape(self, category):
 
564
  max_idx = self.category_counts.get(category, 0) - 1
565
  if max_idx <= 0:
566
  return 0, self.get_shape_preview(category, 0), f"Shape 0 of 0"
567
+
568
  # Generate random index
569
  random_idx = random.randint(0, max_idx)
570
+
571
  # Get preview image
572
  preview_image = self.get_shape_preview(category, random_idx)
573
+
574
  # Update counter text
575
  counter_text = f"Shape {random_idx} of {max_idx}"
576
+
577
  return random_idx, preview_image, counter_text
578
 
579
+ def random_prompt(self):
580
+ """Select a random prompt from the predefined list"""
581
+ prompts = [
582
+ 'a low poly 3d rendering of a [CATEGORY]',
583
+ 'an aquarelle drawing of a [CATEGORY]',
584
+ 'a photo of a [CATEGORY] on a beach',
585
+ 'a charcoal drawing of a [CATEGORY]',
586
+ 'a Hieronymus Bosch painting of a [CATEGORY]',
587
+ 'a [CATEGORY] under a tree',
588
+ 'A Kazimir Malevich painting of a [CATEGORY]',
589
+ 'a vector graphic of a [CATEGORY]',
590
+ 'a Claude Monet painting of a [CATEGORY]',
591
+ 'a Salvador Dali painting of a [CATEGORY]',
592
+ 'an Art Deco poster of a [CATEGORY]'
593
+ ]
594
+
595
+ # Get a random prompt
596
+ return random.choice(prompts)
597
+
598
  def on_category_change(self, category):
599
  """Update the slider and preview when the category changes"""
600
  # Reset to the first shape
601
  current_idx = 0
602
  max_idx = self.category_counts.get(category, 0) - 1
603
+
604
  # Get preview image
605
  preview_image = self.get_shape_preview(category, current_idx)
606
+
607
  # Update counter text
608
  counter_text = f"Shape {current_idx} of {max_idx}"
609
+
610
  # Need to update the slider range
611
  new_slider = gr.Slider(
612
  minimum=0,
 
615
  value=current_idx,
616
  label="Shape Index"
617
  )
618
+
619
  return new_slider, current_idx, preview_image, counter_text
620
 
621
  def get_guidance(self, test_prompt, category_name, guidance_emb):
 
622
  print(test_prompt, category_name)
623
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
624
+ print(f"Using device: {device} in get_guidance")
625
+
626
  prompt_tokens = torch.LongTensor(self.tokenizer.encode(test_prompt, padding='max_length')).to(device)
627
+
628
  with torch.no_grad():
629
  out = self.text_encoder(prompt_tokens.unsqueeze(0), output_attentions=True)
630
  prompt_emb = out.last_hidden_state.detach().clone()
631
+
 
632
  if len(guidance_emb.shape) == 1:
633
  guidance_emb = torch.FloatTensor(guidance_emb).unsqueeze(0).unsqueeze(0)
634
  else:
 
647
  with torch.no_grad():
648
  guided_prompt_emb_cond = self.shape2clip_model(prompt_emb.float(), guidance_emb[:,:,:].float()).half()
649
  guided_prompt_emb = guided_prompt_emb_cond.clone()
650
+
651
  guided_prompt_emb[:,:1] = 0
652
  guided_prompt_emb[:,:chair_inds] = 0
653
  guided_prompt_emb[:,chair_inds] *= obj_strength
 
658
 
659
  return fin_guidance, prompt_emb
660
 
 
661
  @spaces.GPU(duration=120)
662
  def generate_images(self, prompt, category, selected_shape_idx, guidance_strength, seed):
663
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
664
+ print(f"Using device: {device} in generate_images")
665
+
666
+ # Move models to gpu
667
+ if device.type == "cuda":
668
+ self.pipeline = self.pipeline.to(device)
669
+ self.shape2clip_model = self.shape2clip_model.to(device)
670
+
671
  # Clear status text immediately
672
  status = ""
673
+
674
+ # Replace [CATEGORY] with the selected category (case-insensitive)
675
+ category_pattern = re.compile(r'\[CATEGORY\]', re.IGNORECASE)
676
+ if re.search(category_pattern, prompt):
677
+ # Use re.sub for replacement to maintain the same casing pattern that was used
678
+ final_prompt = re.sub(category_pattern, category, prompt)
679
+ else:
680
+ # Fallback if user didn't use placeholder
681
+ final_prompt = f"{prompt} {category}"
682
+ status = status + f"<div style='padding: 10px; background-color: #f0f7ff; border-left: 5px solid #3498db; margin-bottom: 10px;'>Note: For better results, use [CATEGORY] in your prompt where you want '{category}' to appear, otherwise it is appended at the end of the prompt.</div>"
683
+
684
+ error = False
685
+ # Check if prompt contains any other categories
686
  for other_category in self.available_categories:
687
+ if re.search(r'\b' + re.escape(other_category) + r'\b', prompt):
688
+ status = status + f"<div style='padding: 10px; background-color: #ffebee; border-left: 5px solid #e74c3c; font-weight: bold; margin-bottom: 10px;'>⚠️ ERROR: Your prompt contains '{other_category}'. Please remove it and use [CATEGORY] instead.</div>"
689
+ error = True
690
+ if error:
691
+ return [], status
692
+
693
  # Load category embeddings if not already loaded
694
  pb_dict, all_ids = self.load_category_embeddings(category)
695
  if pb_dict is None or not all_ids:
696
+ status = status + f"<div style='padding: 10px; background-color: #ffebee; border-left: 5px solid #e74c3c; font-weight: bold; margin-bottom: 10px;'>⚠️ ERROR: Failed to load embeddings for {category}</div>"
697
+ return [], status
698
+
699
  # Ensure shape index is valid
700
  if selected_shape_idx is None or selected_shape_idx < 0:
701
  selected_shape_idx = 0
702
+
703
  max_idx = len(all_ids) - 1
704
  selected_shape_idx = max(0, min(selected_shape_idx, max_idx))
705
  guidance_shape_id = all_ids[selected_shape_idx]
706
+
707
+ # Set generator
 
708
  generator = torch.Generator(device=device).manual_seed(seed)
709
+
710
  results = []
711
+
 
 
 
712
  try:
 
 
 
 
 
713
  # Generate base image (without guidance)
714
  with torch.no_grad():
715
  base_images = self.pipeline(
716
+ prompt=final_prompt,
717
  num_inference_steps=50,
718
  num_images_per_prompt=1,
719
  generator=generator,
720
  guidance_scale=7.5
721
  ).images
722
+
723
  base_image = base_images[0]
724
  base_image = self.draw_text(base_image, "Unguided result")
725
  results.append(base_image)
726
  except Exception as e:
727
  print(f"Error generating base image: {e}")
728
+ status = status + f"<div style='padding: 10px; background-color: #ffebee; border-left: 5px solid #e74c3c; font-weight: bold; margin-bottom: 10px;'>⚠️ ERROR: Error generating base image: {str(e)}</div>"
729
  return results, status
730
+
731
  try:
732
  # Get shape guidance image
733
  ulip_image = self.get_ulip_image(guidance_shape_id)
 
735
  results.append(ulip_image)
736
  except Exception as e:
737
  print(f"Error getting guidance shape: {e}")
738
+ status = status + f"<div style='padding: 10px; background-color: #ffebee; border-left: 5px solid #e74c3c; font-weight: bold; margin-bottom: 10px;'>⚠️ ERROR: Error getting guidance shape: {str(e)}</div>"
739
  return results, status
740
+
741
  try:
742
  # Get shape guidance embedding
743
  pb_emb = pb_dict[guidance_shape_id]
744
+ out_guidance, prompt_emb = self.get_guidance(final_prompt, category, pb_emb)
 
745
  except Exception as e:
746
  print(f"Error generating guidance: {e}")
747
+ status = status + f"<div style='padding: 10px; background-color: #ffebee; border-left: 5px solid #e74c3c; font-weight: bold; margin-bottom: 10px;'>⚠️ ERROR: Error generating guidance: {str(e)}</div>"
748
  return results, status
749
+
750
  try:
751
  # Generate guided image
752
  generator = torch.Generator(device=device).manual_seed(seed)
 
758
  generator=generator,
759
  guidance_scale=7.5
760
  ).images
761
+
762
  guided_image = guided_images[0]
763
  guided_image = self.draw_text(guided_image, f"Guided result (λ={guidance_strength:.1f})")
764
  results.append(guided_image)
765
+
766
  # Success status
767
+ status = status + f"<div style='padding: 10px; background-color: #e8f5e9; border-left: 5px solid #4caf50; margin-bottom: 10px;'>✓ Successfully generated images using Shape #{selected_shape_idx} from category '{category}'.</div>"
768
+
769
+ torch.cuda.empty_cache()
770
+
 
 
 
 
771
  except Exception as e:
772
  print(f"Error generating guided image: {e}")
773
+ status = status + f"<div style='padding: 10px; background-color: #ffebee; border-left: 5px solid #e74c3c; font-weight: bold; margin-bottom: 10px;'>⚠️ ERROR: Error generating guided image: {str(e)}</div>"
 
 
774
 
775
+ return results, status
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
776
 
777
  def on_demo_load(self):
778
  """Function to ensure initial image is loaded when demo starts"""
 
783
  def create_ui(self):
784
  # Ensure chair is in available categories, otherwise use the first available
785
  default_category = "chair" if "chair" in self.available_categories else self.available_categories[0]
786
+
787
  with gr.Blocks(title="ShapeWords: Guiding Text-to-Image Synthesis with 3D Shape-Aware Prompts") as demo:
788
  gr.Markdown("""
789
  # ShapeWords: Guiding Text-to-Image Synthesis with 3D Shape-Aware Prompts
 
794
  - **Paper**: [ArXiv](https://arxiv.org/abs/2412.02912)
795
  - **Publication**: Accepted to CVPR 2025
796
  """)
797
+
798
  with gr.Row():
799
  with gr.Column(scale=1):
800
  prompt = gr.Textbox(
801
+ label="Prompt (use [CATEGORY] for object type)",
802
+ placeholder="an aquarelle drawing of a [CATEGORY]",
803
+ value=f"an aquarelle drawing of a [CATEGORY]"
804
  )
805
+
806
+ # Add help text below the prompt
807
+ help_text = gr.Markdown("""
808
+ **Tip:** Use [CATEGORY] in your prompt where you want the selected object type to appear.
809
+ For example: "a watercolor painting of a [CATEGORY] in the forest"
810
+ """)
811
+
812
+ random_prompt_btn = gr.Button("🎲 Random Prompt", size="sm", variant="secondary")
813
+
814
  category = gr.Dropdown(
815
+ label="Object Category",
816
  choices=self.available_categories,
817
  value=default_category
818
  )
819
+
820
  # Hidden field to store selected shape index
821
  selected_shape_idx = gr.Number(
822
  value=0,
823
  visible=False
824
  )
825
+
826
  # Create a slider for shape selection with preview
827
  with gr.Row():
828
  with gr.Column(scale=1):
 
835
  label="Shape Index",
836
  interactive=True
837
  )
838
+
839
  # Display shape index counter
840
  shape_counter = gr.Markdown(f"Shape 0 of {self.category_counts.get(default_category, 0) - 1}")
841
+
842
  # Quick navigation buttons
843
  with gr.Row():
844
  jump_start_btn = gr.Button("⏮️ First", size="sm")
845
+ random_btn = gr.Button("🎲 Random Shape", size="sm", variant="secondary")
846
  jump_end_btn = gr.Button("Last ⏭️", size="sm")
847
+
848
  with gr.Row():
849
  prev_shape_btn = gr.Button("◀️ Previous", size="sm")
850
  next_shape_btn = gr.Button("Next ▶️", size="sm")
851
+
852
  with gr.Column(scale=1):
853
+ gr.Markdown("### Selected Shape (3D Point Cloud)")
854
+ current_shape_plot = gr.Plot(
855
+ label=None,
856
+ scale=1, # Take up available space
857
+ show_label=False,
858
+ #container=False
859
  )
860
+
861
  guidance_strength = gr.Slider(
862
  minimum=0.0, maximum=1.0, step=0.1, value=0.9,
863
  label="Guidance Strength (λ)"
864
  )
865
+
866
  seed = gr.Slider(
867
  minimum=0, maximum=10000, step=1, value=42,
868
  label="Random Seed"
869
  )
870
+
871
  run_button = gr.Button("Generate Images", variant="primary")
872
+
873
  info = gr.Markdown("""
874
  **Note**: Higher guidance strength (λ) means stronger adherence to the 3D shape.
875
  Start with λ=0.9 for a good balance between shape and prompt adherence.
876
  """)
877
+
878
  status_text = gr.HTML("")
879
+
880
  with gr.Column(scale=2):
881
  gallery = gr.Gallery(
882
  label="Results",
 
885
  columns=3,
886
  height="auto"
887
  )
888
+
889
  # Make sure the initial image is loaded when the demo starts
890
  demo.load(
891
  fn=self.on_demo_load,
892
  inputs=None,
893
+ outputs=[current_shape_plot]
894
  )
895
+
896
  # Connect slider to update preview
897
  shape_slider.change(
898
  fn=self.on_slider_change,
899
  inputs=[shape_slider, category],
900
+ outputs=[current_shape_plot, shape_counter, selected_shape_idx]
901
  )
902
+
903
  # Previous shape button
904
  prev_shape_btn.click(
905
  fn=self.prev_shape,
906
  inputs=[selected_shape_idx, category],
907
+ outputs=[shape_slider, current_shape_plot, shape_counter]
908
  )
909
+
910
  # Next shape button
911
  next_shape_btn.click(
912
  fn=self.next_shape,
913
  inputs=[selected_shape_idx, category],
914
+ outputs=[shape_slider, current_shape_plot, shape_counter]
915
  )
916
+
917
  # Jump to start button
918
  jump_start_btn.click(
919
  fn=self.jump_to_start,
920
  inputs=[category],
921
+ outputs=[shape_slider, current_shape_plot, shape_counter]
922
  )
923
+
924
  # Jump to end button
925
  jump_end_btn.click(
926
  fn=self.jump_to_end,
927
  inputs=[category],
928
+ outputs=[shape_slider, current_shape_plot, shape_counter]
929
  )
930
+
931
  # Random shape button
932
  random_btn.click(
933
  fn=self.random_shape,
934
  inputs=[category],
935
+ outputs=[shape_slider, current_shape_plot, shape_counter]
936
  )
937
+
938
+ # Connect the random prompt button
939
+ random_prompt_btn.click(
940
+ fn=self.random_prompt,
941
+ inputs=[],
942
+ outputs=[prompt]
943
+ )
944
+
945
  # Update the UI when category changes
946
  category.change(
947
  fn=self.on_category_change,
948
  inputs=[category],
949
+ outputs=[shape_slider, selected_shape_idx, current_shape_plot, shape_counter]
 
 
 
 
 
 
 
950
  )
951
+
952
  # Clear status text before generating new images
953
  run_button.click(
954
  fn=lambda: None, # Empty function to clear the status
955
  inputs=None,
956
  outputs=[status_text]
957
  )
958
+
959
  # Generate images when button is clicked
960
  run_button.click(
961
  fn=self.generate_images,
962
  inputs=[prompt, category, selected_shape_idx, guidance_strength, seed],
963
  outputs=[gallery, status_text]
964
  )
965
+
966
  gr.Markdown("""
967
  ## Credits
968
 
 
981
  }
982
  ```
983
  """)
984
+
985
  return demo
986
 
987
 
 
990
  parser = argparse.ArgumentParser(description="ShapeWords Gradio Demo")
991
  parser.add_argument('--share', action='store_true', help='Create a public link')
992
  args = parser.parse_args()
993
+
994
  # Create the demo app and UI
995
  app = ShapeWordsDemo()
996
  demo = app.create_ui()
997
  demo.launch(share=args.share)
998
 
999
+
1000
  if __name__ == "__main__":
1001
+ main()
shapewords_paper_code CHANGED
@@ -1 +1 @@
1
- Subproject commit 797c3704038c4ee8e9d9869ee7f743c1253ee081
 
1
+ Subproject commit e4ebe6c6541505c2e7bc1068186f7045b1bfb51a