dmpetrov commited on
Commit
fe7f119
·
1 Parent(s): e97ae32

pushed new interface updated by Melinos

Browse files
Files changed (1) hide show
  1. app.py +331 -179
app.py CHANGED
@@ -6,8 +6,8 @@ A Gradio web interface for the ShapeWords paper, allowing users to generate
6
  images guided by 3D shape information.
7
 
8
  Author: Melinos Averkiou
9
- Date: 11 March 2025
10
- Version: 1.0
11
 
12
  Paper: "ShapeWords: Guiding Text-to-Image Synthesis with 3D Shape-Aware Prompts"
13
  arXiv: https://arxiv.org/abs/2412.02912
@@ -30,10 +30,10 @@ Usage:
30
  python app.py [--share]
31
 
32
  This demo allows users to:
33
- 1. Select a 3D object category
34
- 2. Choose a specific 3D shape
35
- 3. Enter a text prompt
36
- 4. Generate images guided by the selected 3D shape
37
 
38
  The code is structured as a class and is compatible with Hugging Face ZeroGPU deployment.
39
  """
@@ -89,35 +89,11 @@ class ShapeWordsDemo:
89
  self.shape_thumbnail_cache = {} # Cache for shape thumbnails
90
  self.CAT2NAME = {v: k for k, v in self.NAME2CAT.items()}
91
  self.category_point_clouds = {}
 
92
  # Initialize all models and data
93
  self.initialize_models()
94
 
95
- def draw_text(self, img, text, color=(10, 10, 10), size=80, location=(200, 30)):
96
- img = img.copy()
97
- draw = ImageDraw.Draw(img)
98
 
99
- try:
100
- font = ImageFont.truetype("Arial", size=size)
101
- except IOError:
102
- font = ImageFont.load_default()
103
-
104
- bbox = draw.textbbox(location, text, font=font)
105
- draw.rectangle(bbox, fill="white")
106
- draw.text(location, text, color, font=font)
107
-
108
- return img
109
-
110
- def get_ulip_image(self, guidance_shape_id, angle='036'):
111
- shape_id_ulip = guidance_shape_id.replace('_', '-')
112
- ulip_template = 'https://storage.googleapis.com/sfr-ulip-code-release-research/shapenet-55/only_rgb_depth_images/{}_r_{}_depth0001.png'
113
- ulip_path = ulip_template.format(shape_id_ulip, angle)
114
-
115
- try:
116
- ulip_image = load_image(ulip_path).resize((512, 512))
117
- return ulip_image
118
- except Exception as e:
119
- print(f"Error loading image: {e}")
120
- return Image.new('RGB', (512, 512), color='gray')
121
 
122
  def initialize_models(self):
123
  # device = DEVICE
@@ -181,12 +157,12 @@ class ShapeWordsDemo:
181
  self.available_categories = []
182
  self.category_counts = {}
183
 
184
- # Try to find PointBert embeddings for all 55 ShapeNetCore shape categories
185
  for category, cat_id in self.NAME2CAT.items():
186
  possible_filenames = [
187
  f"{cat_id}_pb_embs.npz",
188
  f"embeddings/{cat_id}_pb_embs.npz",
189
- f"/data/shapenet_pointbert_tokens/{cat_id}_pb_embs.npz" # if using Hugging Face persistent storage look in a /data/shapenet_pointbert_tokens directory
190
  ]
191
 
192
  found_file = None
@@ -317,8 +293,9 @@ class ShapeWordsDemo:
317
  # Load point clouds
318
  try:
319
  print(f"Loading point clouds from {pc_filename}...")
 
320
  pc_data_map = np.load(pc_filename, allow_pickle=False)
321
- pc_data = {'ids':pc_data_map['ids'], 'clouds': pc_data_map['clouds']}
322
 
323
  # Cache the loaded data
324
  self.category_point_clouds[category] = pc_data
@@ -380,7 +357,6 @@ class ShapeWordsDemo:
380
  try:
381
  preview_image = self.get_ulip_image(shape_id)
382
  preview_image = preview_image.resize((300, 300))
383
- preview_with_text = self.draw_text(preview_image, f"Shape #{shape_idx}", size=80, location=(10, 10))
384
 
385
  # Convert PIL image to plotly figure
386
  fig = go.Figure()
@@ -391,7 +367,7 @@ class ShapeWordsDemo:
391
 
392
  # Convert PIL image to base64
393
  buf = io.BytesIO()
394
- preview_with_text.save(buf, format='PNG')
395
  img_str = base64.b64encode(buf.getvalue()).decode('utf-8')
396
 
397
  # Add image to figure
@@ -407,12 +383,10 @@ class ShapeWordsDemo:
407
  )
408
 
409
  fig.update_layout(
410
- title=f"Shape #{shape_idx} (2D Preview - 3D not available)",
411
  xaxis=dict(showgrid=False, zeroline=False, visible=False, range=[0, 1]),
412
  yaxis=dict(showgrid=False, zeroline=False, visible=False, range=[0, 1], scaleanchor="x", scaleratio=1),
413
  margin=dict(l=0, r=0, b=0, t=0),
414
- height=450,
415
- width=450,
416
  plot_bgcolor='rgba(0,0,0,0)' # Transparent background
417
  )
418
 
@@ -431,10 +405,7 @@ class ShapeWordsDemo:
431
  ont=dict(size=16, color="#E53935"), # Red error text
432
  align="center"
433
  )],
434
- height=450,
435
- width=450,
436
- margin=dict(l=0, r=0, b=0, t=30, pad=0),
437
- paper_bgcolor='rgba(0,0,0,0)',
438
  plot_bgcolor='rgba(0,0,0,0)' # Transparent background
439
  )
440
  return fig
@@ -458,10 +429,7 @@ class ShapeWordsDemo:
458
  )])
459
 
460
  fig.update_layout(
461
- title=dict(text=title,
462
- xanchor='center',
463
- x=0.5
464
- ),
465
  scene=dict(
466
  # Remove all axes elements
467
  xaxis=dict(visible=False, showticklabels=False, showgrid=False, zeroline=False, showline=False,
@@ -473,7 +441,7 @@ class ShapeWordsDemo:
473
  aspectmode='data' # Maintain data aspect ratio
474
  ),
475
  # Eliminate margins
476
- margin=dict(l=0, r=0, b=0, t=30, pad=0),
477
  autosize=True,
478
  # Control modebar appearance through layout
479
  modebar=dict(
@@ -496,17 +464,40 @@ class ShapeWordsDemo:
496
 
497
  return fig
498
 
 
 
 
 
 
 
 
 
 
 
 
 
499
  def on_slider_change(self, shape_idx, category):
500
  """Update the preview when the slider changes"""
501
  max_idx = self.category_counts.get(category, 0) - 1
502
 
503
- # Get preview image
504
- preview_image = self.get_shape_preview(category, shape_idx)
505
 
506
  # Update counter text
507
  counter_text = f"Shape {shape_idx} of {max_idx}"
508
 
509
- return preview_image, counter_text, shape_idx
 
 
 
 
 
 
 
 
 
 
 
510
 
511
  def prev_shape(self, current_idx, category):
512
  """Go to previous shape"""
@@ -519,6 +510,9 @@ class ShapeWordsDemo:
519
  # Update counter text
520
  counter_text = f"Shape {new_idx} of {max_idx}"
521
 
 
 
 
522
  return new_idx, preview_image, counter_text
523
 
524
  def next_shape(self, current_idx, category):
@@ -532,6 +526,9 @@ class ShapeWordsDemo:
532
  # Update counter text
533
  counter_text = f"Shape {new_idx} of {max_idx}"
534
 
 
 
 
535
  return new_idx, preview_image, counter_text
536
 
537
  def jump_to_start(self, category):
@@ -545,6 +542,9 @@ class ShapeWordsDemo:
545
  # Update counter text
546
  counter_text = f"Shape {new_idx} of {max_idx}"
547
 
 
 
 
548
  return new_idx, preview_image, counter_text
549
 
550
  def jump_to_end(self, category):
@@ -558,6 +558,9 @@ class ShapeWordsDemo:
558
  # Update counter text
559
  counter_text = f"Shape {new_idx} of {max_idx}"
560
 
 
 
 
561
  return new_idx, preview_image, counter_text
562
 
563
  def random_shape(self, category):
@@ -575,6 +578,9 @@ class ShapeWordsDemo:
575
  # Update counter text
576
  counter_text = f"Shape {random_idx} of {max_idx}"
577
 
 
 
 
578
  return random_idx, preview_image, counter_text
579
 
580
  def random_prompt(self):
@@ -721,24 +727,12 @@ class ShapeWordsDemo:
721
  guidance_scale=7.5
722
  ).images
723
 
724
- base_image = base_images[0]
725
- base_image = self.draw_text(base_image, "Unguided result")
726
- results.append(base_image)
727
  except Exception as e:
728
  print(f"Error generating base image: {e}")
729
  status = status + f"<div style='padding: 10px; background-color: #ffebee; border-left: 5px solid #e74c3c; font-weight: bold; margin-bottom: 10px;'>⚠️ ERROR: Error generating base image: {str(e)}</div>"
730
  return results, status
731
 
732
- try:
733
- # Get shape guidance image
734
- ulip_image = self.get_ulip_image(guidance_shape_id)
735
- ulip_image = self.draw_text(ulip_image, "Guidance shape")
736
- results.append(ulip_image)
737
- except Exception as e:
738
- print(f"Error getting guidance shape: {e}")
739
- status = status + f"<div style='padding: 10px; background-color: #ffebee; border-left: 5px solid #e74c3c; font-weight: bold; margin-bottom: 10px;'>⚠️ ERROR: Error getting guidance shape: {str(e)}</div>"
740
- return results, status
741
-
742
  try:
743
  # Get shape guidance embedding
744
  pb_emb = pb_dict[guidance_shape_id]
@@ -760,9 +754,7 @@ class ShapeWordsDemo:
760
  guidance_scale=7.5
761
  ).images
762
 
763
- guided_image = guided_images[0]
764
- guided_image = self.draw_text(guided_image, f"Guided result (λ={guidance_strength:.1f})")
765
- results.append(guided_image)
766
 
767
  # Success status
768
  status = status + f"<div style='padding: 10px; background-color: #e8f5e9; border-left: 5px solid #4caf50; margin-bottom: 10px;'>✓ Successfully generated images using Shape #{selected_shape_idx} from category '{category}'.</div>"
@@ -785,107 +777,272 @@ class ShapeWordsDemo:
785
  # Ensure chair is in available categories, otherwise use the first available
786
  default_category = "chair" if "chair" in self.available_categories else self.available_categories[0]
787
 
788
- with gr.Blocks(title="ShapeWords: Guiding Text-to-Image Synthesis with 3D Shape-Aware Prompts") as demo:
789
- gr.Markdown("""
790
- # ShapeWords: Guiding Text-to-Image Synthesis with 3D Shape-Aware Prompts
791
-
792
- ShapeWords incorporates target 3D shape information with text prompts to guide image synthesis.
793
-
794
- - **Website**: [ShapeWords Project Page](https://lodurality.github.io/shapewords/)
795
- - **Paper**: [ArXiv](https://arxiv.org/abs/2412.02912)
796
- - **Publication**: Accepted to CVPR 2025
797
- """)
798
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
799
  with gr.Row():
800
- with gr.Column(scale=1):
801
- prompt = gr.Textbox(
802
- label="Prompt (use [CATEGORY] for object type)",
803
- placeholder="an aquarelle drawing of a [CATEGORY]",
804
- value=f"an aquarelle drawing of a [CATEGORY]"
805
- )
806
-
807
- # Add help text below the prompt
808
- help_text = gr.Markdown("""
809
- **Tip:** Use [CATEGORY] in your prompt where you want the selected object type to appear.
810
- For example: "a watercolor painting of a [CATEGORY] in the forest"
 
 
 
 
 
 
 
811
  """)
 
 
812
 
813
- random_prompt_btn = gr.Button("🎲 Random Prompt", size="sm", variant="secondary")
 
814
 
 
 
 
 
 
815
  category = gr.Dropdown(
816
  label="Object Category",
817
  choices=self.available_categories,
818
- value=default_category
819
- )
820
-
821
- # Hidden field to store selected shape index
822
- selected_shape_idx = gr.Number(
823
- value=0,
824
- visible=False
825
- )
826
-
827
- # Create a slider for shape selection with preview
828
- with gr.Row():
829
- with gr.Column(scale=1):
830
- # Slider for shape selection
831
- shape_slider = gr.Slider(
832
- minimum=0,
833
- maximum=self.category_counts.get(default_category, 0) - 1,
834
- step=1,
835
- value=0,
836
- label="Shape Index",
837
- interactive=True
838
- )
839
-
840
- # Display shape index counter
841
- shape_counter = gr.Markdown(f"Shape 0 of {self.category_counts.get(default_category, 0) - 1}")
842
-
843
- # Quick navigation buttons
844
- with gr.Row():
845
- jump_start_btn = gr.Button("⏮️ First", size="sm")
846
- random_btn = gr.Button("🎲 Random Shape", size="sm", variant="secondary")
847
- jump_end_btn = gr.Button("Last ⏭️", size="sm")
848
-
849
- with gr.Row():
850
- prev_shape_btn = gr.Button("◀️ Previous", size="sm")
851
- next_shape_btn = gr.Button("Next ▶️", size="sm")
852
-
853
- with gr.Column(scale=1):
854
- gr.Markdown("### Selected Shape (3D Point Cloud)")
855
- current_shape_plot = gr.Plot(
856
- label=None,
857
- scale=1, # Take up available space
858
- show_label=False,
859
- #container=False
860
- )
861
-
862
- guidance_strength = gr.Slider(
863
- minimum=0.0, maximum=1.0, step=0.1, value=0.9,
864
- label="Guidance Strength (λ)"
865
  )
866
 
867
- seed = gr.Slider(
868
- minimum=0, maximum=10000, step=1, value=42,
869
- label="Random Seed"
 
 
 
 
870
  )
871
 
872
- run_button = gr.Button("Generate Images", variant="primary")
873
-
874
- info = gr.Markdown("""
875
- **Note**: Higher guidance strength (λ) means stronger adherence to the 3D shape.
876
- Start with λ=0.9 for a good balance between shape and prompt adherence.
877
- """)
878
-
879
- status_text = gr.HTML("")
880
-
881
- with gr.Column(scale=2):
882
- gallery = gr.Gallery(
883
- label="Results",
884
- show_label=True,
885
- elem_id="results_gallery",
886
- columns=3,
887
- height="auto"
888
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
889
 
890
  # Make sure the initial image is loaded when the demo starts
891
  demo.load(
@@ -950,39 +1107,34 @@ class ShapeWordsDemo:
950
  outputs=[shape_slider, selected_shape_idx, current_shape_plot, shape_counter]
951
  )
952
 
953
- # Clear status text before generating new images
954
  run_button.click(
955
- fn=lambda: None, # Empty function to clear the status
 
 
 
956
  inputs=None,
957
  outputs=[status_text]
958
  )
959
 
960
  # Generate images when button is clicked
961
  run_button.click(
962
- fn=self.generate_images,
 
 
 
 
 
 
 
 
 
 
 
963
  inputs=[prompt, category, selected_shape_idx, guidance_strength, seed],
964
  outputs=[gallery, status_text]
965
  )
966
 
967
- gr.Markdown("""
968
- ## Credits
969
-
970
- This demo is based on the ShapeWords paper by Petrov et al. (2024) accepted to CVPR 2025.
971
-
972
- If you use this in your work, please cite:
973
- ```
974
- @misc{petrov2024shapewords,
975
- title={ShapeWords: Guiding Text-to-Image Synthesis with 3D Shape-Aware Prompts},
976
- author={Dmitry Petrov and Pradyumn Goyal and Divyansh Shivashok and Yuanming Tao and Melinos Averkiou and Evangelos Kalogerakis},
977
- year={2024},
978
- eprint={2412.02912},
979
- archivePrefix={arXiv},
980
- primaryClass={cs.CV},
981
- url={https://arxiv.org/abs/2412.02912},
982
- }
983
- ```
984
- """)
985
-
986
  return demo
987
 
988
 
 
6
  images guided by 3D shape information.
7
 
8
  Author: Melinos Averkiou
9
+ Date: 24 March 2025
10
+ Version: 1.5
11
 
12
  Paper: "ShapeWords: Guiding Text-to-Image Synthesis with 3D Shape-Aware Prompts"
13
  arXiv: https://arxiv.org/abs/2412.02912
 
30
  python app.py [--share]
31
 
32
  This demo allows users to:
33
+ 1. Select a 3D object category from ShapeNetCore
34
+ 2. Choose a specific 3D shape using a slider or the navigation buttons (including a random shape button)
35
+ 3. Enter a text prompt or pick a random one
36
+ 4. Generate images guided by the selected 3D shape and the text prompt
37
 
38
  The code is structured as a class and is compatible with Hugging Face ZeroGPU deployment.
39
  """
 
89
  self.shape_thumbnail_cache = {} # Cache for shape thumbnails
90
  self.CAT2NAME = {v: k for k, v in self.NAME2CAT.items()}
91
  self.category_point_clouds = {}
92
+ self.from_navigation = False
93
  # Initialize all models and data
94
  self.initialize_models()
95
 
 
 
 
96
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
 
98
  def initialize_models(self):
99
  # device = DEVICE
 
157
  self.available_categories = []
158
  self.category_counts = {}
159
 
160
+ # Try to find PointBert embeddings for all 55 ShapeNetCore shape categories
161
  for category, cat_id in self.NAME2CAT.items():
162
  possible_filenames = [
163
  f"{cat_id}_pb_embs.npz",
164
  f"embeddings/{cat_id}_pb_embs.npz",
165
+ f"/data/shapenet_pointbert_tokens/{cat_id}_pb_embs.npz" # if using Hugging Face persistent storage look in a /data/shapenet_pointbert_tokens directory
166
  ]
167
 
168
  found_file = None
 
293
  # Load point clouds
294
  try:
295
  print(f"Loading point clouds from {pc_filename}...")
296
+
297
  pc_data_map = np.load(pc_filename, allow_pickle=False)
298
+ pc_data = {'ids':pc_data_map['ids'], 'clouds': pc_data_map['clouds']}
299
 
300
  # Cache the loaded data
301
  self.category_point_clouds[category] = pc_data
 
357
  try:
358
  preview_image = self.get_ulip_image(shape_id)
359
  preview_image = preview_image.resize((300, 300))
 
360
 
361
  # Convert PIL image to plotly figure
362
  fig = go.Figure()
 
367
 
368
  # Convert PIL image to base64
369
  buf = io.BytesIO()
370
+ preview_image.save(buf, format='PNG')
371
  img_str = base64.b64encode(buf.getvalue()).decode('utf-8')
372
 
373
  # Add image to figure
 
383
  )
384
 
385
  fig.update_layout(
386
+ title=f"Shape 2D Preview - 3D not available",
387
  xaxis=dict(showgrid=False, zeroline=False, visible=False, range=[0, 1]),
388
  yaxis=dict(showgrid=False, zeroline=False, visible=False, range=[0, 1], scaleanchor="x", scaleratio=1),
389
  margin=dict(l=0, r=0, b=0, t=0),
 
 
390
  plot_bgcolor='rgba(0,0,0,0)' # Transparent background
391
  )
392
 
 
405
  ont=dict(size=16, color="#E53935"), # Red error text
406
  align="center"
407
  )],
408
+ margin=dict(l=0, r=0, b=0, t=0, pad=0),
 
 
 
409
  plot_bgcolor='rgba(0,0,0,0)' # Transparent background
410
  )
411
  return fig
 
429
  )])
430
 
431
  fig.update_layout(
432
+ title=None,
 
 
 
433
  scene=dict(
434
  # Remove all axes elements
435
  xaxis=dict(visible=False, showticklabels=False, showgrid=False, zeroline=False, showline=False,
 
441
  aspectmode='data' # Maintain data aspect ratio
442
  ),
443
  # Eliminate margins
444
+ margin=dict(l=0, r=0, b=0, t=0, pad=0),
445
  autosize=True,
446
  # Control modebar appearance through layout
447
  modebar=dict(
 
464
 
465
  return fig
466
 
467
+ def get_ulip_image(self, guidance_shape_id, angle='036'):
468
+ shape_id_ulip = guidance_shape_id.replace('_', '-')
469
+ ulip_template = 'https://storage.googleapis.com/sfr-ulip-code-release-research/shapenet-55/only_rgb_depth_images/{}_r_{}_depth0001.png'
470
+ ulip_path = ulip_template.format(shape_id_ulip, angle)
471
+
472
+ try:
473
+ ulip_image = load_image(ulip_path).resize((512, 512))
474
+ return ulip_image
475
+ except Exception as e:
476
+ print(f"Error loading image: {e}")
477
+ return Image.new('RGB', (512, 512), color='gray')
478
+
479
  def on_slider_change(self, shape_idx, category):
480
  """Update the preview when the slider changes"""
481
  max_idx = self.category_counts.get(category, 0) - 1
482
 
483
+ # Get shape preview
484
+ shape_preview = self.get_shape_preview(category, shape_idx)
485
 
486
  # Update counter text
487
  counter_text = f"Shape {shape_idx} of {max_idx}"
488
 
489
+ return shape_preview, counter_text, shape_idx
490
+
491
+ def on_slider_change_no_update(self, shape_idx, category):
492
+ """Handle slider change without updating the plot (used when navigation buttons are clicked)"""
493
+ if self.from_navigation:
494
+ self.from_navigation = False
495
+ # Return the same values without recalculating
496
+ max_idx = self.category_counts.get(category, 0) - 1
497
+ return None, f"Shape {shape_idx} of {max_idx}", shape_idx
498
+ else:
499
+ # Normal processing when slider is moved directly
500
+ return self.on_slider_change(shape_idx, category)
501
 
502
  def prev_shape(self, current_idx, category):
503
  """Go to previous shape"""
 
510
  # Update counter text
511
  counter_text = f"Shape {new_idx} of {max_idx}"
512
 
513
+ # Set a flag to indicate this update came from navigation
514
+ self.from_navigation = True
515
+
516
  return new_idx, preview_image, counter_text
517
 
518
  def next_shape(self, current_idx, category):
 
526
  # Update counter text
527
  counter_text = f"Shape {new_idx} of {max_idx}"
528
 
529
+ # Set a flag to indicate this update came from navigation
530
+ self.from_navigation = True
531
+
532
  return new_idx, preview_image, counter_text
533
 
534
  def jump_to_start(self, category):
 
542
  # Update counter text
543
  counter_text = f"Shape {new_idx} of {max_idx}"
544
 
545
+ # Set a flag to indicate this update came from navigation
546
+ self.from_navigation = True
547
+
548
  return new_idx, preview_image, counter_text
549
 
550
  def jump_to_end(self, category):
 
558
  # Update counter text
559
  counter_text = f"Shape {new_idx} of {max_idx}"
560
 
561
+ # Set a flag to indicate this update came from navigation
562
+ self.from_navigation = True
563
+
564
  return new_idx, preview_image, counter_text
565
 
566
  def random_shape(self, category):
 
578
  # Update counter text
579
  counter_text = f"Shape {random_idx} of {max_idx}"
580
 
581
+ # Set a flag to indicate this update came from navigation
582
+ self.from_navigation = True
583
+
584
  return random_idx, preview_image, counter_text
585
 
586
  def random_prompt(self):
 
727
  guidance_scale=7.5
728
  ).images
729
 
730
+ results.append(base_images[0])
 
 
731
  except Exception as e:
732
  print(f"Error generating base image: {e}")
733
  status = status + f"<div style='padding: 10px; background-color: #ffebee; border-left: 5px solid #e74c3c; font-weight: bold; margin-bottom: 10px;'>⚠️ ERROR: Error generating base image: {str(e)}</div>"
734
  return results, status
735
 
 
 
 
 
 
 
 
 
 
 
736
  try:
737
  # Get shape guidance embedding
738
  pb_emb = pb_dict[guidance_shape_id]
 
754
  guidance_scale=7.5
755
  ).images
756
 
757
+ results.append(guided_images[0])
 
 
758
 
759
  # Success status
760
  status = status + f"<div style='padding: 10px; background-color: #e8f5e9; border-left: 5px solid #4caf50; margin-bottom: 10px;'>✓ Successfully generated images using Shape #{selected_shape_idx} from category '{category}'.</div>"
 
777
  # Ensure chair is in available categories, otherwise use the first available
778
  default_category = "chair" if "chair" in self.available_categories else self.available_categories[0]
779
 
780
+ with gr.Blocks(title="ShapeWords: Guiding Text-to-Image Synthesis with 3D Shape-Aware Prompts",
781
+ theme=gr.themes.Soft(
782
+ primary_hue="orange",
783
+ secondary_hue="blue",
784
+ font=[gr.themes.GoogleFont("Inter"), "ui-sans-serif", "system-ui", "sans-serif"],
785
+ font_mono=[gr.themes.GoogleFont("IBM Plex Mono"), "ui-monospace", "Consolas", "monospace"],
786
+ ),
787
+ css="""
788
+ /* Base styles */
789
+ .container { max-width: 1400px; margin: 0 auto; }
790
+
791
+ /* Typography */
792
+ .title { text-align: center; font-size: 26px; font-weight: 600; margin-bottom: 3px; }
793
+ .subtitle { text-align: center; font-size: 16px; margin-bottom: 3px; }
794
+ .authors { text-align: center; font-size: 15px; margin-bottom: 3px; }
795
+ .affiliations { text-align: center; font-size: 13px; margin-bottom: 5px; }
796
+
797
+ /* Buttons */
798
+ .buttons-container { margin: 0 auto 10px; }
799
+ .buttons-row { display: flex; justify-content: center; gap: 10px; }
800
+ .nav-button {
801
+ display: inline-block;
802
+ padding: 6px 12px;
803
+ background-color: #363636;
804
+ color: white !important;
805
+ text-decoration: none;
806
+ border-radius: 20px;
807
+ font-weight: 500;
808
+ font-size: 14px;
809
+ transition: background-color 0.2s;
810
+ text-align: center;
811
+ }
812
+ .nav-button:hover { background-color: #505050; }
813
+ .nav-button.disabled {
814
+ opacity: 0.6;
815
+ cursor: not-allowed;
816
+ }
817
+
818
+ /* Form elements */
819
+ .prompt-text { font-size: 16px; }
820
+ .instruction-text { font-size: 15px; padding: 10px; border-radius: 8px; background-color: rgba(255, 165, 0, 0.1); }
821
+ .shape-navigation {
822
+ display: flex;
823
+ justify-content: center;
824
+ align-items: center;
825
+ margin: 10px auto;
826
+ gap: 15px;
827
+ max-width: 320px;
828
+ }
829
+ .shape-navigation button {
830
+ min-width: 40px;
831
+ max-width: 60px;
832
+ width: auto;
833
+ padding: 6px 10px;
834
+ }
835
+ .nav-icon-btn { font-size: 18px; }
836
+ .category-dropdown .wrap { font-size: 16px; }
837
+ .generate-button { font-size: 18px !important; padding: 12px !important; margin: 15px 0 !important; }
838
+ .slider-label { font-size: 16px; }
839
+ .slider-text { font-size: 14px; margin-top: 5px; }
840
+ .about-section { font-size: 16px; margin-top: 40px; padding: 20px; border-top: 1px solid rgba(128, 128, 128, 0.2); }
841
+ .status-message { background-color: rgba(0, 128, 0, 0.1); color: #006400; padding: 10px; border-radius: 4px; margin-top: 10px; }
842
+ .prompt-container { display: flex; align-items: center; }
843
+ .prompt-input { flex-grow: 1; }
844
+ .prompt-button { margin-left: 10px; align-self: center; }
845
+ .results-gallery { min-height: 100px; max-height: 500px; }
846
+
847
+ /* Responsive adjustments */
848
+ @media (max-width: 768px) {
849
+ .shape-navigation {
850
+ max-width: 100%;
851
+ gap: 5px;
852
+ }
853
+ .shape-navigation button {
854
+ min-width: 36px;
855
+ padding: 6px 0;
856
+ font-size: 16px;
857
+ }
858
+ .buttons-row {
859
+ flex-wrap: wrap;
860
+ }
861
+ .nav-button {
862
+ margin-bottom: 5px;
863
+ }
864
+ .results-gallery {
865
+ max-height: 320px;
866
+ }
867
+ }
868
+
869
+ /* Dark mode overrides */
870
+ @media (prefers-color-scheme: dark) {
871
+ .nav-button {
872
+ background-color: #505050;
873
+ }
874
+ .nav-button:hover {
875
+ background-color: #666666;
876
+ }
877
+ }
878
+ """) as demo:
879
+ # Header with title and links
880
+ gr.Markdown("# ShapeWords: Guiding Text-to-Image Synthesis with 3D Shape-Aware Prompts",
881
+ elem_classes="title")
882
+ gr.Markdown("### CVPR 2025", elem_classes="subtitle")
883
+ gr.Markdown(
884
+ "Dmitry Petrov<sup>1</sup>, Pradyumn Goyal<sup>1</sup>, Divyansh Shivashok<sup>1</sup>, Yuanming Tao<sup>1</sup>, Melinos Averkiou<sup>2,3</sup>, Evangelos Kalogerakis<sup>1,2,4</sup>",
885
+ elem_classes="authors")
886
+ gr.Markdown(
887
+ "<sup>1</sup>UMass Amherst <sup>2</sup>CYENS CoE <sup>3</sup>University of Cyprus <sup>4</sup>TU Crete",
888
+ elem_classes="affiliations")
889
+
890
+ # Navigation buttons
891
  with gr.Row():
892
+ with gr.Column(scale=3):
893
+ pass # Empty space for alignment
894
+ with gr.Column(scale=2, elem_classes="buttons-container"):
895
+ gr.HTML("""
896
+ <div class="buttons-row">
897
+ <a href="https://arxiv.org/abs/2412.02912" target="_blank" class="nav-button">
898
+ arXiv
899
+ </a>
900
+ <a href="https://lodurality.github.io/shapewords/" target="_blank" class="nav-button">
901
+ Project Page
902
+ </a>
903
+ <a href="#" target="_blank" class="nav-button disabled">
904
+ Code
905
+ </a>
906
+ <a href="#" target="_blank" class="nav-button disabled">
907
+ Data
908
+ </a>
909
+ </div>
910
  """)
911
+ with gr.Column(scale=3):
912
+ pass # Empty space for alignment
913
 
914
+ # Hidden field to store selected shape index
915
+ selected_shape_idx = gr.Number(value=0, visible=False)
916
 
917
+ # Prompt Design (full width)
918
+ with gr.Group():
919
+ gr.Markdown("### 📝 Prompt Design")
920
+
921
+ with gr.Row():
922
  category = gr.Dropdown(
923
  label="Object Category",
924
  choices=self.available_categories,
925
+ value=default_category,
926
+ container=True,
927
+ elem_classes="category-dropdown",
928
+ scale=2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
929
  )
930
 
931
+ prompt = gr.Textbox(
932
+ label="Prompt",
933
+ placeholder="an aquarelle drawing of a [CATEGORY]",
934
+ value="an aquarelle drawing of a [CATEGORY]",
935
+ lines=1,
936
+ scale=5,
937
+ elem_classes="prompt-input"
938
  )
939
 
940
+ random_prompt_btn = gr.Button("🎲 Random\nPrompt",
941
+ size="lg",
942
+ scale=1,
943
+ elem_classes="prompt-button")
944
+
945
+ gr.Markdown("""
946
+ Use [CATEGORY] in your prompt where you want the selected object type to appear.
947
+ For example: "a watercolor painting of a [CATEGORY] in the forest"
948
+ """, elem_classes="instruction-text")
949
+
950
+ # Middle section - Shape Selection and Results side by side
951
+ with gr.Row(equal_height=False):
952
+ # Left column - Shape Selection
953
+ with gr.Column():
954
+ with gr.Group():
955
+ gr.Markdown("### 🔍 Shape Selection")
956
+
957
+ shape_slider = gr.Slider(
958
+ minimum=0,
959
+ maximum=self.category_counts.get(default_category, 0) - 1,
960
+ step=1,
961
+ value=0,
962
+ label="Shape Index",
963
+ interactive=True
964
+ )
965
+
966
+ shape_counter = gr.Markdown(f"Shape 0 of {self.category_counts.get(default_category, 0) - 1}")
967
+
968
+ gr.Markdown("### Selected Shape (3D Point Cloud)")
969
+
970
+ current_shape_plot = gr.Plot(show_label=False)
971
+
972
+ # Navigation buttons - Icons only for better mobile compatibility
973
+ with gr.Row(elem_classes="shape-navigation"):
974
+ jump_start_btn = gr.Button("⏮️", size="sm", elem_classes="nav-icon-btn")
975
+ prev_shape_btn = gr.Button("◀️", size="sm", elem_classes="nav-icon-btn")
976
+ random_btn = gr.Button("🎲", size="sm", variant="secondary", elem_classes="nav-icon-btn")
977
+ next_shape_btn = gr.Button("▶️", size="sm", elem_classes="nav-icon-btn")
978
+ jump_end_btn = gr.Button("⏭️", size="sm", elem_classes="nav-icon-btn")
979
+
980
+ # Right column - Results
981
+ with gr.Column():
982
+ with gr.Group():
983
+ gr.Markdown("### 🖼️ Generated Results")
984
+ gallery = gr.Gallery(
985
+ label="Results",
986
+ show_label=False,
987
+ elem_id="results_gallery",
988
+ columns=2,
989
+ height="auto",
990
+ object_fit="contain",
991
+ elem_classes="results-gallery"
992
+ )
993
+
994
+ # Generate button (full width)
995
+ with gr.Row():
996
+ run_button = gr.Button("✨ Generate Images", variant="primary", size="lg",
997
+ elem_classes="generate-button")
998
+
999
+ # Generation Settings (full width)
1000
+ with gr.Group():
1001
+ gr.Markdown("### ⚙️ Generation Settings")
1002
+
1003
+ with gr.Row():
1004
+ with gr.Column():
1005
+ guidance_strength = gr.Slider(
1006
+ minimum=0.0, maximum=1.0, step=0.1, value=0.9,
1007
+ label="Guidance Strength (λ) - Higher λ = stronger shape adherence"
1008
+ )
1009
+ with gr.Column():
1010
+ seed = gr.Slider(
1011
+ minimum=0, maximum=10000, step=1, value=42,
1012
+ label="Random Seed"
1013
+ )
1014
+
1015
+ status_text = gr.HTML("", elem_classes="status-message")
1016
+
1017
+ # About section at the bottom of the page
1018
+ with gr.Group(elem_classes="about-section"):
1019
+ gr.Markdown("""
1020
+ ## About ShapeWords
1021
+
1022
+ ShapeWords incorporates target 3D shape information with text prompts to guide image synthesis.
1023
+
1024
+ ### How It Works
1025
+ 1. Select an object category from the dropdown menu
1026
+ 2. Browse through available 3D shapes using the slider or navigation buttons
1027
+ 3. Create a text prompt using [CATEGORY] as a placeholder
1028
+ 4. Adjust guidance strength to control shape influence
1029
+ 5. Click Generate to create images that follow both your text prompt and the selected 3D shape
1030
+
1031
+ ### Citation
1032
+ ```
1033
+ @misc{petrov2024shapewords,
1034
+ title={ShapeWords: Guiding Text-to-Image Synthesis with 3D Shape-Aware Prompts},
1035
+ author={Dmitry Petrov and Pradyumn Goyal and Divyansh Shivashok and Yuanming Tao and Melinos Averkiou and Evangelos Kalogerakis},
1036
+ year={2024},
1037
+ eprint={2412.02912},
1038
+ archivePrefix={arXiv},
1039
+ primaryClass={cs.CV},
1040
+ url={https://arxiv.org/abs/2412.02912},
1041
+ }
1042
+ ```
1043
+ """)
1044
+
1045
+ # Connect components
1046
 
1047
  # Make sure the initial image is loaded when the demo starts
1048
  demo.load(
 
1107
  outputs=[shape_slider, selected_shape_idx, current_shape_plot, shape_counter]
1108
  )
1109
 
1110
+ # Update status text when generating
1111
  run_button.click(
1112
+ fn=lambda: """<div style='color: #00cc00; background-color: #1a1a1a;
1113
+ border: 1px solid #2a2a2a; padding: 10px; border-radius: 4px;
1114
+ margin-top: 10px; font-weight: bold;'>
1115
+ Generating images...</div>""",
1116
  inputs=None,
1117
  outputs=[status_text]
1118
  )
1119
 
1120
  # Generate images when button is clicked
1121
  run_button.click(
1122
+ fn=lambda p, c, s_idx, g, seed: [
1123
+ [
1124
+ (img, caption) for img, caption in zip(
1125
+ self.generate_images(p, c, s_idx, g, seed)[0],
1126
+ [f"Unguided Result", f"Guided Result (λ = {g})"]
1127
+ )
1128
+ ], # Gallery images with captions
1129
+ f"""<div style="color: #00cc00; background-color: #1a1a1a;
1130
+ border: 1px solid #2a2a2a; padding: 10px; border-radius: 4px;
1131
+ margin-top: 10px; font-weight: bold;">
1132
+ ✓ Successfully generated images using Shape #{s_idx} from category '{c}'.</div>"""
1133
+ ],
1134
  inputs=[prompt, category, selected_shape_idx, guidance_strength, seed],
1135
  outputs=[gallery, status_text]
1136
  )
1137
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1138
  return demo
1139
 
1140