shapewords / app.py
dmpetrov's picture
better check
1725499
"""
ShapeWords: Guiding Text-to-Image Synthesis with 3D Shape-Aware Prompts
=======================================================================
A Gradio web interface for the ShapeWords paper, allowing users to generate
images guided by 3D shape information.
Author: Melinos Averkiou
Date: 24 March 2025
Version: 1.5
Paper: "ShapeWords: Guiding Text-to-Image Synthesis with 3D Shape-Aware Prompts"
arXiv: https://arxiv.org/abs/2412.02912
Project Page: https://lodurality.github.io/shapewords/
Citation:
@misc{petrov2024shapewords,
title={ShapeWords: Guiding Text-to-Image Synthesis with 3D Shape-Aware Prompts},
author={Dmitry Petrov and Pradyumn Goyal and Divyansh Shivashok and Yuanming Tao and Melinos Averkiou and Evangelos Kalogerakis},
year={2024},
eprint={2412.02912},
archivePrefix={arXiv},
primaryClass={cs.CV},
url={https://arxiv.org/abs/2412.02912},
}
License: MIT License
Usage:
python app.py [--share]
This demo allows users to:
1. Select a 3D object category from ShapeNetCore
2. Choose a specific 3D shape using a slider or the navigation buttons (including a random shape button)
3. Enter a text prompt or pick a random one
4. Generate images guided by the selected 3D shape and the text prompt
The code is structured as a class and is compatible with Hugging Face ZeroGPU deployment.
"""
import os
import sys
import numpy as np
import torch
import gradio as gr
from PIL import Image, ImageFont, ImageDraw
from diffusers.utils import load_image
from diffusers import DPMSolverMultistepScheduler, StableDiffusionPipeline
import gdown
import argparse
import random
import spaces # for Hugging Face ZeroGPU deployment
import re
import plotly.graph_objects as go
from numpy.lib.user_array import container
import shutil
# Only for Hugging Face hosting - Add the Hugging Face cache to persistent storage to avoid downloading safetensors every time the demo sleeps and wakes up
os.environ['HF_HOME'] = '/data/.huggingface'
class ShapeWordsDemo:
# Constants
NAME2CAT = {
"chair": "03001627", "table": "04379243", "jar": "03593526", "skateboard": "04225987",
"car": "02958343", "bottle": "02876657", "tower": "04460130", "bookshelf": "02871439",
"camera": "02942699", "airplane": "02691156", "laptop": "03642806", "basket": "02801938",
"sofa": "04256520", "knife": "03624134", "can": "02946921", "rifle": "04090263",
"train": "04468005", "pillow": "03938244", "lamp": "03636649", "trash bin": "02747177",
"mailbox": "03710193", "watercraft": "04530566", "motorbike": "03790512",
"dishwasher": "03207941", "bench": "02828884", "pistol": "03948459", "rocket": "04099429",
"loudspeaker": "03691459", "file cabinet": "03337140", "bag": "02773838",
"cabinet": "02933112", "bed": "02818832", "birdhouse": "02843684", "display": "03211117",
"piano": "03928116", "earphone": "03261776", "telephone": "04401088", "stove": "04330267",
"microphone": "03759954", "bus": "02924116", "mug": "03797390", "remote": "04074963",
"bathtub": "02808440", "bowl": "02880940", "keyboard": "03085013", "guitar": "03467517",
"washer": "04554684", "bicycle": "02834778", "faucet": "03325088", "printer": "04004475",
"cap": "02954340", "phone": "02992529", "clock": "03046257", "helmet": "03513137",
"microwave": "03761084", "plant": "03991062"
}
PREDEFINED_PROMPTS = [
'a low poly 3d rendering of a [CATEGORY]',
'an aquarelle drawing of a [CATEGORY]',
'a photo of a [CATEGORY] on a beach',
'a charcoal drawing of a [CATEGORY]',
'a Hieronymus Bosch painting of a [CATEGORY]',
'a [CATEGORY] under a tree',
'A Kazimir Malevich painting of a [CATEGORY]',
'a vector graphic of a [CATEGORY]',
'a Claude Monet painting of a [CATEGORY]',
'a Salvador Dali painting of a [CATEGORY]',
'an Art Deco poster of a [CATEGORY]'
]
def __init__(self):
# Initialize class attributes
self.pipeline = None
self.shape2clip_model = None
self.text_encoder = None
self.tokenizer = None
self.category_embeddings = {}
self.category_counts = {}
self.available_categories = []
self.shape_thumbnail_cache = {} # Cache for shape thumbnails
self.CAT2NAME = {v: k for k, v in self.NAME2CAT.items()}
self.category_point_clouds = {}
# Initialize all models and data
self.initialize_models()
def initialize_models(self):
# device = DEVICE
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device} in initialize_models")
# Download Shape2CLIP code if it doesn't exist
if not os.path.exists("shapewords_paper_code/geometry_guidance_models.py"):
shutil.rmtree("shapewords_paper_code/", ignore_errors=True)
print("Loading models file")
os.system("git clone https://github.com/lodurality/shapewords_paper_code.git")
# Import Shape2CLIP model
sys.path.append("./shapewords_paper_code")
from shapewords_paper_code.geometry_guidance_models import Shape2CLIP
# Initialize the pipeline
self.pipeline = StableDiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-2-1-base",
torch_dtype=torch.float16 if device.type == "cuda" else torch.float32
)
self.pipeline.scheduler = DPMSolverMultistepScheduler.from_config(
self.pipeline.scheduler.config,
algorithm_type="sde-dpmsolver++"
)
self.text_encoder = self.pipeline.text_encoder
self.tokenizer = self.pipeline.tokenizer
# Look for Shape2CLIP checkpoint in multiple locations
checkpoint_paths = [
"./projection_model-0920192.pth",
"/data/projection_model-0920192.pth" # if using Hugging Face persistent storage look in a /data/ directory
]
checkpoint_found = False
checkpoint_path = None
for path in checkpoint_paths:
if os.path.exists(path):
checkpoint_path = path
print(f"Found Shape2CLIP checkpoint at: {checkpoint_path}")
checkpoint_found = True
break
# Download Shape2CLIP checkpoint if not found
if not checkpoint_found:
checkpoint_path = "projection_model-0920192.pth"
print("Downloading Shape2CLIP model checkpoint...")
gdown.download("https://drive.google.com/uc?id=1nvEXnwMpNkRts6rxVqMZt8i9FZ40KjP7", checkpoint_path, quiet=False) # download in same directory as app.py
print("Download complete")
# Initialize Shape2CLIP model
self.shape2clip_model = Shape2CLIP(depth=6, drop_path_rate=0.1, pb_dim=384)
self.shape2clip_model.load_state_dict(torch.load(checkpoint_path, map_location=device))
self.shape2clip_model.eval()
# Scan for available embeddings
self.scan_available_embeddings()
def scan_available_embeddings(self):
self.available_categories = []
self.category_counts = {}
# Try to find PointBert embeddings for all 55 ShapeNetCore shape categories
for category, cat_id in self.NAME2CAT.items():
possible_filenames = [
f"{cat_id}_pb_embs.npz",
f"embeddings/{cat_id}_pb_embs.npz",
f"/data/shapenet_pointbert_tokens/{cat_id}_pb_embs.npz" # if using Hugging Face persistent storage look in a /data/shapenet_pointbert_tokens directory
]
found_file = None
for filename in possible_filenames:
if os.path.exists(filename):
found_file = filename
break
if found_file:
try:
pb_data = np.load(found_file)
if 'ids' in pb_data:
count = len(pb_data['ids'])
else:
# Try to infer the correct keys
keys = list(pb_data.keys())
if len(keys) >= 1:
count = len(pb_data[keys[0]])
else:
count = 0
if count > 0:
self.available_categories.append(category)
self.category_counts[category] = count
print(f"Found {count} embeddings for category '{category}'")
except Exception as e:
print(f"Error loading embeddings for {category}: {e}")
# Sort categories alphabetically
self.available_categories.sort()
print(f"Found {len(self.available_categories)} categories with embeddings")
print(f"Available categories: {', '.join(self.available_categories)}")
# No embeddings found for any category - DEMO CANNOT RUN - but still load the interface with a default placeholder category, an error will be displayed when trying to generate images
if not self.available_categories:
self.available_categories = ["chair"] # Fallback
self.category_counts["chair"] = 50 # Default value
def load_category_embeddings(self, category):
if category in self.category_embeddings:
return self.category_embeddings[category]
if category not in self.NAME2CAT:
return None, []
cat_id = self.NAME2CAT[category]
# Check for different possible embedding filenames and locations
possible_filenames = [
f"{cat_id}_pb_embs.npz",
f"embeddings/{cat_id}_pb_embs.npz",
f"/data/shapenet_pointbert_tokens/{cat_id}_pb_embs.npz" # if using Hugging Face persistent storage look in a /data/shapenet_pointbert_tokens directory
]
# Find the first existing file
pb_emb_filename = None
for filename in possible_filenames:
if os.path.exists(filename):
pb_emb_filename = filename
print(f"Found embeddings file: {pb_emb_filename}")
break
if pb_emb_filename is None:
print(f"No embeddings found for {category}")
return None, []
# Load embeddings
try:
print(f"Loading embeddings from {pb_emb_filename}...")
pb_data = np.load(pb_emb_filename)
# Check for different key names in the NPZ file
if 'ids' in pb_data and 'embs' in pb_data:
pb_dict = dict(zip(pb_data['ids'], pb_data['embs']))
else:
# Try to infer the correct keys
keys = list(pb_data.keys())
if len(keys) >= 2:
# Assume first key is for IDs and second is for embeddings
pb_dict = dict(zip(pb_data[keys[0]], pb_data[keys[1]]))
else:
print("Unexpected embedding file format")
return None, []
all_ids = sorted(list(pb_dict.keys()))
print(f"Loaded {len(all_ids)} shape embeddings for {category}")
# Cache the results
self.category_embeddings[category] = (pb_dict, all_ids)
return pb_dict, all_ids
except Exception as e:
print(f"Error loading embeddings: {e}")
print(f"Exception details: {str(e)}")
return None, []
def load_category_point_clouds(self, category):
"""Load all point clouds for a category from a single NPZ file"""
if category not in self.NAME2CAT:
return None
cat_id = self.NAME2CAT[category]
# Cache to avoid reloading
if category in self.category_point_clouds:
return self.category_point_clouds[category]
# Check for different possible point cloud filenames
possible_filenames = [
f"{cat_id}.npz",
f"point_clouds/{cat_id}_clouds.npz",
f"/point_clouds/{cat_id}_clouds.npz",
f"/data/point_clouds/{cat_id}_clouds.npz" # For Hugging Face persistent storage
]
# Find the first existing file
pc_filename = None
for filename in possible_filenames:
if os.path.exists(filename):
pc_filename = filename
print(f"Found point cloud file: {pc_filename}")
break
if pc_filename is None:
print(f"No point cloud file found for category {category}")
return None
# Load point clouds
try:
print(f"Loading point clouds from {pc_filename}...")
pc_data_map = np.load(pc_filename, allow_pickle=False)
pc_data = {'ids': pc_data_map['ids'], 'clouds': pc_data_map['clouds']}
# Cache the loaded data
self.category_point_clouds[category] = pc_data
return pc_data
except Exception as e:
print(f"Error loading point clouds: {e}")
return None
def get_shape_preview(self, category, shape_idx):
"""Get a 3D point cloud visualization for a specific shape"""
if shape_idx is None or shape_idx < 0:
return None
# Get shape ID
pb_dict, all_ids = self.load_category_embeddings(category)
if pb_dict is None or not all_ids or shape_idx >= len(all_ids):
return None
shape_id = all_ids[shape_idx]
# Load all point clouds for this category
pc_data = self.load_category_point_clouds(category)
if pc_data is None:
# Fallback to image if point clouds not available
return self.get_shape_image_preview(category, shape_idx, shape_id)
# Extract point cloud for this specific shape
try:
# Get the arrays from the npz file
ids = pc_data['ids']
clouds = pc_data['clouds']
matching_indices = np.where(ids == shape_id)[0]
# Check number of matches
if len(matching_indices) == 0:
# No matches found - log error and fall back to image
print(f"Error: Shape ID {shape_id} not found in point cloud data")
return self.get_shape_image_preview(category, shape_idx, shape_id)
elif len(matching_indices) > 1:
# Multiple matches found - unexpected data issue - we will get the first one
print(f"Warning: Multiple matches ({len(matching_indices)}) found for Shape ID {shape_id}. Using first match.")
# Get the corresponding point cloud
matching_idx = matching_indices[0]
points = clouds[matching_idx]
# Create 3D visualization
fig = self.get_shape_pointcloud_preview(points, title=f"Shape #{shape_idx}")
return fig
except Exception as e:
print(f"Error extracting point cloud for {shape_id}: {e}")
return self.get_shape_image_preview(category, shape_idx, shape_id)
def get_shape_image_preview(self, category, shape_idx, shape_id):
"""Fallback to image preview if point cloud not available"""
try:
preview_image = self.get_ulip_image(shape_id)
preview_image = preview_image.resize((300, 300))
# Convert PIL image to plotly figure
fig = go.Figure()
# Need to convert PIL image to a format plotly can use
import io
import base64
# Convert PIL image to base64
buf = io.BytesIO()
preview_image.save(buf, format='PNG')
img_str = base64.b64encode(buf.getvalue()).decode('utf-8')
# Add image to figure
fig.add_layout_image(
dict(
source=f"data:image/png;base64,{img_str}",
xref="paper", yref="paper",
x=0, y=1,
sizex=1, sizey=1,
sizing="contain",
layer="below"
)
)
fig.update_layout(
title=f"Shape 2D Preview - 3D not available",
xaxis=dict(showgrid=False, zeroline=False, visible=False, range=[0, 1]),
yaxis=dict(showgrid=False, zeroline=False, visible=False, range=[0, 1], scaleanchor="x", scaleratio=1),
margin=dict(l=0, r=0, b=0, t=0),
plot_bgcolor='rgba(0,0,0,0)' # Transparent background
)
return fig
except Exception as e:
print(f"Error loading preview for {shape_id}: {e}")
# Create empty figure with error message
fig = go.Figure()
fig.update_layout(
title=f"Error loading Shape #{shape_idx}",
annotations=[dict(
text="Preview not available",
showarrow=False,
xref="paper", yref="paper",
x=0.5, y=0.5,
ont=dict(size=16, color="#E53935"), # Red error text
align="center"
)],
margin=dict(l=0, r=0, b=0, t=0, pad=0),
plot_bgcolor='rgba(0,0,0,0)' # Transparent background
)
return fig
def get_shape_pointcloud_preview(self, points, title=None):
"""Create a clean 3D point cloud visualization with Y as up axis"""
# Sample points for better performance (fewer points = smoother interaction)
sampled_points = points[::1] # Take every Nth point
# Create 3D scatter plot with fixed color
fig = go.Figure(data=[go.Scatter3d(
x=sampled_points[:, 0],
y=sampled_points[:, 1], # Use Z as Y (up axis)
z=sampled_points[:, 2], # Use Y as Z
mode='markers',
marker=dict(
size=2.5,
color='#4285F4', # Fixed blue color
opacity=1
)
)])
fig.update_layout(
title=None,
scene=dict(
# Remove all axes elements
xaxis=dict(visible=False, showticklabels=False, showgrid=False, zeroline=False, showline=False,
showbackground=False),
yaxis=dict(visible=False, showticklabels=False, showgrid=False, zeroline=False, showline=False,
showbackground=False),
zaxis=dict(visible=False, showticklabels=False, showgrid=False, zeroline=False, showline=False,
showbackground=False),
aspectmode='data' # Maintain data aspect ratio
),
# Eliminate margins
margin=dict(l=0, r=0, b=0, t=0, pad=0),
autosize=True,
# Control modebar appearance through layout
modebar=dict(
bgcolor='white',
color='#333',
orientation='v', # Vertical orientation
activecolor='#009688'
),
paper_bgcolor='rgba(0,0,0,0)', # Transparent background
)
# Better camera angle
fig.update_layout(
scene_camera=dict(
eye=dict(x=-1.5, y=0.5, z=-1.5),
up=dict(x=0, y=1, z=0), # Y is up
center=dict(x=0, y=0, z=0)
)
)
return fig
def get_ulip_image(self, guidance_shape_id, angle='036'):
shape_id_ulip = guidance_shape_id.replace('_', '-')
ulip_template = 'https://storage.googleapis.com/sfr-ulip-code-release-research/shapenet-55/only_rgb_depth_images/{}_r_{}_depth0001.png'
ulip_path = ulip_template.format(shape_id_ulip, angle)
try:
ulip_image = load_image(ulip_path).resize((512, 512))
return ulip_image
except Exception as e:
print(f"Error loading image: {e}")
return Image.new('RGB', (512, 512), color='gray')
def on_slider_change(self, shape_idx, category):
"""Update the preview when the slider changes"""
max_idx = self.category_counts.get(category, 0) - 1
# Get shape preview
shape_preview = self.get_shape_preview(category, shape_idx)
# Update counter text
counter_text = f"Shape {shape_idx} of {max_idx}"
return shape_preview, counter_text, shape_idx
def prev_shape(self, current_idx):
"""Go to previous shape"""
new_idx = max(0, current_idx - 1)
return new_idx
def next_shape(self, current_idx, category):
"""Go to next shape"""
max_idx = self.category_counts.get(category, 0) - 1
new_idx = min(max_idx, current_idx + 1)
return new_idx
def jump_to_start(self):
"""Jump to the first shape"""
return 0
def jump_to_end(self, category):
"""Jump to the last shape"""
max_idx = self.category_counts.get(category, 0) - 1
return max_idx
def random_shape(self, category):
"""Select a random shape from the category"""
max_idx = self.category_counts.get(category, 0) - 1
if max_idx <= 0:
return 0
# Generate random index
random_idx = random.randint(0, max_idx)
return random_idx
def random_prompt(self):
"""Select a random prompt from the predefined list"""
return random.choice(self.PREDEFINED_PROMPTS)
def on_category_change(self, category):
"""Update the slider and preview when the category changes"""
# Reset to the first shape
current_idx = 0
max_idx = self.category_counts.get(category, 0) - 1
# Get preview image
preview_image = self.get_shape_preview(category, current_idx)
# Update counter text
counter_text = f"Shape {current_idx} of {max_idx}"
# Need to update the slider range
new_slider = gr.Slider(
minimum=0,
maximum=max_idx,
step=1,
value=current_idx,
label="Shape Index"
)
return new_slider, current_idx, preview_image, counter_text
def get_guidance(self, test_prompt, category_name, guidance_emb):
print(test_prompt, category_name)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device} in get_guidance")
prompt_tokens = torch.LongTensor(self.tokenizer.encode(test_prompt, padding='max_length')).to(device)
with torch.no_grad():
out = self.text_encoder(prompt_tokens.unsqueeze(0), output_attentions=True)
prompt_emb = out.last_hidden_state.detach().clone()
if len(guidance_emb.shape) == 1:
guidance_emb = torch.FloatTensor(guidance_emb).unsqueeze(0).unsqueeze(0)
else:
guidance_emb = torch.FloatTensor(guidance_emb).unsqueeze(0)
guidance_emb = guidance_emb.to(device)
eos_inds = torch.where(prompt_tokens.unsqueeze(0) == 49407)[1]
obj_word = category_name
obj_word_token = self.tokenizer.encode(obj_word)[-2]
chair_inds = torch.where(prompt_tokens.unsqueeze(0) == obj_word_token)[1]
eos_strength = 0.8
obj_strength = 1.0
self.shape2clip_model.eval()
with torch.no_grad():
guided_prompt_emb_cond = self.shape2clip_model(prompt_emb.float(), guidance_emb[:,:,:].float()).half()
guided_prompt_emb = guided_prompt_emb_cond.clone()
guided_prompt_emb[:,:1] = 0
guided_prompt_emb[:,:chair_inds] = 0
guided_prompt_emb[:,chair_inds] *= obj_strength
guided_prompt_emb[:,eos_inds+1:] = 0
guided_prompt_emb[:,eos_inds] *= eos_strength
guided_prompt_emb[:,chair_inds+1:eos_inds:] = 0
fin_guidance = guided_prompt_emb
return fin_guidance, prompt_emb
@spaces.GPU(duration=120)
def generate_images(self, prompt, category, selected_shape_idx, guidance_strength, seed):
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device} in generate_images")
# Move models to gpu
if device.type == "cuda":
self.pipeline = self.pipeline.to(device)
self.shape2clip_model = self.shape2clip_model.to(device)
# Clear status text immediately
status = ""
# Replace [CATEGORY] with the selected category (case-insensitive)
category_pattern = re.compile(r'\[CATEGORY\]', re.IGNORECASE)
if re.search(category_pattern, prompt):
# Use re.sub for replacement to maintain the same casing pattern that was used
final_prompt = re.sub(category_pattern, category, prompt)
else:
# Fallback if user didn't use placeholder
final_prompt = f"{prompt} {category}"
status += self.create_status_message(
f"Warning! For better results, use [CATEGORY] in your prompt where you want '{category}' to appear, otherwise it is appended at the end of the prompt.",
"info"
)
error = False
# Check if prompt contains any other categories
for other_category in self.available_categories:
if re.search(r'\b' + re.escape(other_category) + r'\b', prompt, re.IGNORECASE):
status += self.create_status_message(
f"Error! Your prompt contains '{other_category}'. Please remove it and use [CATEGORY] instead.",
"error"
)
error = True
if error:
return [], status
# Load category embeddings if not already loaded
pb_dict, all_ids = self.load_category_embeddings(category)
if pb_dict is None or not all_ids:
status += self.create_status_message(
f"Error! Unable to load embeddings for {category}",
"error"
)
return [], status
# Ensure shape index is valid
if selected_shape_idx is None or selected_shape_idx < 0:
selected_shape_idx = 0
max_idx = len(all_ids) - 1
selected_shape_idx = max(0, min(selected_shape_idx, max_idx))
guidance_shape_id = all_ids[selected_shape_idx]
# Set generator
generator = torch.Generator(device=device).manual_seed(seed)
results = []
try:
# Generate base image (without guidance)
with torch.no_grad():
base_images = self.pipeline(
prompt=final_prompt,
num_inference_steps=50,
num_images_per_prompt=1,
generator=generator,
guidance_scale=7.5
).images
results.append((base_images[0], "Unguided Result"))
except Exception as e:
print(f"Error generating base image: {e}")
status += self.create_status_message(
f"Error! Unable to generate base image: {str(e)}",
"error"
)
return results, status
try:
# Get shape guidance embedding
pb_emb = pb_dict[guidance_shape_id]
out_guidance, prompt_emb = self.get_guidance(final_prompt, category, pb_emb)
except Exception as e:
print(f"Error generating guidance: {e}")
status += self.create_status_message(
f"Error! Unable to generate guidance: {str(e)}",
"error"
)
return results, status
try:
# Generate guided image
generator = torch.Generator(device=device).manual_seed(seed)
with torch.no_grad():
guided_images = self.pipeline(
prompt_embeds=prompt_emb + guidance_strength * out_guidance,
num_inference_steps=50,
num_images_per_prompt=1,
generator=generator,
guidance_scale=7.5
).images
results.append((guided_images[0], f"Guided Result (λ = {guidance_strength})"))
# Success status
status += self.create_status_message(
f"Success! Generated image guided by Shape #{selected_shape_idx} from category '{category}'.",
"success"
)
torch.cuda.empty_cache()
except Exception as e:
print(f"Error generating guided image: {e}")
status += self.create_status_message(
f"Error! Unable to generate guided image: {str(e)}",
"error"
)
return results, status
return results, status
def create_status_message(self, content, type_="info"):
# Define styles for different message types
styles = {
"info": {
"bg": "rgba(33, 150, 243, 0.15)",
"border": "#2196F3",
"icon": "ℹ️",
"title": "NOTE: "
},
"error": {
"bg": "rgba(244, 67, 54, 0.15)",
"border": "#F44336",
"icon": "❌",
"title": "ERROR: "
},
"success": {
"bg": "rgba(76, 175, 80, 0.15)",
"border": "#4CAF50",
"icon": "✅",
"title": "SUCCESS: "
},
"waiting": {
"bg": "rgba(255, 193, 7, 1)",
"border": "#FFC107",
"icon": "⏳",
"title": "PROCESSING: "
}
}
style = styles.get(type_, styles["info"])
font_weight = "bold" if type_ == "waiting" else "normal"
animation_style = "animation: pulse 1.5s infinite;" if type_ == "waiting" else ""
return f"""
<div style='
padding: 12px;
background-color: {style["bg"]};
border-left: 5px solid {style["border"]};
margin-bottom: 12px;
border-radius: 4px;
display: flex;
align-items: flex-start;
gap: 8px;
box-shadow: 0 1px 3px rgba(0,0,0,0.12);
font-weight: {font_weight};
{animation_style}
'>
<style>
@keyframes pulse {{
0%, 100% {{ opacity: 1; }}
50% {{ opacity: 0.7; }}
}}
</style>
<div style='font-size: 18px; line-height: 1.2;'>{style["icon"]}</div>
<div>{content}</div>
</div>
"""
def on_demo_load(self):
"""Function to ensure initial image is loaded when demo starts"""
default_category = "chair" if "chair" in self.available_categories else self.available_categories[0]
initial_img = self.get_shape_preview(default_category, 0)
return initial_img
def create_ui(self):
# Ensure chair is in available categories, otherwise use the first available
default_category = "chair" if "chair" in self.available_categories else self.available_categories[0]
with gr.Blocks(title="ShapeWords: Guiding Text-to-Image Synthesis with 3D Shape-Aware Prompts",
theme=gr.themes.Soft(
primary_hue="orange",
secondary_hue="blue",
font=[gr.themes.GoogleFont("Inter"), "ui-sans-serif", "system-ui", "sans-serif"],
font_mono=[gr.themes.GoogleFont("IBM Plex Mono"), "ui-monospace", "Consolas", "monospace"],
),
css="""
/* Base styles */
.container { max-width: 1400px; margin: 0 auto; }
/* Title headers */
.title { text-align: center; font-size: 26px; font-weight: 600; margin-bottom: 3px; }
.subtitle { text-align: center; font-size: 16px; margin-bottom: 3px; }
.authors { text-align: center; font-size: 15px; margin-bottom: 3px; }
.affiliations { text-align: center; font-size: 13px; margin-bottom: 3px; }
/* Instructions Accordion */
button.instructions-accordion > span,
.instructions-accordion button > span {
font-size: 17px !important;
font-weight: 600 !important;
}
.instructions-accordion + div p,
.instructions-accordion + div li,
.instructions-text p,
.instructions-text li {
font-size: 14px !important;
}
/* Section Headers */
.step-header,.settings-header {
font-size: 18px;
font-weight: 600;
margin-top: 5px;
margin-bottom: 5px;
}
.sub-header {
margin-top: 5px;
margin-bottom: 5px;
padding-left: 5px;
}
/* Buttons for project page, paper, code etc*/
.buttons-container { margin: 0 auto 10px; }
.buttons-row { display: flex; justify-content: center; gap: 10px; flex-wrap: nowrap; }
.nav-button {
display: inline-block;
padding: 6px 12px;
background-color: #363636;
color: white !important;
text-decoration: none;
border-radius: 20px;
font-weight: 500;
font-size: 14px;
transition: background-color 0.2s;
text-align: center;
white-space: nowrap;
}
.nav-button:hover { background-color: #505050; }
.nav-button.disabled {
opacity: 0.6;
cursor: not-allowed;
}
/* Prompt design section elements */
.category-dropdown .wrap { font-size: 16px; }
.prompt-input { flex-grow: 1; }
.prompt-button {
align-self: center; /* Vertical centering */
margin-left: auto; /* Horizontal centering */
margin-right: auto;
display: block; /* Makes margins work for centering */
}
/* Shape selection section elements */
.shape-navigation {
display: flex;
justify-content: center;
align-items: center;
margin: 10px auto;
gap: 15px;
max-width: 320px;
}
.shape-navigation button {
min-width: 40px;
max-width: 60px;
width: auto;
padding: 6px 10px;
}
.nav-icon-btn { font-size: 18px; }
/* Generate button */
.generate-button {
font-size: 18px !important;
padding: 12px !important;
margin: 15px 0 !important;
background: linear-gradient(135deg, #f97316, #fb923c) !important;
}
/* Results section elements */
.results-gallery { min-height: 100px; max-height: 500px; display: flex; align-items: center; justify-content: center; }
.results-gallery .grid-container { display: flex; align-items: center; }
/* About section elements */
.about-section { font-size: 16px; margin-top: 40px; padding: 20px; border-top: 1px solid rgba(128, 128, 128, 0.2); }
/* Responsive adjustments for mobile mode*/
@media (max-width: 768px) {
.shape-navigation {
max-width: 100%;
gap: 5px;
}
.shape-navigation button {
min-width: 36px;
padding: 6px 0;
font-size: 16px;
}
.buttons-row {
gap: 5px;
}
.nav-button {
padding: 5px 8px;
font-size: 13px;
}
.results-gallery {
max-height: 320px;
}
}
/* Dark mode overrides */
@media (prefers-color-scheme: dark) {
.nav-button {
background-color: #505050;
}
.nav-button:hover {
background-color: #666666;
}
}
""") as demo:
# Header with title and links
gr.Markdown("# ShapeWords: Guiding Text-to-Image Synthesis with 3D Shape-Aware Prompts",
elem_classes="title")
gr.Markdown("### CVPR 2025", elem_classes="subtitle")
gr.Markdown(
"Dmitry Petrov<sup>1</sup>, Pradyumn Goyal<sup>1</sup>, Divyansh Shivashok<sup>1</sup>, Yuanming Tao<sup>1</sup>, Melinos Averkiou<sup>2,3</sup>, Evangelos Kalogerakis<sup>1,2,4</sup>",
elem_classes="authors")
gr.Markdown(
"<sup>1</sup>UMass Amherst <sup>2</sup>CYENS CoE <sup>3</sup>University of Cyprus <sup>4</sup>TU Crete",
elem_classes="affiliations")
# Navigation buttons
with gr.Row():
with gr.Column(scale=3):
pass # Empty space for alignment
with gr.Column(scale=2, elem_classes="buttons-container"):
gr.HTML("""
<div class="buttons-row">
<a href="https://arxiv.org/abs/2412.02912" target="_blank" class="nav-button">
arXiv
</a>
<a href="https://lodurality.github.io/shapewords/" target="_blank" class="nav-button">
Project Page
</a>
<a href="#" target="_blank" class="nav-button disabled">
Code
</a>
<a href="#" target="_blank" class="nav-button disabled">
Data
</a>
</div>
""")
with gr.Column(scale=3):
pass # Empty space for alignment
# Add instructions
with gr.Accordion("📋 Instructions", open=True, elem_classes="instructions-accordion"):
gr.Markdown("""
1️⃣ Select an shape category from the dropdown menu -- overall 55 categories. We recommend trying chair (default), car, lamp and bottle categories.
2️⃣ Create a text prompt using **[CATEGORY]** as a placeholder or use **"Random prompt"** button to select from a small set of pre-defined prompts
3️⃣ Adjust **guidance strength** to control shape influence. Use the default 0.9 value for best balance between prompt and shape adherence. Value of 0.0 corresponds to unguided result that is based just on input prompt.
4️⃣ (optional) Choose **random seed**. For a fixed combination of input prompt and random seed, unguided image will always be the same.
5️⃣ Choose **guidance 3D shape** using the slider, navigation or random shape buttons. Shapes come from ShapeNet dataset (~55K shapes across all categories)
6️⃣ Click **Generate Images** button at the bottom to create images that follow both your text prompt and the selected 3D shape geometry
""", elem_classes="instructions-text")
# Hidden field to store selected shape index
selected_shape_idx = gr.Number(value=0, visible=False)
# Prompt Design (full width)
with gr.Group():
gr.Markdown("### 📝 Prompt Design", elem_classes="step-header")
with gr.Row():
category = gr.Dropdown(
label="1️⃣ Shape Category",
choices=self.available_categories,
value=default_category,
container=True,
elem_classes="category-dropdown",
scale=2
)
prompt = gr.Textbox(
label="2️⃣ Text Prompt - Use [CATEGORY] as a placeholder, e.g. 'a [CATEGORY] under a tree'",
placeholder="an aquarelle drawing of a [CATEGORY]",
value="an aquarelle drawing of a [CATEGORY]",
lines=1,
scale=5,
elem_classes="prompt-input"
)
random_prompt_btn = gr.Button("🎲 Random\nPrompt",
size="lg",
scale=1,
elem_classes="prompt-button")
# Generation Settings (full width)
with gr.Group():
gr.Markdown("### ⚙️ Generation Settings", elem_classes="settings-header")
with gr.Row():
with gr.Column():
guidance_strength = gr.Slider(
minimum=0.0, maximum=1.0, step=0.1, value=0.9,
label="3️⃣ Guidance Strength (λ) - Higher λ = stronger shape adherence"
)
with gr.Column():
seed = gr.Slider(
minimum=0, maximum=10000, step=1, value=42,
label="4️⃣ Random Seed - (optional) Change for different variations"
)
# Middle section - Shape Selection and Results side by side
with gr.Row(equal_height=True):
# Left column - Shape Selection
with gr.Column():
with gr.Group():
gr.Markdown("### 🔍 Shape Selection", elem_classes="step-header")
shape_slider = gr.Slider(
minimum=0,
maximum=self.category_counts.get(default_category, 0) - 1,
step=1,
value=0,
label="5️⃣ Shape Index - Choose a 3D shape to guide image generation",
interactive=True
)
shape_counter = gr.Markdown(f"Shape 0 of {self.category_counts.get(default_category, 0) - 1}", elem_classes="sub-header")
current_shape_plot = gr.Plot(show_label=False)
# Navigation buttons - Icons only for better mobile compatibility
with gr.Row(elem_classes="shape-navigation"):
jump_start_btn = gr.Button("⏮️", size="sm", elem_classes="nav-icon-btn")
prev_shape_btn = gr.Button("◀️", size="sm", elem_classes="nav-icon-btn")
random_btn = gr.Button("🎲", size="sm", variant="secondary", elem_classes="nav-icon-btn")
next_shape_btn = gr.Button("▶️", size="sm", elem_classes="nav-icon-btn")
jump_end_btn = gr.Button("⏭️", size="sm", elem_classes="nav-icon-btn")
# Right column - Results
with gr.Column():
with gr.Group():
gr.Markdown("### 🖼️ Generated Results Preview", elem_classes="step-header")
gallery = gr.Gallery(
label="Results",
show_label=False,
elem_id="results_gallery",
columns=2,
elem_classes="results-gallery"
)
# Generate button (full width)
with gr.Row():
run_button = gr.Button(" 6️⃣ ✨ Generate Images guided by Selected Shape", variant="primary", size="lg",
elem_classes="generate-button")
# Status message (full width)
with gr.Row():
status_text = gr.HTML("", elem_classes="status-message")
# About section at the bottom of the page
with gr.Group(elem_classes="about-section"):
gr.Markdown("""
### About ShapeWords
ShapeWords incorporates target 3D shape information with text prompts to guide image synthesis.
### How It Works
1. Select an shape category from the dropdown menu -- overall 55 categories. We recommend trying chair (default), car, lamp and bottle categories.
2. Create a text prompt using **[CATEGORY]** as a placeholder or use **"Random prompt"** button to select from a small set of pre-defined prompts
3. Adjust **guidance strength** to control shape influence. Use the default 0.9 value for best balance between prompt and shape adherence. Value of 0.0 corresponds to unguided result that is based just on input prompt.
4. (optional) Choose **random seed**. For a fixed combination of input prompt and random seed, unguided image will always be the same.
5. Choose **guidance 3D shape** using the slider, navigation or random shape buttons. Shapes come from ShapeNet dataset (~55K shapes across all categories)
6. Click **Generate Images** button at the bottom to create images that follow both your text prompt and the selected 3D shape geometry
### Citation
```
@misc{petrov2024shapewords,
title={ShapeWords: Guiding Text-to-Image Synthesis with 3D Shape-Aware Prompts},
author={Dmitry Petrov and Pradyumn Goyal and Divyansh Shivashok and Yuanming Tao and Melinos Averkiou and Evangelos Kalogerakis},
year={2024},
eprint={2412.02912},
archivePrefix={arXiv},
primaryClass={cs.CV},
url={https://arxiv.org/abs/2412.02912},
}
```
""")
# Make sure the initial image is loaded when the demo starts
demo.load(
fn=self.on_demo_load,
inputs=None,
outputs=[current_shape_plot]
)
# Connect slider to update preview
shape_slider.change(
fn=self.on_slider_change,
inputs=[shape_slider, category],
outputs=[current_shape_plot, shape_counter, selected_shape_idx]
)
# Previous shape button
prev_shape_btn.click(
fn=self.prev_shape,
inputs=[selected_shape_idx],
outputs=[shape_slider]
)
# Next shape button
next_shape_btn.click(
fn=self.next_shape,
inputs=[selected_shape_idx, category],
outputs=[shape_slider]
)
# Jump to start button
jump_start_btn.click(
fn=self.jump_to_start,
inputs=None,
outputs=[shape_slider]
)
# Jump to end button
jump_end_btn.click(
fn=self.jump_to_end,
inputs=[category],
outputs=[shape_slider]
)
# Random shape button
random_btn.click(
fn=self.random_shape,
inputs=[category],
outputs=[shape_slider]
)
# Connect the random prompt button
random_prompt_btn.click(
fn=self.random_prompt,
inputs=[],
outputs=[prompt]
)
# Update the UI when category changes
category.change(
fn=self.on_category_change,
inputs=[category],
outputs=[shape_slider, selected_shape_idx, current_shape_plot, shape_counter]
)
# Update status text when generating
run_button.click(
fn=lambda: self.create_status_message("Generating images...", "waiting"),
inputs=None,
outputs=[status_text]
)
# Generate images when button is clicked
run_button.click(
fn=self.generate_images,
inputs=[prompt, category, selected_shape_idx, guidance_strength, seed],
outputs=[gallery, status_text]
)
return demo
# Main function and entry point
def main():
parser = argparse.ArgumentParser(description="ShapeWords Gradio Demo")
parser.add_argument('--share', action='store_true', help='Create a public link')
args = parser.parse_args()
# Create the demo app and UI
app = ShapeWordsDemo()
demo = app.create_ui()
demo.launch(share=args.share)
if __name__ == "__main__":
main()