Spaces:
Runtime error
Runtime error
added updated melinos code with point clouds and random prompts
Browse files- app.py +434 -259
- shapewords_paper_code +1 -1
app.py
CHANGED
|
@@ -31,7 +31,7 @@ Usage:
|
|
| 31 |
|
| 32 |
This demo allows users to:
|
| 33 |
1. Select a 3D object category
|
| 34 |
-
2. Choose a specific 3D shape
|
| 35 |
3. Enter a text prompt
|
| 36 |
4. Generate images guided by the selected 3D shape
|
| 37 |
|
|
@@ -46,32 +46,34 @@ import gradio as gr
|
|
| 46 |
from PIL import Image, ImageFont, ImageDraw
|
| 47 |
from diffusers.utils import load_image
|
| 48 |
from diffusers import DPMSolverMultistepScheduler, StableDiffusionPipeline
|
| 49 |
-
#import open_clip
|
| 50 |
import gdown
|
| 51 |
import argparse
|
| 52 |
import random
|
| 53 |
-
import spaces
|
|
|
|
|
|
|
|
|
|
| 54 |
|
| 55 |
-
|
| 56 |
-
os.environ['
|
| 57 |
|
| 58 |
class ShapeWordsDemo:
|
| 59 |
# Constants
|
| 60 |
NAME2CAT = {
|
| 61 |
-
"chair": "03001627", "table": "04379243", "jar": "03593526", "skateboard": "04225987",
|
| 62 |
-
"car": "02958343", "bottle": "02876657", "tower": "04460130", "bookshelf": "02871439",
|
| 63 |
-
"camera": "02942699", "airplane": "02691156", "laptop": "03642806", "basket": "02801938",
|
| 64 |
-
"sofa": "04256520", "knife": "03624134", "can": "02946921", "rifle": "04090263",
|
| 65 |
-
"train": "04468005", "pillow": "03938244", "lamp": "03636649", "trash bin": "02747177",
|
| 66 |
-
"mailbox": "03710193", "watercraft": "04530566", "motorbike": "03790512",
|
| 67 |
-
"dishwasher": "03207941", "bench": "02828884", "pistol": "03948459", "rocket": "04099429",
|
| 68 |
-
"loudspeaker": "03691459", "file cabinet": "03337140", "bag": "02773838",
|
| 69 |
-
"cabinet": "02933112", "bed": "02818832", "birdhouse": "02843684", "display": "03211117",
|
| 70 |
-
"piano": "03928116", "earphone": "03261776", "telephone": "04401088", "stove": "04330267",
|
| 71 |
-
"microphone": "03759954", "bus": "02924116", "mug": "03797390", "remote": "04074963",
|
| 72 |
-
"bathtub": "02808440", "bowl": "02880940", "keyboard": "03085013", "guitar": "03467517",
|
| 73 |
-
"washer": "04554684", "bicycle": "02834778", "faucet": "03325088", "printer": "04004475",
|
| 74 |
-
"cap": "02954340", "phone": "02992529", "clock": "03046257", "helmet": "03513137",
|
| 75 |
"microwave": "03761084", "plant": "03991062"
|
| 76 |
}
|
| 77 |
|
|
@@ -86,30 +88,30 @@ class ShapeWordsDemo:
|
|
| 86 |
self.available_categories = []
|
| 87 |
self.shape_thumbnail_cache = {} # Cache for shape thumbnails
|
| 88 |
self.CAT2NAME = {v: k for k, v in self.NAME2CAT.items()}
|
| 89 |
-
|
| 90 |
# Initialize all models and data
|
| 91 |
self.initialize_models()
|
| 92 |
|
| 93 |
def draw_text(self, img, text, color=(10, 10, 10), size=80, location=(200, 30)):
|
| 94 |
img = img.copy()
|
| 95 |
draw = ImageDraw.Draw(img)
|
| 96 |
-
|
| 97 |
try:
|
| 98 |
font = ImageFont.truetype("Arial", size=size)
|
| 99 |
except IOError:
|
| 100 |
font = ImageFont.load_default()
|
| 101 |
-
|
| 102 |
bbox = draw.textbbox(location, text, font=font)
|
| 103 |
draw.rectangle(bbox, fill="white")
|
| 104 |
draw.text(location, text, color, font=font)
|
| 105 |
-
|
| 106 |
return img
|
| 107 |
|
| 108 |
def get_ulip_image(self, guidance_shape_id, angle='036'):
|
| 109 |
shape_id_ulip = guidance_shape_id.replace('_', '-')
|
| 110 |
ulip_template = 'https://storage.googleapis.com/sfr-ulip-code-release-research/shapenet-55/only_rgb_depth_images/{}_r_{}_depth0001.png'
|
| 111 |
ulip_path = ulip_template.format(shape_id_ulip, angle)
|
| 112 |
-
|
| 113 |
try:
|
| 114 |
ulip_image = load_image(ulip_path).resize((512, 512))
|
| 115 |
return ulip_image
|
|
@@ -117,56 +119,40 @@ class ShapeWordsDemo:
|
|
| 117 |
print(f"Error loading image: {e}")
|
| 118 |
return Image.new('RGB', (512, 512), color='gray')
|
| 119 |
|
| 120 |
-
def get_ulip_thumbnail(self, guidance_shape_id, angle='036', size=(150, 150)):
|
| 121 |
-
"""Get a thumbnail version of the ULIP image for use in the gallery"""
|
| 122 |
-
image = self.get_ulip_image(guidance_shape_id, angle)
|
| 123 |
-
return image.resize(size)
|
| 124 |
-
|
| 125 |
def initialize_models(self):
|
| 126 |
-
device =
|
| 127 |
-
|
| 128 |
-
|
|
|
|
| 129 |
# Download Shape2CLIP code if it doesn't exist
|
| 130 |
if not os.path.exists("shapewords_paper_code"):
|
| 131 |
print("Loading models file")
|
| 132 |
os.system("git clone https://github.com/lodurality/shapewords_paper_code.git")
|
| 133 |
-
|
| 134 |
# Import Shape2CLIP model
|
| 135 |
sys.path.append("./shapewords_paper_code")
|
| 136 |
from shapewords_paper_code.geometry_guidance_models import Shape2CLIP
|
| 137 |
-
|
| 138 |
# Initialize the pipeline
|
| 139 |
self.pipeline = StableDiffusionPipeline.from_pretrained(
|
| 140 |
-
"stabilityai/stable-diffusion-2-1-base",
|
| 141 |
torch_dtype=torch.float16 if device.type == "cuda" else torch.float32
|
| 142 |
)
|
| 143 |
-
|
| 144 |
self.pipeline.scheduler = DPMSolverMultistepScheduler.from_config(
|
| 145 |
-
self.pipeline.scheduler.config,
|
| 146 |
algorithm_type="sde-dpmsolver++"
|
| 147 |
)
|
| 148 |
-
|
| 149 |
-
# Load CLIP model
|
| 150 |
-
#clip_model, _, preprocess = open_clip.create_model_and_transforms(
|
| 151 |
-
# 'ViT-H-14',
|
| 152 |
-
# pretrained='laion2b_s32b_b79k'
|
| 153 |
-
#)
|
| 154 |
-
|
| 155 |
-
# Move models to device if not using ZeroGPU
|
| 156 |
-
if device.type == "cuda":
|
| 157 |
-
self.pipeline = self.pipeline.to(device)
|
| 158 |
-
#self.pipeline.enable_model_cpu_offload()
|
| 159 |
-
|
| 160 |
-
#clip_tokenizer = open_clip.get_tokenizer('ViT-H-14')
|
| 161 |
self.text_encoder = self.pipeline.text_encoder
|
| 162 |
self.tokenizer = self.pipeline.tokenizer
|
| 163 |
-
|
| 164 |
# Look for Shape2CLIP checkpoint in multiple locations
|
| 165 |
checkpoint_paths = [
|
| 166 |
-
"
|
| 167 |
-
"/data/
|
| 168 |
]
|
| 169 |
-
|
| 170 |
checkpoint_found = False
|
| 171 |
checkpoint_path = None
|
| 172 |
for path in checkpoint_paths:
|
|
@@ -175,43 +161,40 @@ class ShapeWordsDemo:
|
|
| 175 |
print(f"Found Shape2CLIP checkpoint at: {checkpoint_path}")
|
| 176 |
checkpoint_found = True
|
| 177 |
break
|
| 178 |
-
|
| 179 |
# Download Shape2CLIP checkpoint if not found
|
| 180 |
if not checkpoint_found:
|
| 181 |
checkpoint_path = "projection_model-0920192.pth"
|
| 182 |
print("Downloading Shape2CLIP model checkpoint...")
|
| 183 |
-
gdown.download("1nvEXnwMpNkRts6rxVqMZt8i9FZ40KjP7", checkpoint_path, quiet=False)
|
| 184 |
print("Download complete")
|
| 185 |
-
|
| 186 |
# Initialize Shape2CLIP model
|
| 187 |
self.shape2clip_model = Shape2CLIP(depth=6, drop_path_rate=0.1, pb_dim=384)
|
| 188 |
self.shape2clip_model.load_state_dict(torch.load(checkpoint_path, map_location=device))
|
| 189 |
-
if device.type == "cuda":
|
| 190 |
-
self.shape2clip_model = self.shape2clip_model.to(device)
|
| 191 |
self.shape2clip_model.eval()
|
| 192 |
-
|
| 193 |
# Scan for available embeddings
|
| 194 |
self.scan_available_embeddings()
|
| 195 |
|
| 196 |
def scan_available_embeddings(self):
|
| 197 |
self.available_categories = []
|
| 198 |
self.category_counts = {}
|
| 199 |
-
|
|
|
|
| 200 |
for category, cat_id in self.NAME2CAT.items():
|
| 201 |
possible_filenames = [
|
| 202 |
-
f"pointbert_shapenet_{cat_id}.npz",
|
| 203 |
f"{cat_id}_pb_embs.npz",
|
| 204 |
-
f"embeddings/
|
| 205 |
-
|
| 206 |
-
f"/data/shapenet_pointbert_tokens/{cat_id}_pb_embs.npz"
|
| 207 |
]
|
| 208 |
-
|
| 209 |
found_file = None
|
| 210 |
for filename in possible_filenames:
|
| 211 |
if os.path.exists(filename):
|
| 212 |
found_file = filename
|
| 213 |
break
|
| 214 |
-
|
| 215 |
if found_file:
|
| 216 |
try:
|
| 217 |
pb_data = np.load(found_file)
|
|
@@ -224,42 +207,41 @@ class ShapeWordsDemo:
|
|
| 224 |
count = len(pb_data[keys[0]])
|
| 225 |
else:
|
| 226 |
count = 0
|
| 227 |
-
|
| 228 |
if count > 0:
|
| 229 |
self.available_categories.append(category)
|
| 230 |
self.category_counts[category] = count
|
| 231 |
print(f"Found {count} embeddings for category '{category}'")
|
| 232 |
except Exception as e:
|
| 233 |
print(f"Error loading embeddings for {category}: {e}")
|
| 234 |
-
|
| 235 |
-
if not self.available_categories:
|
| 236 |
-
self.available_categories = ["chair"] # Fallback
|
| 237 |
-
self.category_counts["chair"] = 50 # Default value
|
| 238 |
-
|
| 239 |
# Sort categories alphabetically
|
| 240 |
self.available_categories.sort()
|
| 241 |
-
|
| 242 |
print(f"Found {len(self.available_categories)} categories with embeddings")
|
| 243 |
print(f"Available categories: {', '.join(self.available_categories)}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 244 |
|
| 245 |
def load_category_embeddings(self, category):
|
| 246 |
if category in self.category_embeddings:
|
| 247 |
return self.category_embeddings[category]
|
| 248 |
-
|
| 249 |
if category not in self.NAME2CAT:
|
| 250 |
return None, []
|
| 251 |
-
|
| 252 |
cat_id = self.NAME2CAT[category]
|
| 253 |
-
|
| 254 |
# Check for different possible embedding filenames and locations
|
| 255 |
possible_filenames = [
|
| 256 |
-
f"
|
| 257 |
-
f"{cat_id}_pb_embs.npz",
|
| 258 |
-
f"embeddings/pointbert_shapenet_{cat_id}.npz",
|
| 259 |
f"embeddings/{cat_id}_pb_embs.npz",
|
| 260 |
-
f"/data/shapenet_pointbert_tokens/{cat_id}_pb_embs.npz"
|
| 261 |
]
|
| 262 |
-
|
| 263 |
# Find the first existing file
|
| 264 |
pb_emb_filename = None
|
| 265 |
for filename in possible_filenames:
|
|
@@ -267,16 +249,16 @@ class ShapeWordsDemo:
|
|
| 267 |
pb_emb_filename = filename
|
| 268 |
print(f"Found embeddings file: {pb_emb_filename}")
|
| 269 |
break
|
| 270 |
-
|
| 271 |
if pb_emb_filename is None:
|
| 272 |
print(f"No embeddings found for {category}")
|
| 273 |
return None, []
|
| 274 |
-
|
| 275 |
# Load embeddings
|
| 276 |
try:
|
| 277 |
print(f"Loading embeddings from {pb_emb_filename}...")
|
| 278 |
pb_data = np.load(pb_emb_filename)
|
| 279 |
-
|
| 280 |
# Check for different key names in the NPZ file
|
| 281 |
if 'ids' in pb_data and 'embs' in pb_data:
|
| 282 |
pb_dict = dict(zip(pb_data['ids'], pb_data['embs']))
|
|
@@ -289,10 +271,10 @@ class ShapeWordsDemo:
|
|
| 289 |
else:
|
| 290 |
print("Unexpected embedding file format")
|
| 291 |
return None, []
|
| 292 |
-
|
| 293 |
all_ids = sorted(list(pb_dict.keys()))
|
| 294 |
print(f"Loaded {len(all_ids)} shape embeddings for {category}")
|
| 295 |
-
|
| 296 |
# Cache the results
|
| 297 |
self.category_embeddings[category] = (pb_dict, all_ids)
|
| 298 |
return pb_dict, all_ids
|
|
@@ -301,90 +283,280 @@ class ShapeWordsDemo:
|
|
| 301 |
print(f"Exception details: {str(e)}")
|
| 302 |
return None, []
|
| 303 |
|
| 304 |
-
def
|
| 305 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 306 |
if shape_idx is None or shape_idx < 0:
|
| 307 |
return None
|
| 308 |
-
|
|
|
|
| 309 |
pb_dict, all_ids = self.load_category_embeddings(category)
|
| 310 |
if pb_dict is None or not all_ids or shape_idx >= len(all_ids):
|
| 311 |
return None
|
|
|
|
| 312 |
shape_id = all_ids[shape_idx]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 313 |
try:
|
| 314 |
-
# Get the shape image at the requested size
|
| 315 |
preview_image = self.get_ulip_image(shape_id)
|
| 316 |
-
preview_image = preview_image.resize(
|
| 317 |
-
preview_with_text = self.draw_text(preview_image, f"Shape #{shape_idx}", size=
|
| 318 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 319 |
except Exception as e:
|
| 320 |
print(f"Error loading preview for {shape_id}: {e}")
|
| 321 |
-
# Create
|
| 322 |
-
|
| 323 |
-
|
| 324 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 325 |
|
| 326 |
def on_slider_change(self, shape_idx, category):
|
| 327 |
"""Update the preview when the slider changes"""
|
| 328 |
max_idx = self.category_counts.get(category, 0) - 1
|
| 329 |
-
|
| 330 |
# Get preview image
|
| 331 |
preview_image = self.get_shape_preview(category, shape_idx)
|
| 332 |
-
|
| 333 |
# Update counter text
|
| 334 |
counter_text = f"Shape {shape_idx} of {max_idx}"
|
| 335 |
-
|
| 336 |
return preview_image, counter_text, shape_idx
|
| 337 |
|
| 338 |
def prev_shape(self, current_idx, category):
|
| 339 |
"""Go to previous shape"""
|
| 340 |
max_idx = self.category_counts.get(category, 0) - 1
|
| 341 |
new_idx = max(0, current_idx - 1)
|
| 342 |
-
|
| 343 |
# Get preview image
|
| 344 |
preview_image = self.get_shape_preview(category, new_idx)
|
| 345 |
-
|
| 346 |
# Update counter text
|
| 347 |
counter_text = f"Shape {new_idx} of {max_idx}"
|
| 348 |
-
|
| 349 |
return new_idx, preview_image, counter_text
|
| 350 |
|
| 351 |
def next_shape(self, current_idx, category):
|
| 352 |
"""Go to next shape"""
|
| 353 |
max_idx = self.category_counts.get(category, 0) - 1
|
| 354 |
new_idx = min(max_idx, current_idx + 1)
|
| 355 |
-
|
| 356 |
# Get preview image
|
| 357 |
preview_image = self.get_shape_preview(category, new_idx)
|
| 358 |
-
|
| 359 |
# Update counter text
|
| 360 |
counter_text = f"Shape {new_idx} of {max_idx}"
|
| 361 |
-
|
| 362 |
return new_idx, preview_image, counter_text
|
| 363 |
|
| 364 |
def jump_to_start(self, category):
|
| 365 |
"""Jump to the first shape"""
|
| 366 |
max_idx = self.category_counts.get(category, 0) - 1
|
| 367 |
new_idx = 0
|
| 368 |
-
|
| 369 |
# Get preview image
|
| 370 |
preview_image = self.get_shape_preview(category, new_idx)
|
| 371 |
-
|
| 372 |
# Update counter text
|
| 373 |
counter_text = f"Shape {new_idx} of {max_idx}"
|
| 374 |
-
|
| 375 |
return new_idx, preview_image, counter_text
|
| 376 |
|
| 377 |
def jump_to_end(self, category):
|
| 378 |
"""Jump to the last shape"""
|
| 379 |
max_idx = self.category_counts.get(category, 0) - 1
|
| 380 |
new_idx = max_idx
|
| 381 |
-
|
| 382 |
# Get preview image
|
| 383 |
preview_image = self.get_shape_preview(category, new_idx)
|
| 384 |
-
|
| 385 |
# Update counter text
|
| 386 |
counter_text = f"Shape {new_idx} of {max_idx}"
|
| 387 |
-
|
| 388 |
return new_idx, preview_image, counter_text
|
| 389 |
|
| 390 |
def random_shape(self, category):
|
|
@@ -392,30 +564,49 @@ class ShapeWordsDemo:
|
|
| 392 |
max_idx = self.category_counts.get(category, 0) - 1
|
| 393 |
if max_idx <= 0:
|
| 394 |
return 0, self.get_shape_preview(category, 0), f"Shape 0 of 0"
|
| 395 |
-
|
| 396 |
# Generate random index
|
| 397 |
random_idx = random.randint(0, max_idx)
|
| 398 |
-
|
| 399 |
# Get preview image
|
| 400 |
preview_image = self.get_shape_preview(category, random_idx)
|
| 401 |
-
|
| 402 |
# Update counter text
|
| 403 |
counter_text = f"Shape {random_idx} of {max_idx}"
|
| 404 |
-
|
| 405 |
return random_idx, preview_image, counter_text
|
| 406 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 407 |
def on_category_change(self, category):
|
| 408 |
"""Update the slider and preview when the category changes"""
|
| 409 |
# Reset to the first shape
|
| 410 |
current_idx = 0
|
| 411 |
max_idx = self.category_counts.get(category, 0) - 1
|
| 412 |
-
|
| 413 |
# Get preview image
|
| 414 |
preview_image = self.get_shape_preview(category, current_idx)
|
| 415 |
-
|
| 416 |
# Update counter text
|
| 417 |
counter_text = f"Shape {current_idx} of {max_idx}"
|
| 418 |
-
|
| 419 |
# Need to update the slider range
|
| 420 |
new_slider = gr.Slider(
|
| 421 |
minimum=0,
|
|
@@ -424,19 +615,20 @@ class ShapeWordsDemo:
|
|
| 424 |
value=current_idx,
|
| 425 |
label="Shape Index"
|
| 426 |
)
|
| 427 |
-
|
| 428 |
return new_slider, current_idx, preview_image, counter_text
|
| 429 |
|
| 430 |
def get_guidance(self, test_prompt, category_name, guidance_emb):
|
| 431 |
-
print("Getting guidance")
|
| 432 |
print(test_prompt, category_name)
|
| 433 |
-
device = torch.device(
|
|
|
|
|
|
|
| 434 |
prompt_tokens = torch.LongTensor(self.tokenizer.encode(test_prompt, padding='max_length')).to(device)
|
|
|
|
| 435 |
with torch.no_grad():
|
| 436 |
out = self.text_encoder(prompt_tokens.unsqueeze(0), output_attentions=True)
|
| 437 |
prompt_emb = out.last_hidden_state.detach().clone()
|
| 438 |
-
|
| 439 |
-
|
| 440 |
if len(guidance_emb.shape) == 1:
|
| 441 |
guidance_emb = torch.FloatTensor(guidance_emb).unsqueeze(0).unsqueeze(0)
|
| 442 |
else:
|
|
@@ -455,7 +647,7 @@ class ShapeWordsDemo:
|
|
| 455 |
with torch.no_grad():
|
| 456 |
guided_prompt_emb_cond = self.shape2clip_model(prompt_emb.float(), guidance_emb[:,:,:].float()).half()
|
| 457 |
guided_prompt_emb = guided_prompt_emb_cond.clone()
|
| 458 |
-
|
| 459 |
guided_prompt_emb[:,:1] = 0
|
| 460 |
guided_prompt_emb[:,:chair_inds] = 0
|
| 461 |
guided_prompt_emb[:,chair_inds] *= obj_strength
|
|
@@ -466,72 +658,76 @@ class ShapeWordsDemo:
|
|
| 466 |
|
| 467 |
return fin_guidance, prompt_emb
|
| 468 |
|
| 469 |
-
# For ZeroGPU compatibility, uncomment this decorator when using ZeroGPU
|
| 470 |
@spaces.GPU(duration=120)
|
| 471 |
def generate_images(self, prompt, category, selected_shape_idx, guidance_strength, seed):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 472 |
# Clear status text immediately
|
| 473 |
status = ""
|
| 474 |
-
|
| 475 |
-
#
|
| 476 |
-
|
| 477 |
-
|
| 478 |
-
|
| 479 |
-
|
| 480 |
-
|
| 481 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 482 |
for other_category in self.available_categories:
|
| 483 |
-
if other_category
|
| 484 |
-
#
|
| 485 |
-
|
| 486 |
-
|
| 487 |
-
|
| 488 |
-
|
| 489 |
# Load category embeddings if not already loaded
|
| 490 |
pb_dict, all_ids = self.load_category_embeddings(category)
|
| 491 |
if pb_dict is None or not all_ids:
|
| 492 |
-
|
| 493 |
-
|
|
|
|
| 494 |
# Ensure shape index is valid
|
| 495 |
if selected_shape_idx is None or selected_shape_idx < 0:
|
| 496 |
selected_shape_idx = 0
|
| 497 |
-
|
| 498 |
max_idx = len(all_ids) - 1
|
| 499 |
selected_shape_idx = max(0, min(selected_shape_idx, max_idx))
|
| 500 |
guidance_shape_id = all_ids[selected_shape_idx]
|
| 501 |
-
|
| 502 |
-
# Set
|
| 503 |
-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 504 |
generator = torch.Generator(device=device).manual_seed(seed)
|
| 505 |
-
|
| 506 |
results = []
|
| 507 |
-
|
| 508 |
-
# Add status message for generation
|
| 509 |
-
updating_status = f"<div style='padding: 10px; background-color: #e8f5e9; border-left: 5px solid #4caf50; margin-bottom: 10px;'>Generating images using Shape #{selected_shape_idx}...</div>"
|
| 510 |
-
|
| 511 |
try:
|
| 512 |
-
# For ZeroGPU, move models to GPU if not already there
|
| 513 |
-
if hasattr(spaces, 'GPU'):
|
| 514 |
-
self.pipeline = self.pipeline.to(device)
|
| 515 |
-
self.shape2clip_model = self.shape2clip_model.to(device)
|
| 516 |
-
|
| 517 |
# Generate base image (without guidance)
|
| 518 |
with torch.no_grad():
|
| 519 |
base_images = self.pipeline(
|
| 520 |
-
prompt=
|
| 521 |
num_inference_steps=50,
|
| 522 |
num_images_per_prompt=1,
|
| 523 |
generator=generator,
|
| 524 |
guidance_scale=7.5
|
| 525 |
).images
|
| 526 |
-
|
| 527 |
base_image = base_images[0]
|
| 528 |
base_image = self.draw_text(base_image, "Unguided result")
|
| 529 |
results.append(base_image)
|
| 530 |
except Exception as e:
|
| 531 |
print(f"Error generating base image: {e}")
|
| 532 |
-
status = f"<div style='padding: 10px; background-color: #ffebee; border-left: 5px solid #e74c3c; font-weight: bold; margin-bottom: 10px;'>⚠️ ERROR: Error generating base image: {str(e)}</div>"
|
| 533 |
return results, status
|
| 534 |
-
|
| 535 |
try:
|
| 536 |
# Get shape guidance image
|
| 537 |
ulip_image = self.get_ulip_image(guidance_shape_id)
|
|
@@ -539,19 +735,18 @@ class ShapeWordsDemo:
|
|
| 539 |
results.append(ulip_image)
|
| 540 |
except Exception as e:
|
| 541 |
print(f"Error getting guidance shape: {e}")
|
| 542 |
-
status = f"<div style='padding: 10px; background-color: #ffebee; border-left: 5px solid #e74c3c; font-weight: bold; margin-bottom: 10px;'>⚠️ ERROR: Error getting guidance shape: {str(e)}</div>"
|
| 543 |
return results, status
|
| 544 |
-
|
| 545 |
try:
|
| 546 |
# Get shape guidance embedding
|
| 547 |
pb_emb = pb_dict[guidance_shape_id]
|
| 548 |
-
|
| 549 |
-
out_guidance, prompt_emb = self.get_guidance(prompt, category, pb_emb)
|
| 550 |
except Exception as e:
|
| 551 |
print(f"Error generating guidance: {e}")
|
| 552 |
-
status = f"<div style='padding: 10px; background-color: #ffebee; border-left: 5px solid #e74c3c; font-weight: bold; margin-bottom: 10px;'>⚠️ ERROR: Error generating guidance: {str(e)}</div>"
|
| 553 |
return results, status
|
| 554 |
-
|
| 555 |
try:
|
| 556 |
# Generate guided image
|
| 557 |
generator = torch.Generator(device=device).manual_seed(seed)
|
|
@@ -563,51 +758,21 @@ class ShapeWordsDemo:
|
|
| 563 |
generator=generator,
|
| 564 |
guidance_scale=7.5
|
| 565 |
).images
|
| 566 |
-
|
| 567 |
guided_image = guided_images[0]
|
| 568 |
guided_image = self.draw_text(guided_image, f"Guided result (λ={guidance_strength:.1f})")
|
| 569 |
results.append(guided_image)
|
| 570 |
-
|
| 571 |
# Success status
|
| 572 |
-
status = f"<div style='padding: 10px; background-color: #e8f5e9; border-left: 5px solid #4caf50; margin-bottom: 10px;'>✓ Successfully generated images using Shape #{selected_shape_idx} from category '{category}'.</div>"
|
| 573 |
-
|
| 574 |
-
|
| 575 |
-
|
| 576 |
-
self.pipeline = self.pipeline.to('cpu')
|
| 577 |
-
self.shape2clip_model = self.shape2clip_model.to('cpu')
|
| 578 |
-
torch.cuda.empty_cache()
|
| 579 |
-
|
| 580 |
except Exception as e:
|
| 581 |
print(f"Error generating guided image: {e}")
|
| 582 |
-
status = f"<div style='padding: 10px; background-color: #ffebee; border-left: 5px solid #e74c3c; font-weight: bold; margin-bottom: 10px;'>⚠️ ERROR: Error generating guided image: {str(e)}</div>"
|
| 583 |
-
|
| 584 |
-
return results, status
|
| 585 |
|
| 586 |
-
|
| 587 |
-
# Remove all existing categories from the prompt
|
| 588 |
-
cleaned_prompt = old_prompt
|
| 589 |
-
for cat in self.available_categories:
|
| 590 |
-
# Skip the current category
|
| 591 |
-
if cat == new_category:
|
| 592 |
-
continue
|
| 593 |
-
|
| 594 |
-
# Replace the category with a space, being careful about word boundaries
|
| 595 |
-
cleaned_prompt = cleaned_prompt.replace(f" {cat} ", " ")
|
| 596 |
-
cleaned_prompt = cleaned_prompt.replace(f" {cat}", "")
|
| 597 |
-
cleaned_prompt = cleaned_prompt.replace(f"{cat} ", "")
|
| 598 |
-
# Only do exact match for the whole prompt
|
| 599 |
-
if cleaned_prompt == cat:
|
| 600 |
-
cleaned_prompt = ""
|
| 601 |
-
|
| 602 |
-
# Add the new category if it's not already in the cleaned prompt
|
| 603 |
-
cleaned_prompt = cleaned_prompt.strip()
|
| 604 |
-
if new_category not in cleaned_prompt:
|
| 605 |
-
if cleaned_prompt:
|
| 606 |
-
return f"{cleaned_prompt} {new_category}"
|
| 607 |
-
else:
|
| 608 |
-
return new_category
|
| 609 |
-
else:
|
| 610 |
-
return cleaned_prompt
|
| 611 |
|
| 612 |
def on_demo_load(self):
|
| 613 |
"""Function to ensure initial image is loaded when demo starts"""
|
|
@@ -618,7 +783,7 @@ class ShapeWordsDemo:
|
|
| 618 |
def create_ui(self):
|
| 619 |
# Ensure chair is in available categories, otherwise use the first available
|
| 620 |
default_category = "chair" if "chair" in self.available_categories else self.available_categories[0]
|
| 621 |
-
|
| 622 |
with gr.Blocks(title="ShapeWords: Guiding Text-to-Image Synthesis with 3D Shape-Aware Prompts") as demo:
|
| 623 |
gr.Markdown("""
|
| 624 |
# ShapeWords: Guiding Text-to-Image Synthesis with 3D Shape-Aware Prompts
|
|
@@ -629,27 +794,35 @@ class ShapeWordsDemo:
|
|
| 629 |
- **Paper**: [ArXiv](https://arxiv.org/abs/2412.02912)
|
| 630 |
- **Publication**: Accepted to CVPR 2025
|
| 631 |
""")
|
| 632 |
-
|
| 633 |
with gr.Row():
|
| 634 |
with gr.Column(scale=1):
|
| 635 |
prompt = gr.Textbox(
|
| 636 |
-
label="Prompt",
|
| 637 |
-
placeholder="an aquarelle drawing of a
|
| 638 |
-
value=f"an aquarelle drawing of a
|
| 639 |
)
|
| 640 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 641 |
category = gr.Dropdown(
|
| 642 |
-
label="Object Category",
|
| 643 |
choices=self.available_categories,
|
| 644 |
value=default_category
|
| 645 |
)
|
| 646 |
-
|
| 647 |
# Hidden field to store selected shape index
|
| 648 |
selected_shape_idx = gr.Number(
|
| 649 |
value=0,
|
| 650 |
visible=False
|
| 651 |
)
|
| 652 |
-
|
| 653 |
# Create a slider for shape selection with preview
|
| 654 |
with gr.Row():
|
| 655 |
with gr.Column(scale=1):
|
|
@@ -662,47 +835,48 @@ class ShapeWordsDemo:
|
|
| 662 |
label="Shape Index",
|
| 663 |
interactive=True
|
| 664 |
)
|
| 665 |
-
|
| 666 |
# Display shape index counter
|
| 667 |
shape_counter = gr.Markdown(f"Shape 0 of {self.category_counts.get(default_category, 0) - 1}")
|
| 668 |
-
|
| 669 |
# Quick navigation buttons
|
| 670 |
with gr.Row():
|
| 671 |
jump_start_btn = gr.Button("⏮️ First", size="sm")
|
| 672 |
-
random_btn = gr.Button("🎲 Random", size="sm", variant="secondary")
|
| 673 |
jump_end_btn = gr.Button("Last ⏭️", size="sm")
|
| 674 |
-
|
| 675 |
with gr.Row():
|
| 676 |
prev_shape_btn = gr.Button("◀️ Previous", size="sm")
|
| 677 |
next_shape_btn = gr.Button("Next ▶️", size="sm")
|
| 678 |
-
|
| 679 |
with gr.Column(scale=1):
|
| 680 |
-
|
| 681 |
-
|
| 682 |
-
label=
|
| 683 |
-
|
| 684 |
-
|
|
|
|
| 685 |
)
|
| 686 |
-
|
| 687 |
guidance_strength = gr.Slider(
|
| 688 |
minimum=0.0, maximum=1.0, step=0.1, value=0.9,
|
| 689 |
label="Guidance Strength (λ)"
|
| 690 |
)
|
| 691 |
-
|
| 692 |
seed = gr.Slider(
|
| 693 |
minimum=0, maximum=10000, step=1, value=42,
|
| 694 |
label="Random Seed"
|
| 695 |
)
|
| 696 |
-
|
| 697 |
run_button = gr.Button("Generate Images", variant="primary")
|
| 698 |
-
|
| 699 |
info = gr.Markdown("""
|
| 700 |
**Note**: Higher guidance strength (λ) means stronger adherence to the 3D shape.
|
| 701 |
Start with λ=0.9 for a good balance between shape and prompt adherence.
|
| 702 |
""")
|
| 703 |
-
|
| 704 |
status_text = gr.HTML("")
|
| 705 |
-
|
| 706 |
with gr.Column(scale=2):
|
| 707 |
gallery = gr.Gallery(
|
| 708 |
label="Results",
|
|
@@ -711,84 +885,84 @@ class ShapeWordsDemo:
|
|
| 711 |
columns=3,
|
| 712 |
height="auto"
|
| 713 |
)
|
| 714 |
-
|
| 715 |
# Make sure the initial image is loaded when the demo starts
|
| 716 |
demo.load(
|
| 717 |
fn=self.on_demo_load,
|
| 718 |
inputs=None,
|
| 719 |
-
outputs=[
|
| 720 |
)
|
| 721 |
-
|
| 722 |
# Connect slider to update preview
|
| 723 |
shape_slider.change(
|
| 724 |
fn=self.on_slider_change,
|
| 725 |
inputs=[shape_slider, category],
|
| 726 |
-
outputs=[
|
| 727 |
)
|
| 728 |
-
|
| 729 |
# Previous shape button
|
| 730 |
prev_shape_btn.click(
|
| 731 |
fn=self.prev_shape,
|
| 732 |
inputs=[selected_shape_idx, category],
|
| 733 |
-
outputs=[shape_slider,
|
| 734 |
)
|
| 735 |
-
|
| 736 |
# Next shape button
|
| 737 |
next_shape_btn.click(
|
| 738 |
fn=self.next_shape,
|
| 739 |
inputs=[selected_shape_idx, category],
|
| 740 |
-
outputs=[shape_slider,
|
| 741 |
)
|
| 742 |
-
|
| 743 |
# Jump to start button
|
| 744 |
jump_start_btn.click(
|
| 745 |
fn=self.jump_to_start,
|
| 746 |
inputs=[category],
|
| 747 |
-
outputs=[shape_slider,
|
| 748 |
)
|
| 749 |
-
|
| 750 |
# Jump to end button
|
| 751 |
jump_end_btn.click(
|
| 752 |
fn=self.jump_to_end,
|
| 753 |
inputs=[category],
|
| 754 |
-
outputs=[shape_slider,
|
| 755 |
)
|
| 756 |
-
|
| 757 |
# Random shape button
|
| 758 |
random_btn.click(
|
| 759 |
fn=self.random_shape,
|
| 760 |
inputs=[category],
|
| 761 |
-
outputs=[shape_slider,
|
| 762 |
)
|
| 763 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 764 |
# Update the UI when category changes
|
| 765 |
category.change(
|
| 766 |
fn=self.on_category_change,
|
| 767 |
inputs=[category],
|
| 768 |
-
outputs=[shape_slider, selected_shape_idx,
|
| 769 |
-
)
|
| 770 |
-
|
| 771 |
-
# Automatically update prompt when category changes
|
| 772 |
-
category.change(
|
| 773 |
-
fn=self.update_prompt_for_category,
|
| 774 |
-
inputs=[prompt, category],
|
| 775 |
-
outputs=[prompt]
|
| 776 |
)
|
| 777 |
-
|
| 778 |
# Clear status text before generating new images
|
| 779 |
run_button.click(
|
| 780 |
fn=lambda: None, # Empty function to clear the status
|
| 781 |
inputs=None,
|
| 782 |
outputs=[status_text]
|
| 783 |
)
|
| 784 |
-
|
| 785 |
# Generate images when button is clicked
|
| 786 |
run_button.click(
|
| 787 |
fn=self.generate_images,
|
| 788 |
inputs=[prompt, category, selected_shape_idx, guidance_strength, seed],
|
| 789 |
outputs=[gallery, status_text]
|
| 790 |
)
|
| 791 |
-
|
| 792 |
gr.Markdown("""
|
| 793 |
## Credits
|
| 794 |
|
|
@@ -807,7 +981,7 @@ class ShapeWordsDemo:
|
|
| 807 |
}
|
| 808 |
```
|
| 809 |
""")
|
| 810 |
-
|
| 811 |
return demo
|
| 812 |
|
| 813 |
|
|
@@ -816,11 +990,12 @@ def main():
|
|
| 816 |
parser = argparse.ArgumentParser(description="ShapeWords Gradio Demo")
|
| 817 |
parser.add_argument('--share', action='store_true', help='Create a public link')
|
| 818 |
args = parser.parse_args()
|
| 819 |
-
|
| 820 |
# Create the demo app and UI
|
| 821 |
app = ShapeWordsDemo()
|
| 822 |
demo = app.create_ui()
|
| 823 |
demo.launch(share=args.share)
|
| 824 |
|
|
|
|
| 825 |
if __name__ == "__main__":
|
| 826 |
-
main()
|
|
|
|
| 31 |
|
| 32 |
This demo allows users to:
|
| 33 |
1. Select a 3D object category
|
| 34 |
+
2. Choose a specific 3D shape
|
| 35 |
3. Enter a text prompt
|
| 36 |
4. Generate images guided by the selected 3D shape
|
| 37 |
|
|
|
|
| 46 |
from PIL import Image, ImageFont, ImageDraw
|
| 47 |
from diffusers.utils import load_image
|
| 48 |
from diffusers import DPMSolverMultistepScheduler, StableDiffusionPipeline
|
|
|
|
| 49 |
import gdown
|
| 50 |
import argparse
|
| 51 |
import random
|
| 52 |
+
import spaces # for Hugging Face ZeroGPU deployment
|
| 53 |
+
import re
|
| 54 |
+
import plotly.graph_objects as go
|
| 55 |
+
from numpy.lib.user_array import container
|
| 56 |
|
| 57 |
+
# Only for Hugging Face hosting - Add the Hugging Face cache to persistent storage to avoid downloading safetensors every time the demo sleeps and wakes up
|
| 58 |
+
os.environ['HF_HOME'] = '/data/.huggingface'
|
| 59 |
|
| 60 |
class ShapeWordsDemo:
|
| 61 |
# Constants
|
| 62 |
NAME2CAT = {
|
| 63 |
+
"chair": "03001627", "table": "04379243", "jar": "03593526", "skateboard": "04225987",
|
| 64 |
+
"car": "02958343", "bottle": "02876657", "tower": "04460130", "bookshelf": "02871439",
|
| 65 |
+
"camera": "02942699", "airplane": "02691156", "laptop": "03642806", "basket": "02801938",
|
| 66 |
+
"sofa": "04256520", "knife": "03624134", "can": "02946921", "rifle": "04090263",
|
| 67 |
+
"train": "04468005", "pillow": "03938244", "lamp": "03636649", "trash bin": "02747177",
|
| 68 |
+
"mailbox": "03710193", "watercraft": "04530566", "motorbike": "03790512",
|
| 69 |
+
"dishwasher": "03207941", "bench": "02828884", "pistol": "03948459", "rocket": "04099429",
|
| 70 |
+
"loudspeaker": "03691459", "file cabinet": "03337140", "bag": "02773838",
|
| 71 |
+
"cabinet": "02933112", "bed": "02818832", "birdhouse": "02843684", "display": "03211117",
|
| 72 |
+
"piano": "03928116", "earphone": "03261776", "telephone": "04401088", "stove": "04330267",
|
| 73 |
+
"microphone": "03759954", "bus": "02924116", "mug": "03797390", "remote": "04074963",
|
| 74 |
+
"bathtub": "02808440", "bowl": "02880940", "keyboard": "03085013", "guitar": "03467517",
|
| 75 |
+
"washer": "04554684", "bicycle": "02834778", "faucet": "03325088", "printer": "04004475",
|
| 76 |
+
"cap": "02954340", "phone": "02992529", "clock": "03046257", "helmet": "03513137",
|
| 77 |
"microwave": "03761084", "plant": "03991062"
|
| 78 |
}
|
| 79 |
|
|
|
|
| 88 |
self.available_categories = []
|
| 89 |
self.shape_thumbnail_cache = {} # Cache for shape thumbnails
|
| 90 |
self.CAT2NAME = {v: k for k, v in self.NAME2CAT.items()}
|
| 91 |
+
self.category_point_clouds = {}
|
| 92 |
# Initialize all models and data
|
| 93 |
self.initialize_models()
|
| 94 |
|
| 95 |
def draw_text(self, img, text, color=(10, 10, 10), size=80, location=(200, 30)):
|
| 96 |
img = img.copy()
|
| 97 |
draw = ImageDraw.Draw(img)
|
| 98 |
+
|
| 99 |
try:
|
| 100 |
font = ImageFont.truetype("Arial", size=size)
|
| 101 |
except IOError:
|
| 102 |
font = ImageFont.load_default()
|
| 103 |
+
|
| 104 |
bbox = draw.textbbox(location, text, font=font)
|
| 105 |
draw.rectangle(bbox, fill="white")
|
| 106 |
draw.text(location, text, color, font=font)
|
| 107 |
+
|
| 108 |
return img
|
| 109 |
|
| 110 |
def get_ulip_image(self, guidance_shape_id, angle='036'):
|
| 111 |
shape_id_ulip = guidance_shape_id.replace('_', '-')
|
| 112 |
ulip_template = 'https://storage.googleapis.com/sfr-ulip-code-release-research/shapenet-55/only_rgb_depth_images/{}_r_{}_depth0001.png'
|
| 113 |
ulip_path = ulip_template.format(shape_id_ulip, angle)
|
| 114 |
+
|
| 115 |
try:
|
| 116 |
ulip_image = load_image(ulip_path).resize((512, 512))
|
| 117 |
return ulip_image
|
|
|
|
| 119 |
print(f"Error loading image: {e}")
|
| 120 |
return Image.new('RGB', (512, 512), color='gray')
|
| 121 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 122 |
def initialize_models(self):
|
| 123 |
+
# device = DEVICE
|
| 124 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 125 |
+
print(f"Using device: {device} in initialize_models")
|
| 126 |
+
|
| 127 |
# Download Shape2CLIP code if it doesn't exist
|
| 128 |
if not os.path.exists("shapewords_paper_code"):
|
| 129 |
print("Loading models file")
|
| 130 |
os.system("git clone https://github.com/lodurality/shapewords_paper_code.git")
|
| 131 |
+
|
| 132 |
# Import Shape2CLIP model
|
| 133 |
sys.path.append("./shapewords_paper_code")
|
| 134 |
from shapewords_paper_code.geometry_guidance_models import Shape2CLIP
|
| 135 |
+
|
| 136 |
# Initialize the pipeline
|
| 137 |
self.pipeline = StableDiffusionPipeline.from_pretrained(
|
| 138 |
+
"stabilityai/stable-diffusion-2-1-base",
|
| 139 |
torch_dtype=torch.float16 if device.type == "cuda" else torch.float32
|
| 140 |
)
|
| 141 |
+
|
| 142 |
self.pipeline.scheduler = DPMSolverMultistepScheduler.from_config(
|
| 143 |
+
self.pipeline.scheduler.config,
|
| 144 |
algorithm_type="sde-dpmsolver++"
|
| 145 |
)
|
| 146 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 147 |
self.text_encoder = self.pipeline.text_encoder
|
| 148 |
self.tokenizer = self.pipeline.tokenizer
|
| 149 |
+
|
| 150 |
# Look for Shape2CLIP checkpoint in multiple locations
|
| 151 |
checkpoint_paths = [
|
| 152 |
+
"./projection_model-0920192.pth",
|
| 153 |
+
"/data/projection_model-0920192.pth" # if using Hugging Face persistent storage look in a /data/ directory
|
| 154 |
]
|
| 155 |
+
|
| 156 |
checkpoint_found = False
|
| 157 |
checkpoint_path = None
|
| 158 |
for path in checkpoint_paths:
|
|
|
|
| 161 |
print(f"Found Shape2CLIP checkpoint at: {checkpoint_path}")
|
| 162 |
checkpoint_found = True
|
| 163 |
break
|
| 164 |
+
|
| 165 |
# Download Shape2CLIP checkpoint if not found
|
| 166 |
if not checkpoint_found:
|
| 167 |
checkpoint_path = "projection_model-0920192.pth"
|
| 168 |
print("Downloading Shape2CLIP model checkpoint...")
|
| 169 |
+
gdown.download("https://drive.google.com/uc?id=1nvEXnwMpNkRts6rxVqMZt8i9FZ40KjP7", checkpoint_path, quiet=False) # download in same directory as app.py
|
| 170 |
print("Download complete")
|
| 171 |
+
|
| 172 |
# Initialize Shape2CLIP model
|
| 173 |
self.shape2clip_model = Shape2CLIP(depth=6, drop_path_rate=0.1, pb_dim=384)
|
| 174 |
self.shape2clip_model.load_state_dict(torch.load(checkpoint_path, map_location=device))
|
|
|
|
|
|
|
| 175 |
self.shape2clip_model.eval()
|
| 176 |
+
|
| 177 |
# Scan for available embeddings
|
| 178 |
self.scan_available_embeddings()
|
| 179 |
|
| 180 |
def scan_available_embeddings(self):
|
| 181 |
self.available_categories = []
|
| 182 |
self.category_counts = {}
|
| 183 |
+
|
| 184 |
+
# Try to find PointBert embeddings for all 55 ShapeNetCore shape categories
|
| 185 |
for category, cat_id in self.NAME2CAT.items():
|
| 186 |
possible_filenames = [
|
|
|
|
| 187 |
f"{cat_id}_pb_embs.npz",
|
| 188 |
+
f"embeddings/{cat_id}_pb_embs.npz",
|
| 189 |
+
f"/data/shapenet_pointbert_tokens/{cat_id}_pb_embs.npz" # if using Hugging Face persistent storage look in a /data/shapenet_pointbert_tokens directory
|
|
|
|
| 190 |
]
|
| 191 |
+
|
| 192 |
found_file = None
|
| 193 |
for filename in possible_filenames:
|
| 194 |
if os.path.exists(filename):
|
| 195 |
found_file = filename
|
| 196 |
break
|
| 197 |
+
|
| 198 |
if found_file:
|
| 199 |
try:
|
| 200 |
pb_data = np.load(found_file)
|
|
|
|
| 207 |
count = len(pb_data[keys[0]])
|
| 208 |
else:
|
| 209 |
count = 0
|
| 210 |
+
|
| 211 |
if count > 0:
|
| 212 |
self.available_categories.append(category)
|
| 213 |
self.category_counts[category] = count
|
| 214 |
print(f"Found {count} embeddings for category '{category}'")
|
| 215 |
except Exception as e:
|
| 216 |
print(f"Error loading embeddings for {category}: {e}")
|
| 217 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
| 218 |
# Sort categories alphabetically
|
| 219 |
self.available_categories.sort()
|
| 220 |
+
|
| 221 |
print(f"Found {len(self.available_categories)} categories with embeddings")
|
| 222 |
print(f"Available categories: {', '.join(self.available_categories)}")
|
| 223 |
+
|
| 224 |
+
# No embeddings found for any category - DEMO CANNOT RUN - but still load the interface with a default placeholder category, an error will be displayed when trying to generate images
|
| 225 |
+
if not self.available_categories:
|
| 226 |
+
self.available_categories = ["chair"] # Fallback
|
| 227 |
+
self.category_counts["chair"] = 50 # Default value
|
| 228 |
|
| 229 |
def load_category_embeddings(self, category):
|
| 230 |
if category in self.category_embeddings:
|
| 231 |
return self.category_embeddings[category]
|
| 232 |
+
|
| 233 |
if category not in self.NAME2CAT:
|
| 234 |
return None, []
|
| 235 |
+
|
| 236 |
cat_id = self.NAME2CAT[category]
|
| 237 |
+
|
| 238 |
# Check for different possible embedding filenames and locations
|
| 239 |
possible_filenames = [
|
| 240 |
+
f"{cat_id}_pb_embs.npz",
|
|
|
|
|
|
|
| 241 |
f"embeddings/{cat_id}_pb_embs.npz",
|
| 242 |
+
f"/data/shapenet_pointbert_tokens/{cat_id}_pb_embs.npz" # if using Hugging Face persistent storage look in a /data/shapenet_pointbert_tokens directory
|
| 243 |
]
|
| 244 |
+
|
| 245 |
# Find the first existing file
|
| 246 |
pb_emb_filename = None
|
| 247 |
for filename in possible_filenames:
|
|
|
|
| 249 |
pb_emb_filename = filename
|
| 250 |
print(f"Found embeddings file: {pb_emb_filename}")
|
| 251 |
break
|
| 252 |
+
|
| 253 |
if pb_emb_filename is None:
|
| 254 |
print(f"No embeddings found for {category}")
|
| 255 |
return None, []
|
| 256 |
+
|
| 257 |
# Load embeddings
|
| 258 |
try:
|
| 259 |
print(f"Loading embeddings from {pb_emb_filename}...")
|
| 260 |
pb_data = np.load(pb_emb_filename)
|
| 261 |
+
|
| 262 |
# Check for different key names in the NPZ file
|
| 263 |
if 'ids' in pb_data and 'embs' in pb_data:
|
| 264 |
pb_dict = dict(zip(pb_data['ids'], pb_data['embs']))
|
|
|
|
| 271 |
else:
|
| 272 |
print("Unexpected embedding file format")
|
| 273 |
return None, []
|
| 274 |
+
|
| 275 |
all_ids = sorted(list(pb_dict.keys()))
|
| 276 |
print(f"Loaded {len(all_ids)} shape embeddings for {category}")
|
| 277 |
+
|
| 278 |
# Cache the results
|
| 279 |
self.category_embeddings[category] = (pb_dict, all_ids)
|
| 280 |
return pb_dict, all_ids
|
|
|
|
| 283 |
print(f"Exception details: {str(e)}")
|
| 284 |
return None, []
|
| 285 |
|
| 286 |
+
def load_category_point_clouds(self, category):
|
| 287 |
+
"""Load all point clouds for a category from a single NPZ file"""
|
| 288 |
+
if category not in self.NAME2CAT:
|
| 289 |
+
return None
|
| 290 |
+
|
| 291 |
+
cat_id = self.NAME2CAT[category]
|
| 292 |
+
|
| 293 |
+
# Cache to avoid reloading
|
| 294 |
+
if category in self.category_point_clouds:
|
| 295 |
+
return self.category_point_clouds[category]
|
| 296 |
+
|
| 297 |
+
# Check for different possible point cloud filenames
|
| 298 |
+
possible_filenames = [
|
| 299 |
+
f"{cat_id}.npz",
|
| 300 |
+
f"point_clouds/{cat_id}_clouds.npz",
|
| 301 |
+
f"/point_clouds/{cat_id}_clouds.npz",
|
| 302 |
+
f"/data/point_clouds/{cat_id}_clouds.npz" # For Hugging Face persistent storage
|
| 303 |
+
]
|
| 304 |
+
|
| 305 |
+
# Find the first existing file
|
| 306 |
+
pc_filename = None
|
| 307 |
+
for filename in possible_filenames:
|
| 308 |
+
if os.path.exists(filename):
|
| 309 |
+
pc_filename = filename
|
| 310 |
+
print(f"Found point cloud file: {pc_filename}")
|
| 311 |
+
break
|
| 312 |
+
|
| 313 |
+
if pc_filename is None:
|
| 314 |
+
print(f"No point cloud file found for category {category}")
|
| 315 |
+
return None
|
| 316 |
+
|
| 317 |
+
# Load point clouds
|
| 318 |
+
try:
|
| 319 |
+
print(f"Loading point clouds from {pc_filename}...")
|
| 320 |
+
pc_data = np.load(pc_filename, allow_pickle=True)
|
| 321 |
+
|
| 322 |
+
# Cache the loaded data
|
| 323 |
+
self.category_point_clouds[category] = pc_data
|
| 324 |
+
|
| 325 |
+
return pc_data
|
| 326 |
+
except Exception as e:
|
| 327 |
+
print(f"Error loading point clouds: {e}")
|
| 328 |
+
return None
|
| 329 |
+
|
| 330 |
+
def get_shape_preview(self, category, shape_idx):
|
| 331 |
+
"""Get a 3D point cloud visualization for a specific shape"""
|
| 332 |
if shape_idx is None or shape_idx < 0:
|
| 333 |
return None
|
| 334 |
+
|
| 335 |
+
# Get shape ID
|
| 336 |
pb_dict, all_ids = self.load_category_embeddings(category)
|
| 337 |
if pb_dict is None or not all_ids or shape_idx >= len(all_ids):
|
| 338 |
return None
|
| 339 |
+
|
| 340 |
shape_id = all_ids[shape_idx]
|
| 341 |
+
|
| 342 |
+
# Load all point clouds for this category
|
| 343 |
+
pc_data = self.load_category_point_clouds(category)
|
| 344 |
+
if pc_data is None:
|
| 345 |
+
# Fallback to image if point clouds not available
|
| 346 |
+
return self.get_shape_image_preview(category, shape_idx, shape_id)
|
| 347 |
+
|
| 348 |
+
# Extract point cloud for this specific shape
|
| 349 |
+
try:
|
| 350 |
+
# Get the arrays from the npz file
|
| 351 |
+
ids = pc_data['ids']
|
| 352 |
+
clouds = pc_data['clouds']
|
| 353 |
+
|
| 354 |
+
matching_indices = np.where(ids == shape_id)[0]
|
| 355 |
+
|
| 356 |
+
# Check number of matches
|
| 357 |
+
if len(matching_indices) == 0:
|
| 358 |
+
# No matches found - log error and fall back to image
|
| 359 |
+
print(f"Error: Shape ID {shape_id} not found in point cloud data")
|
| 360 |
+
return self.get_shape_image_preview(category, shape_idx, shape_id)
|
| 361 |
+
elif len(matching_indices) > 1:
|
| 362 |
+
# Multiple matches found - unexpected data issue - we will get the first one
|
| 363 |
+
print(f"Warning: Multiple matches ({len(matching_indices)}) found for Shape ID {shape_id}. Using first match.")
|
| 364 |
+
|
| 365 |
+
# Get the corresponding point cloud
|
| 366 |
+
matching_idx = matching_indices[0]
|
| 367 |
+
points = clouds[matching_idx]
|
| 368 |
+
|
| 369 |
+
# Create 3D visualization
|
| 370 |
+
fig = self.get_shape_pointcloud_preview(points, title=f"Shape #{shape_idx}")
|
| 371 |
+
return fig
|
| 372 |
+
|
| 373 |
+
except Exception as e:
|
| 374 |
+
print(f"Error extracting point cloud for {shape_id}: {e}")
|
| 375 |
+
return self.get_shape_image_preview(category, shape_idx, shape_id)
|
| 376 |
+
|
| 377 |
+
def get_shape_image_preview(self, category, shape_idx, shape_id):
|
| 378 |
+
"""Fallback to image preview if point cloud not available"""
|
| 379 |
try:
|
|
|
|
| 380 |
preview_image = self.get_ulip_image(shape_id)
|
| 381 |
+
preview_image = preview_image.resize((300, 300))
|
| 382 |
+
preview_with_text = self.draw_text(preview_image, f"Shape #{shape_idx}", size=80, location=(10, 10))
|
| 383 |
+
|
| 384 |
+
# Convert PIL image to plotly figure
|
| 385 |
+
fig = go.Figure()
|
| 386 |
+
|
| 387 |
+
# Need to convert PIL image to a format plotly can use
|
| 388 |
+
import io
|
| 389 |
+
import base64
|
| 390 |
+
|
| 391 |
+
# Convert PIL image to base64
|
| 392 |
+
buf = io.BytesIO()
|
| 393 |
+
preview_with_text.save(buf, format='PNG')
|
| 394 |
+
img_str = base64.b64encode(buf.getvalue()).decode('utf-8')
|
| 395 |
+
|
| 396 |
+
# Add image to figure
|
| 397 |
+
fig.add_layout_image(
|
| 398 |
+
dict(
|
| 399 |
+
source=f"data:image/png;base64,{img_str}",
|
| 400 |
+
xref="paper", yref="paper",
|
| 401 |
+
x=0, y=1,
|
| 402 |
+
sizex=1, sizey=1,
|
| 403 |
+
sizing="contain",
|
| 404 |
+
layer="below"
|
| 405 |
+
)
|
| 406 |
+
)
|
| 407 |
+
|
| 408 |
+
fig.update_layout(
|
| 409 |
+
title=f"Shape #{shape_idx} (2D Preview - 3D not available)",
|
| 410 |
+
xaxis=dict(showgrid=False, zeroline=False, visible=False, range=[0, 1]),
|
| 411 |
+
yaxis=dict(showgrid=False, zeroline=False, visible=False, range=[0, 1], scaleanchor="x", scaleratio=1),
|
| 412 |
+
margin=dict(l=0, r=0, b=0, t=0),
|
| 413 |
+
height=450,
|
| 414 |
+
width=450,
|
| 415 |
+
plot_bgcolor='rgba(0,0,0,0)' # Transparent background
|
| 416 |
+
)
|
| 417 |
+
|
| 418 |
+
return fig
|
| 419 |
except Exception as e:
|
| 420 |
print(f"Error loading preview for {shape_id}: {e}")
|
| 421 |
+
# Create empty figure with error message
|
| 422 |
+
fig = go.Figure()
|
| 423 |
+
fig.update_layout(
|
| 424 |
+
title=f"Error loading Shape #{shape_idx}",
|
| 425 |
+
annotations=[dict(
|
| 426 |
+
text="Preview not available",
|
| 427 |
+
showarrow=False,
|
| 428 |
+
xref="paper", yref="paper",
|
| 429 |
+
x=0.5, y=0.5,
|
| 430 |
+
ont=dict(size=16, color="#E53935"), # Red error text
|
| 431 |
+
align="center"
|
| 432 |
+
)],
|
| 433 |
+
height=450,
|
| 434 |
+
width=450,
|
| 435 |
+
margin=dict(l=0, r=0, b=0, t=30, pad=0),
|
| 436 |
+
paper_bgcolor='rgba(0,0,0,0)',
|
| 437 |
+
plot_bgcolor='rgba(0,0,0,0)' # Transparent background
|
| 438 |
+
)
|
| 439 |
+
return fig
|
| 440 |
+
|
| 441 |
+
def get_shape_pointcloud_preview(self, points, title=None):
|
| 442 |
+
"""Create a clean 3D point cloud visualization with Y as up axis"""
|
| 443 |
+
# Sample points for better performance (fewer points = smoother interaction)
|
| 444 |
+
sampled_points = points[::1] # Take every Nth point
|
| 445 |
+
|
| 446 |
+
# Create 3D scatter plot with fixed color
|
| 447 |
+
fig = go.Figure(data=[go.Scatter3d(
|
| 448 |
+
x=sampled_points[:, 0],
|
| 449 |
+
y=sampled_points[:, 1], # Use Z as Y (up axis)
|
| 450 |
+
z=sampled_points[:, 2], # Use Y as Z
|
| 451 |
+
mode='markers',
|
| 452 |
+
marker=dict(
|
| 453 |
+
size=2.5,
|
| 454 |
+
color='#4285F4', # Fixed blue color
|
| 455 |
+
opacity=1
|
| 456 |
+
)
|
| 457 |
+
)])
|
| 458 |
+
|
| 459 |
+
fig.update_layout(
|
| 460 |
+
title=dict(text=title,
|
| 461 |
+
xanchor='center',
|
| 462 |
+
x=0.5
|
| 463 |
+
),
|
| 464 |
+
scene=dict(
|
| 465 |
+
# Remove all axes elements
|
| 466 |
+
xaxis=dict(visible=False, showticklabels=False, showgrid=False, zeroline=False, showline=False,
|
| 467 |
+
showbackground=False),
|
| 468 |
+
yaxis=dict(visible=False, showticklabels=False, showgrid=False, zeroline=False, showline=False,
|
| 469 |
+
showbackground=False),
|
| 470 |
+
zaxis=dict(visible=False, showticklabels=False, showgrid=False, zeroline=False, showline=False,
|
| 471 |
+
showbackground=False),
|
| 472 |
+
aspectmode='data' # Maintain data aspect ratio
|
| 473 |
+
),
|
| 474 |
+
# Eliminate margins
|
| 475 |
+
margin=dict(l=0, r=0, b=0, t=30, pad=0),
|
| 476 |
+
autosize=True,
|
| 477 |
+
# Control modebar appearance through layout
|
| 478 |
+
modebar=dict(
|
| 479 |
+
bgcolor='white',
|
| 480 |
+
color='#333',
|
| 481 |
+
orientation='v', # Vertical orientation
|
| 482 |
+
activecolor='#009688'
|
| 483 |
+
),
|
| 484 |
+
paper_bgcolor='rgba(0,0,0,0)', # Transparent background
|
| 485 |
+
)
|
| 486 |
+
|
| 487 |
+
# Better camera angle
|
| 488 |
+
fig.update_layout(
|
| 489 |
+
scene_camera=dict(
|
| 490 |
+
eye=dict(x=-1.5, y=0.5, z=-1.5),
|
| 491 |
+
up=dict(x=0, y=1, z=0), # Y is up
|
| 492 |
+
center=dict(x=0, y=0, z=0)
|
| 493 |
+
)
|
| 494 |
+
)
|
| 495 |
+
|
| 496 |
+
return fig
|
| 497 |
|
| 498 |
def on_slider_change(self, shape_idx, category):
|
| 499 |
"""Update the preview when the slider changes"""
|
| 500 |
max_idx = self.category_counts.get(category, 0) - 1
|
| 501 |
+
|
| 502 |
# Get preview image
|
| 503 |
preview_image = self.get_shape_preview(category, shape_idx)
|
| 504 |
+
|
| 505 |
# Update counter text
|
| 506 |
counter_text = f"Shape {shape_idx} of {max_idx}"
|
| 507 |
+
|
| 508 |
return preview_image, counter_text, shape_idx
|
| 509 |
|
| 510 |
def prev_shape(self, current_idx, category):
|
| 511 |
"""Go to previous shape"""
|
| 512 |
max_idx = self.category_counts.get(category, 0) - 1
|
| 513 |
new_idx = max(0, current_idx - 1)
|
| 514 |
+
|
| 515 |
# Get preview image
|
| 516 |
preview_image = self.get_shape_preview(category, new_idx)
|
| 517 |
+
|
| 518 |
# Update counter text
|
| 519 |
counter_text = f"Shape {new_idx} of {max_idx}"
|
| 520 |
+
|
| 521 |
return new_idx, preview_image, counter_text
|
| 522 |
|
| 523 |
def next_shape(self, current_idx, category):
|
| 524 |
"""Go to next shape"""
|
| 525 |
max_idx = self.category_counts.get(category, 0) - 1
|
| 526 |
new_idx = min(max_idx, current_idx + 1)
|
| 527 |
+
|
| 528 |
# Get preview image
|
| 529 |
preview_image = self.get_shape_preview(category, new_idx)
|
| 530 |
+
|
| 531 |
# Update counter text
|
| 532 |
counter_text = f"Shape {new_idx} of {max_idx}"
|
| 533 |
+
|
| 534 |
return new_idx, preview_image, counter_text
|
| 535 |
|
| 536 |
def jump_to_start(self, category):
|
| 537 |
"""Jump to the first shape"""
|
| 538 |
max_idx = self.category_counts.get(category, 0) - 1
|
| 539 |
new_idx = 0
|
| 540 |
+
|
| 541 |
# Get preview image
|
| 542 |
preview_image = self.get_shape_preview(category, new_idx)
|
| 543 |
+
|
| 544 |
# Update counter text
|
| 545 |
counter_text = f"Shape {new_idx} of {max_idx}"
|
| 546 |
+
|
| 547 |
return new_idx, preview_image, counter_text
|
| 548 |
|
| 549 |
def jump_to_end(self, category):
|
| 550 |
"""Jump to the last shape"""
|
| 551 |
max_idx = self.category_counts.get(category, 0) - 1
|
| 552 |
new_idx = max_idx
|
| 553 |
+
|
| 554 |
# Get preview image
|
| 555 |
preview_image = self.get_shape_preview(category, new_idx)
|
| 556 |
+
|
| 557 |
# Update counter text
|
| 558 |
counter_text = f"Shape {new_idx} of {max_idx}"
|
| 559 |
+
|
| 560 |
return new_idx, preview_image, counter_text
|
| 561 |
|
| 562 |
def random_shape(self, category):
|
|
|
|
| 564 |
max_idx = self.category_counts.get(category, 0) - 1
|
| 565 |
if max_idx <= 0:
|
| 566 |
return 0, self.get_shape_preview(category, 0), f"Shape 0 of 0"
|
| 567 |
+
|
| 568 |
# Generate random index
|
| 569 |
random_idx = random.randint(0, max_idx)
|
| 570 |
+
|
| 571 |
# Get preview image
|
| 572 |
preview_image = self.get_shape_preview(category, random_idx)
|
| 573 |
+
|
| 574 |
# Update counter text
|
| 575 |
counter_text = f"Shape {random_idx} of {max_idx}"
|
| 576 |
+
|
| 577 |
return random_idx, preview_image, counter_text
|
| 578 |
|
| 579 |
+
def random_prompt(self):
|
| 580 |
+
"""Select a random prompt from the predefined list"""
|
| 581 |
+
prompts = [
|
| 582 |
+
'a low poly 3d rendering of a [CATEGORY]',
|
| 583 |
+
'an aquarelle drawing of a [CATEGORY]',
|
| 584 |
+
'a photo of a [CATEGORY] on a beach',
|
| 585 |
+
'a charcoal drawing of a [CATEGORY]',
|
| 586 |
+
'a Hieronymus Bosch painting of a [CATEGORY]',
|
| 587 |
+
'a [CATEGORY] under a tree',
|
| 588 |
+
'A Kazimir Malevich painting of a [CATEGORY]',
|
| 589 |
+
'a vector graphic of a [CATEGORY]',
|
| 590 |
+
'a Claude Monet painting of a [CATEGORY]',
|
| 591 |
+
'a Salvador Dali painting of a [CATEGORY]',
|
| 592 |
+
'an Art Deco poster of a [CATEGORY]'
|
| 593 |
+
]
|
| 594 |
+
|
| 595 |
+
# Get a random prompt
|
| 596 |
+
return random.choice(prompts)
|
| 597 |
+
|
| 598 |
def on_category_change(self, category):
|
| 599 |
"""Update the slider and preview when the category changes"""
|
| 600 |
# Reset to the first shape
|
| 601 |
current_idx = 0
|
| 602 |
max_idx = self.category_counts.get(category, 0) - 1
|
| 603 |
+
|
| 604 |
# Get preview image
|
| 605 |
preview_image = self.get_shape_preview(category, current_idx)
|
| 606 |
+
|
| 607 |
# Update counter text
|
| 608 |
counter_text = f"Shape {current_idx} of {max_idx}"
|
| 609 |
+
|
| 610 |
# Need to update the slider range
|
| 611 |
new_slider = gr.Slider(
|
| 612 |
minimum=0,
|
|
|
|
| 615 |
value=current_idx,
|
| 616 |
label="Shape Index"
|
| 617 |
)
|
| 618 |
+
|
| 619 |
return new_slider, current_idx, preview_image, counter_text
|
| 620 |
|
| 621 |
def get_guidance(self, test_prompt, category_name, guidance_emb):
|
|
|
|
| 622 |
print(test_prompt, category_name)
|
| 623 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 624 |
+
print(f"Using device: {device} in get_guidance")
|
| 625 |
+
|
| 626 |
prompt_tokens = torch.LongTensor(self.tokenizer.encode(test_prompt, padding='max_length')).to(device)
|
| 627 |
+
|
| 628 |
with torch.no_grad():
|
| 629 |
out = self.text_encoder(prompt_tokens.unsqueeze(0), output_attentions=True)
|
| 630 |
prompt_emb = out.last_hidden_state.detach().clone()
|
| 631 |
+
|
|
|
|
| 632 |
if len(guidance_emb.shape) == 1:
|
| 633 |
guidance_emb = torch.FloatTensor(guidance_emb).unsqueeze(0).unsqueeze(0)
|
| 634 |
else:
|
|
|
|
| 647 |
with torch.no_grad():
|
| 648 |
guided_prompt_emb_cond = self.shape2clip_model(prompt_emb.float(), guidance_emb[:,:,:].float()).half()
|
| 649 |
guided_prompt_emb = guided_prompt_emb_cond.clone()
|
| 650 |
+
|
| 651 |
guided_prompt_emb[:,:1] = 0
|
| 652 |
guided_prompt_emb[:,:chair_inds] = 0
|
| 653 |
guided_prompt_emb[:,chair_inds] *= obj_strength
|
|
|
|
| 658 |
|
| 659 |
return fin_guidance, prompt_emb
|
| 660 |
|
|
|
|
| 661 |
@spaces.GPU(duration=120)
|
| 662 |
def generate_images(self, prompt, category, selected_shape_idx, guidance_strength, seed):
|
| 663 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 664 |
+
print(f"Using device: {device} in generate_images")
|
| 665 |
+
|
| 666 |
+
# Move models to gpu
|
| 667 |
+
if device.type == "cuda":
|
| 668 |
+
self.pipeline = self.pipeline.to(device)
|
| 669 |
+
self.shape2clip_model = self.shape2clip_model.to(device)
|
| 670 |
+
|
| 671 |
# Clear status text immediately
|
| 672 |
status = ""
|
| 673 |
+
|
| 674 |
+
# Replace [CATEGORY] with the selected category (case-insensitive)
|
| 675 |
+
category_pattern = re.compile(r'\[CATEGORY\]', re.IGNORECASE)
|
| 676 |
+
if re.search(category_pattern, prompt):
|
| 677 |
+
# Use re.sub for replacement to maintain the same casing pattern that was used
|
| 678 |
+
final_prompt = re.sub(category_pattern, category, prompt)
|
| 679 |
+
else:
|
| 680 |
+
# Fallback if user didn't use placeholder
|
| 681 |
+
final_prompt = f"{prompt} {category}"
|
| 682 |
+
status = status + f"<div style='padding: 10px; background-color: #f0f7ff; border-left: 5px solid #3498db; margin-bottom: 10px;'>Note: For better results, use [CATEGORY] in your prompt where you want '{category}' to appear, otherwise it is appended at the end of the prompt.</div>"
|
| 683 |
+
|
| 684 |
+
error = False
|
| 685 |
+
# Check if prompt contains any other categories
|
| 686 |
for other_category in self.available_categories:
|
| 687 |
+
if re.search(r'\b' + re.escape(other_category) + r'\b', prompt):
|
| 688 |
+
status = status + f"<div style='padding: 10px; background-color: #ffebee; border-left: 5px solid #e74c3c; font-weight: bold; margin-bottom: 10px;'>⚠️ ERROR: Your prompt contains '{other_category}'. Please remove it and use [CATEGORY] instead.</div>"
|
| 689 |
+
error = True
|
| 690 |
+
if error:
|
| 691 |
+
return [], status
|
| 692 |
+
|
| 693 |
# Load category embeddings if not already loaded
|
| 694 |
pb_dict, all_ids = self.load_category_embeddings(category)
|
| 695 |
if pb_dict is None or not all_ids:
|
| 696 |
+
status = status + f"<div style='padding: 10px; background-color: #ffebee; border-left: 5px solid #e74c3c; font-weight: bold; margin-bottom: 10px;'>⚠️ ERROR: Failed to load embeddings for {category}</div>"
|
| 697 |
+
return [], status
|
| 698 |
+
|
| 699 |
# Ensure shape index is valid
|
| 700 |
if selected_shape_idx is None or selected_shape_idx < 0:
|
| 701 |
selected_shape_idx = 0
|
| 702 |
+
|
| 703 |
max_idx = len(all_ids) - 1
|
| 704 |
selected_shape_idx = max(0, min(selected_shape_idx, max_idx))
|
| 705 |
guidance_shape_id = all_ids[selected_shape_idx]
|
| 706 |
+
|
| 707 |
+
# Set generator
|
|
|
|
| 708 |
generator = torch.Generator(device=device).manual_seed(seed)
|
| 709 |
+
|
| 710 |
results = []
|
| 711 |
+
|
|
|
|
|
|
|
|
|
|
| 712 |
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 713 |
# Generate base image (without guidance)
|
| 714 |
with torch.no_grad():
|
| 715 |
base_images = self.pipeline(
|
| 716 |
+
prompt=final_prompt,
|
| 717 |
num_inference_steps=50,
|
| 718 |
num_images_per_prompt=1,
|
| 719 |
generator=generator,
|
| 720 |
guidance_scale=7.5
|
| 721 |
).images
|
| 722 |
+
|
| 723 |
base_image = base_images[0]
|
| 724 |
base_image = self.draw_text(base_image, "Unguided result")
|
| 725 |
results.append(base_image)
|
| 726 |
except Exception as e:
|
| 727 |
print(f"Error generating base image: {e}")
|
| 728 |
+
status = status + f"<div style='padding: 10px; background-color: #ffebee; border-left: 5px solid #e74c3c; font-weight: bold; margin-bottom: 10px;'>⚠️ ERROR: Error generating base image: {str(e)}</div>"
|
| 729 |
return results, status
|
| 730 |
+
|
| 731 |
try:
|
| 732 |
# Get shape guidance image
|
| 733 |
ulip_image = self.get_ulip_image(guidance_shape_id)
|
|
|
|
| 735 |
results.append(ulip_image)
|
| 736 |
except Exception as e:
|
| 737 |
print(f"Error getting guidance shape: {e}")
|
| 738 |
+
status = status + f"<div style='padding: 10px; background-color: #ffebee; border-left: 5px solid #e74c3c; font-weight: bold; margin-bottom: 10px;'>⚠️ ERROR: Error getting guidance shape: {str(e)}</div>"
|
| 739 |
return results, status
|
| 740 |
+
|
| 741 |
try:
|
| 742 |
# Get shape guidance embedding
|
| 743 |
pb_emb = pb_dict[guidance_shape_id]
|
| 744 |
+
out_guidance, prompt_emb = self.get_guidance(final_prompt, category, pb_emb)
|
|
|
|
| 745 |
except Exception as e:
|
| 746 |
print(f"Error generating guidance: {e}")
|
| 747 |
+
status = status + f"<div style='padding: 10px; background-color: #ffebee; border-left: 5px solid #e74c3c; font-weight: bold; margin-bottom: 10px;'>⚠️ ERROR: Error generating guidance: {str(e)}</div>"
|
| 748 |
return results, status
|
| 749 |
+
|
| 750 |
try:
|
| 751 |
# Generate guided image
|
| 752 |
generator = torch.Generator(device=device).manual_seed(seed)
|
|
|
|
| 758 |
generator=generator,
|
| 759 |
guidance_scale=7.5
|
| 760 |
).images
|
| 761 |
+
|
| 762 |
guided_image = guided_images[0]
|
| 763 |
guided_image = self.draw_text(guided_image, f"Guided result (λ={guidance_strength:.1f})")
|
| 764 |
results.append(guided_image)
|
| 765 |
+
|
| 766 |
# Success status
|
| 767 |
+
status = status + f"<div style='padding: 10px; background-color: #e8f5e9; border-left: 5px solid #4caf50; margin-bottom: 10px;'>✓ Successfully generated images using Shape #{selected_shape_idx} from category '{category}'.</div>"
|
| 768 |
+
|
| 769 |
+
torch.cuda.empty_cache()
|
| 770 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
| 771 |
except Exception as e:
|
| 772 |
print(f"Error generating guided image: {e}")
|
| 773 |
+
status = status + f"<div style='padding: 10px; background-color: #ffebee; border-left: 5px solid #e74c3c; font-weight: bold; margin-bottom: 10px;'>⚠️ ERROR: Error generating guided image: {str(e)}</div>"
|
|
|
|
|
|
|
| 774 |
|
| 775 |
+
return results, status
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 776 |
|
| 777 |
def on_demo_load(self):
|
| 778 |
"""Function to ensure initial image is loaded when demo starts"""
|
|
|
|
| 783 |
def create_ui(self):
|
| 784 |
# Ensure chair is in available categories, otherwise use the first available
|
| 785 |
default_category = "chair" if "chair" in self.available_categories else self.available_categories[0]
|
| 786 |
+
|
| 787 |
with gr.Blocks(title="ShapeWords: Guiding Text-to-Image Synthesis with 3D Shape-Aware Prompts") as demo:
|
| 788 |
gr.Markdown("""
|
| 789 |
# ShapeWords: Guiding Text-to-Image Synthesis with 3D Shape-Aware Prompts
|
|
|
|
| 794 |
- **Paper**: [ArXiv](https://arxiv.org/abs/2412.02912)
|
| 795 |
- **Publication**: Accepted to CVPR 2025
|
| 796 |
""")
|
| 797 |
+
|
| 798 |
with gr.Row():
|
| 799 |
with gr.Column(scale=1):
|
| 800 |
prompt = gr.Textbox(
|
| 801 |
+
label="Prompt (use [CATEGORY] for object type)",
|
| 802 |
+
placeholder="an aquarelle drawing of a [CATEGORY]",
|
| 803 |
+
value=f"an aquarelle drawing of a [CATEGORY]"
|
| 804 |
)
|
| 805 |
+
|
| 806 |
+
# Add help text below the prompt
|
| 807 |
+
help_text = gr.Markdown("""
|
| 808 |
+
**Tip:** Use [CATEGORY] in your prompt where you want the selected object type to appear.
|
| 809 |
+
For example: "a watercolor painting of a [CATEGORY] in the forest"
|
| 810 |
+
""")
|
| 811 |
+
|
| 812 |
+
random_prompt_btn = gr.Button("🎲 Random Prompt", size="sm", variant="secondary")
|
| 813 |
+
|
| 814 |
category = gr.Dropdown(
|
| 815 |
+
label="Object Category",
|
| 816 |
choices=self.available_categories,
|
| 817 |
value=default_category
|
| 818 |
)
|
| 819 |
+
|
| 820 |
# Hidden field to store selected shape index
|
| 821 |
selected_shape_idx = gr.Number(
|
| 822 |
value=0,
|
| 823 |
visible=False
|
| 824 |
)
|
| 825 |
+
|
| 826 |
# Create a slider for shape selection with preview
|
| 827 |
with gr.Row():
|
| 828 |
with gr.Column(scale=1):
|
|
|
|
| 835 |
label="Shape Index",
|
| 836 |
interactive=True
|
| 837 |
)
|
| 838 |
+
|
| 839 |
# Display shape index counter
|
| 840 |
shape_counter = gr.Markdown(f"Shape 0 of {self.category_counts.get(default_category, 0) - 1}")
|
| 841 |
+
|
| 842 |
# Quick navigation buttons
|
| 843 |
with gr.Row():
|
| 844 |
jump_start_btn = gr.Button("⏮️ First", size="sm")
|
| 845 |
+
random_btn = gr.Button("🎲 Random Shape", size="sm", variant="secondary")
|
| 846 |
jump_end_btn = gr.Button("Last ⏭️", size="sm")
|
| 847 |
+
|
| 848 |
with gr.Row():
|
| 849 |
prev_shape_btn = gr.Button("◀️ Previous", size="sm")
|
| 850 |
next_shape_btn = gr.Button("Next ▶️", size="sm")
|
| 851 |
+
|
| 852 |
with gr.Column(scale=1):
|
| 853 |
+
gr.Markdown("### Selected Shape (3D Point Cloud)")
|
| 854 |
+
current_shape_plot = gr.Plot(
|
| 855 |
+
label=None,
|
| 856 |
+
scale=1, # Take up available space
|
| 857 |
+
show_label=False,
|
| 858 |
+
#container=False
|
| 859 |
)
|
| 860 |
+
|
| 861 |
guidance_strength = gr.Slider(
|
| 862 |
minimum=0.0, maximum=1.0, step=0.1, value=0.9,
|
| 863 |
label="Guidance Strength (λ)"
|
| 864 |
)
|
| 865 |
+
|
| 866 |
seed = gr.Slider(
|
| 867 |
minimum=0, maximum=10000, step=1, value=42,
|
| 868 |
label="Random Seed"
|
| 869 |
)
|
| 870 |
+
|
| 871 |
run_button = gr.Button("Generate Images", variant="primary")
|
| 872 |
+
|
| 873 |
info = gr.Markdown("""
|
| 874 |
**Note**: Higher guidance strength (λ) means stronger adherence to the 3D shape.
|
| 875 |
Start with λ=0.9 for a good balance between shape and prompt adherence.
|
| 876 |
""")
|
| 877 |
+
|
| 878 |
status_text = gr.HTML("")
|
| 879 |
+
|
| 880 |
with gr.Column(scale=2):
|
| 881 |
gallery = gr.Gallery(
|
| 882 |
label="Results",
|
|
|
|
| 885 |
columns=3,
|
| 886 |
height="auto"
|
| 887 |
)
|
| 888 |
+
|
| 889 |
# Make sure the initial image is loaded when the demo starts
|
| 890 |
demo.load(
|
| 891 |
fn=self.on_demo_load,
|
| 892 |
inputs=None,
|
| 893 |
+
outputs=[current_shape_plot]
|
| 894 |
)
|
| 895 |
+
|
| 896 |
# Connect slider to update preview
|
| 897 |
shape_slider.change(
|
| 898 |
fn=self.on_slider_change,
|
| 899 |
inputs=[shape_slider, category],
|
| 900 |
+
outputs=[current_shape_plot, shape_counter, selected_shape_idx]
|
| 901 |
)
|
| 902 |
+
|
| 903 |
# Previous shape button
|
| 904 |
prev_shape_btn.click(
|
| 905 |
fn=self.prev_shape,
|
| 906 |
inputs=[selected_shape_idx, category],
|
| 907 |
+
outputs=[shape_slider, current_shape_plot, shape_counter]
|
| 908 |
)
|
| 909 |
+
|
| 910 |
# Next shape button
|
| 911 |
next_shape_btn.click(
|
| 912 |
fn=self.next_shape,
|
| 913 |
inputs=[selected_shape_idx, category],
|
| 914 |
+
outputs=[shape_slider, current_shape_plot, shape_counter]
|
| 915 |
)
|
| 916 |
+
|
| 917 |
# Jump to start button
|
| 918 |
jump_start_btn.click(
|
| 919 |
fn=self.jump_to_start,
|
| 920 |
inputs=[category],
|
| 921 |
+
outputs=[shape_slider, current_shape_plot, shape_counter]
|
| 922 |
)
|
| 923 |
+
|
| 924 |
# Jump to end button
|
| 925 |
jump_end_btn.click(
|
| 926 |
fn=self.jump_to_end,
|
| 927 |
inputs=[category],
|
| 928 |
+
outputs=[shape_slider, current_shape_plot, shape_counter]
|
| 929 |
)
|
| 930 |
+
|
| 931 |
# Random shape button
|
| 932 |
random_btn.click(
|
| 933 |
fn=self.random_shape,
|
| 934 |
inputs=[category],
|
| 935 |
+
outputs=[shape_slider, current_shape_plot, shape_counter]
|
| 936 |
)
|
| 937 |
+
|
| 938 |
+
# Connect the random prompt button
|
| 939 |
+
random_prompt_btn.click(
|
| 940 |
+
fn=self.random_prompt,
|
| 941 |
+
inputs=[],
|
| 942 |
+
outputs=[prompt]
|
| 943 |
+
)
|
| 944 |
+
|
| 945 |
# Update the UI when category changes
|
| 946 |
category.change(
|
| 947 |
fn=self.on_category_change,
|
| 948 |
inputs=[category],
|
| 949 |
+
outputs=[shape_slider, selected_shape_idx, current_shape_plot, shape_counter]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 950 |
)
|
| 951 |
+
|
| 952 |
# Clear status text before generating new images
|
| 953 |
run_button.click(
|
| 954 |
fn=lambda: None, # Empty function to clear the status
|
| 955 |
inputs=None,
|
| 956 |
outputs=[status_text]
|
| 957 |
)
|
| 958 |
+
|
| 959 |
# Generate images when button is clicked
|
| 960 |
run_button.click(
|
| 961 |
fn=self.generate_images,
|
| 962 |
inputs=[prompt, category, selected_shape_idx, guidance_strength, seed],
|
| 963 |
outputs=[gallery, status_text]
|
| 964 |
)
|
| 965 |
+
|
| 966 |
gr.Markdown("""
|
| 967 |
## Credits
|
| 968 |
|
|
|
|
| 981 |
}
|
| 982 |
```
|
| 983 |
""")
|
| 984 |
+
|
| 985 |
return demo
|
| 986 |
|
| 987 |
|
|
|
|
| 990 |
parser = argparse.ArgumentParser(description="ShapeWords Gradio Demo")
|
| 991 |
parser.add_argument('--share', action='store_true', help='Create a public link')
|
| 992 |
args = parser.parse_args()
|
| 993 |
+
|
| 994 |
# Create the demo app and UI
|
| 995 |
app = ShapeWordsDemo()
|
| 996 |
demo = app.create_ui()
|
| 997 |
demo.launch(share=args.share)
|
| 998 |
|
| 999 |
+
|
| 1000 |
if __name__ == "__main__":
|
| 1001 |
+
main()
|
shapewords_paper_code
CHANGED
|
@@ -1 +1 @@
|
|
| 1 |
-
Subproject commit
|
|
|
|
| 1 |
+
Subproject commit e4ebe6c6541505c2e7bc1068186f7045b1bfb51a
|