Create app.py
Browse files
app.py
ADDED
|
@@ -0,0 +1,419 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
import os
|
| 3 |
+
import zipfile
|
| 4 |
+
import time
|
| 5 |
+
import uuid # For unique filenames
|
| 6 |
+
|
| 7 |
+
# --- LLM/Model Setup ---
|
| 8 |
+
from transformers import pipeline as transformers_pipeline # For local list generation
|
| 9 |
+
from huggingface_hub import InferenceClient # For prompt refinement via API
|
| 10 |
+
from diffusers import StableDiffusionPipeline, EulerDiscreteScheduler # For image generation
|
| 11 |
+
from gradio_client import Client as GradioClient, handle_file # For 3D generation
|
| 12 |
+
|
| 13 |
+
# --- Configuration ---
|
| 14 |
+
# Consider making these configurable in the UI later
|
| 15 |
+
LIST_GENERATION_MODEL = "google/flan-t5-base" # Or another suitable small model
|
| 16 |
+
PROMPT_REFINEMENT_MODEL_API = "mistralai/Mixtral-8x7B-Instruct-v0.1" # Or another instruct model via Inference API
|
| 17 |
+
IMAGE_GENERATION_MODEL = "stabilityai/stable-diffusion-xl-base-1.0" # Or "runwayml/stable-diffusion-v1-5"
|
| 18 |
+
HUNYUAN_SPACE_ID = "tencent/Hunyuan3D-2"
|
| 19 |
+
OUTPUT_DIR = "outputs"
|
| 20 |
+
MODELS_SUBDIR = "3d_models"
|
| 21 |
+
IMAGES_SUBDIR = "image_previews"
|
| 22 |
+
ZIP_FILENAME = "3d_collection.zip"
|
| 23 |
+
|
| 24 |
+
# --- Initialize Clients/Pipelines (can be slow, consider loading on demand if needed) ---
|
| 25 |
+
|
| 26 |
+
# Use HF Token from Space secrets if available/needed for Inference API
|
| 27 |
+
HF_TOKEN = os.environ.get("HUGGINGFACE_TOKEN", None)
|
| 28 |
+
|
| 29 |
+
# Basic List Generator (local)
|
| 30 |
+
try:
|
| 31 |
+
list_generator = transformers_pipeline("text2text-generation", model=LIST_GENERATION_MODEL)
|
| 32 |
+
except Exception as e:
|
| 33 |
+
print(f"Warning: Could not load local list generator {LIST_GENERATION_MODEL}: {e}")
|
| 34 |
+
list_generator = None
|
| 35 |
+
|
| 36 |
+
# Prompt Refiner (API)
|
| 37 |
+
try:
|
| 38 |
+
if not HF_TOKEN:
|
| 39 |
+
print("Warning: HUGGINGFACE_TOKEN not set. Inference API calls might be rate-limited or fail.")
|
| 40 |
+
prompt_refiner = InferenceClient(model=PROMPT_REFINEMENT_MODEL_API, token=HF_TOKEN)
|
| 41 |
+
except Exception as e:
|
| 42 |
+
print(f"Warning: Could not initialize InferenceClient for {PROMPT_REFINEMENT_MODEL_API}: {e}")
|
| 43 |
+
prompt_refiner = None
|
| 44 |
+
|
| 45 |
+
# Image Generator (Local with Diffusers - requires GPU on Space for reasonable speed)
|
| 46 |
+
# Or consider an Image Gen API service if running on CPU hardware
|
| 47 |
+
try:
|
| 48 |
+
# Using XL as an example - adjust based on available hardware
|
| 49 |
+
image_pipeline = StableDiffusionPipeline.from_pretrained(IMAGE_GENERATION_MODEL, torch_dtype=torch.float16, use_safetensors=True)
|
| 50 |
+
# Move to GPU if available (check Space hardware)
|
| 51 |
+
# image_pipeline.to("cuda") # Uncomment if GPU is available
|
| 52 |
+
image_pipeline.scheduler = EulerDiscreteScheduler.from_config(image_pipeline.scheduler.config)
|
| 53 |
+
except Exception as e:
|
| 54 |
+
print(f"Warning: Could not load diffusers pipeline {IMAGE_GENERATION_MODEL}. Image generation might fail: {e}")
|
| 55 |
+
image_pipeline = None
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
# 3D Generator Client
|
| 59 |
+
try:
|
| 60 |
+
hunyuan_client = GradioClient(HUNYUAN_SPACE_ID)
|
| 61 |
+
except Exception as e:
|
| 62 |
+
print(f"Error initializing GradioClient for {HUNYUAN_SPACE_ID}: {e}")
|
| 63 |
+
hunyuan_client = None
|
| 64 |
+
|
| 65 |
+
# --- Helper Functions ---
|
| 66 |
+
|
| 67 |
+
def generate_list_local(theme, count):
|
| 68 |
+
if not list_generator:
|
| 69 |
+
return ["Error: List generator model not loaded."]
|
| 70 |
+
prompt = f"Generate a comma-separated list of {count} distinct types of {theme}."
|
| 71 |
+
try:
|
| 72 |
+
result = list_generator(prompt, max_length=200)[0]['generated_text']
|
| 73 |
+
items = [item.strip() for item in result.split(',') if item.strip()]
|
| 74 |
+
return items[:count] # Ensure we don't exceed the requested count
|
| 75 |
+
except Exception as e:
|
| 76 |
+
print(f"Error generating list: {e}")
|
| 77 |
+
return [f"Error: {e}"]
|
| 78 |
+
|
| 79 |
+
def refine_prompt_api(item_name):
|
| 80 |
+
if not prompt_refiner:
|
| 81 |
+
return f"A 3D model of a {item_name}" # Fallback basic prompt
|
| 82 |
+
prompt = f"Create a detailed, descriptive prompt for generating a highly realistic image of a single '{item_name}'. Focus on visual details suitable for a text-to-image AI. Only output the prompt itself."
|
| 83 |
+
try:
|
| 84 |
+
refined = prompt_refiner.text_generation(prompt, max_new_tokens=100)
|
| 85 |
+
# Clean up potential API artifacts if necessary
|
| 86 |
+
refined = refined.strip().strip('"')
|
| 87 |
+
return refined
|
| 88 |
+
except Exception as e:
|
| 89 |
+
print(f"Error refining prompt for '{item_name}': {e}")
|
| 90 |
+
# Fallback to a simpler prompt for 3D generation if refinement fails
|
| 91 |
+
return f"A high quality 3D model of a {item_name}"
|
| 92 |
+
|
| 93 |
+
def generate_image_local(refined_prompt, output_path):
|
| 94 |
+
if not image_pipeline:
|
| 95 |
+
print("Image generation pipeline not available.")
|
| 96 |
+
# Create a placeholder image or return None
|
| 97 |
+
# Example: from PIL import Image; img = Image.new('RGB', (60, 30), color = 'red'); img.save(output_path); return output_path
|
| 98 |
+
return None
|
| 99 |
+
try:
|
| 100 |
+
# Adjust inference steps/guidance as needed
|
| 101 |
+
image = image_pipeline(refined_prompt, num_inference_steps=25, guidance_scale=7.5).images[0]
|
| 102 |
+
image.save(output_path)
|
| 103 |
+
return output_path
|
| 104 |
+
except Exception as e:
|
| 105 |
+
print(f"Error generating image for prompt '{refined_prompt}': {e}")
|
| 106 |
+
return None
|
| 107 |
+
|
| 108 |
+
def generate_3d_model_hunyuan(refined_prompt_for_3d, output_dir, item_name_safe):
|
| 109 |
+
if not hunyuan_client:
|
| 110 |
+
print("Hunyuan 3D client not available.")
|
| 111 |
+
return None, "Client not initialized"
|
| 112 |
+
|
| 113 |
+
print(f"Requesting 3D model for: {refined_prompt_for_3d}")
|
| 114 |
+
# Use defaults for most parameters initially
|
| 115 |
+
try:
|
| 116 |
+
result_tuple = hunyuan_client.predict(
|
| 117 |
+
caption=refined_prompt_for_3d,
|
| 118 |
+
# Leave image and mv_image inputs as None for text-to-3D
|
| 119 |
+
image=None,
|
| 120 |
+
mv_image_front=None,
|
| 121 |
+
mv_image_back=None,
|
| 122 |
+
mv_image_left=None,
|
| 123 |
+
mv_image_right=None,
|
| 124 |
+
# Default values from API docs (can be overridden)
|
| 125 |
+
steps=30,
|
| 126 |
+
guidance_scale=5,
|
| 127 |
+
seed=1234, # Or use randomize_seed=True
|
| 128 |
+
octree_resolution=256,
|
| 129 |
+
check_box_rembg=True,
|
| 130 |
+
num_chunks=8000,
|
| 131 |
+
randomize_seed=True,
|
| 132 |
+
api_name="/generation_all" # Crucial!
|
| 133 |
+
)
|
| 134 |
+
|
| 135 |
+
# --- VERIFICATION NEEDED ---
|
| 136 |
+
# Check the actual return tuple structure. Assuming file path is first or second.
|
| 137 |
+
# Let's try the first element (index 0). If it's None or not a path, try index 1.
|
| 138 |
+
raw_filepath = None
|
| 139 |
+
if result_tuple and len(result_tuple) > 0 and isinstance(result_tuple[0], str):
|
| 140 |
+
raw_filepath = result_tuple[0]
|
| 141 |
+
elif result_tuple and len(result_tuple) > 1 and isinstance(result_tuple[1], str):
|
| 142 |
+
print("Using second element from result tuple for filepath.")
|
| 143 |
+
raw_filepath = result_tuple[1]
|
| 144 |
+
# --- END VERIFICATION NEEDED ---
|
| 145 |
+
|
| 146 |
+
if raw_filepath:
|
| 147 |
+
print(f"Job completed. Raw result path: {raw_filepath}")
|
| 148 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 149 |
+
|
| 150 |
+
# Download the file using handle_file which manages temp paths etc.
|
| 151 |
+
# handle_file saves with a potentially random name in download_dir
|
| 152 |
+
downloaded_temp_path = handle_file(raw_filepath, download_dir=output_dir)
|
| 153 |
+
|
| 154 |
+
if downloaded_temp_path and os.path.exists(downloaded_temp_path):
|
| 155 |
+
# Rename it to something meaningful
|
| 156 |
+
file_ext = os.path.splitext(downloaded_temp_path)[1] # Get extension (.glb, .obj?)
|
| 157 |
+
if not file_ext: file_ext = ".glb" # Assume glb if unknown
|
| 158 |
+
final_path = os.path.join(output_dir, f"{item_name_safe}{file_ext}")
|
| 159 |
+
os.rename(downloaded_temp_path, final_path)
|
| 160 |
+
print(f"Model saved to: {final_path}")
|
| 161 |
+
return final_path, "Success"
|
| 162 |
+
else:
|
| 163 |
+
error_msg = f"handle_file failed to download or returned invalid path: {downloaded_temp_path}"
|
| 164 |
+
print(error_msg)
|
| 165 |
+
return None, error_msg
|
| 166 |
+
else:
|
| 167 |
+
error_msg = f"Job for '{refined_prompt_for_3d}' did not return a valid filepath in expected tuple elements."
|
| 168 |
+
print(error_msg)
|
| 169 |
+
# You might want to inspect the full result_tuple here for debugging
|
| 170 |
+
print(f"Full result tuple: {result_tuple}")
|
| 171 |
+
return None, error_msg
|
| 172 |
+
|
| 173 |
+
except Exception as e:
|
| 174 |
+
error_msg = f"Error calling Hunyuan3D API for '{refined_prompt_for_3d}': {e}"
|
| 175 |
+
print(error_msg)
|
| 176 |
+
return None, str(e)
|
| 177 |
+
|
| 178 |
+
def create_zip(files_to_zip, zip_filepath):
|
| 179 |
+
with zipfile.ZipFile(zip_filepath, 'w') as zf:
|
| 180 |
+
for file_path in files_to_zip:
|
| 181 |
+
if file_path and os.path.exists(file_path):
|
| 182 |
+
zf.write(file_path, os.path.basename(file_path))
|
| 183 |
+
return zip_filepath
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
# --- Gradio Interface & Logic ---
|
| 187 |
+
|
| 188 |
+
with gr.Blocks() as demo:
|
| 189 |
+
gr.Markdown("# 3D Asset Collection Generator")
|
| 190 |
+
gr.Markdown("Generate a list based on a theme, refine prompts, preview images, and generate selected 3D models using Hunyuan3D-2.")
|
| 191 |
+
if not HF_TOKEN:
|
| 192 |
+
gr.Warning("Hugging Face Token not found. Prompt refinement quality/rate limits may be affected. Consider adding HUGGINGFACE_TOKEN to Space secrets.")
|
| 193 |
+
if not image_pipeline:
|
| 194 |
+
gr.Warning("Local Image Generation model failed to load. Image previews will be skipped. Check Space hardware/logs.")
|
| 195 |
+
if not hunyuan_client:
|
| 196 |
+
gr.Error("Failed to connect to the Hunyuan3D-2 Space. 3D generation will not work.")
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
# State to hold intermediate results
|
| 200 |
+
# Using gr.State is good for simple values, for complex lists/dicts might need alternatives or careful handling
|
| 201 |
+
list_items_state = gr.State([])
|
| 202 |
+
refined_prompts_state = gr.State({}) # Dict: {item_name: refined_prompt}
|
| 203 |
+
image_paths_state = gr.State({}) # Dict: {item_name: image_path}
|
| 204 |
+
selected_items_state = gr.State([]) # List of item_names selected by user
|
| 205 |
+
generated_3d_files_state = gr.State([]) # List of paths to successfully generated models
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
with gr.Row():
|
| 209 |
+
theme_input = gr.Textbox(label="Theme", placeholder="e.g., reptiles, kitchen appliances, medieval weapons")
|
| 210 |
+
count_input = gr.Number(label="Number of Items", value=5, minimum=1, step=1)
|
| 211 |
+
|
| 212 |
+
generate_list_button = gr.Button("1. Generate List & Refine Prompts")
|
| 213 |
+
list_output_display = gr.Markdown("List will appear here...") # Or use gr.DataFrame
|
| 214 |
+
|
| 215 |
+
generate_images_button = gr.Button("2. Generate Image Previews", visible=False) # Hidden initially
|
| 216 |
+
# Use Gallery for display, Dataset for selection tracking
|
| 217 |
+
image_gallery = gr.Gallery(label="Image Previews", visible=False, elem_id="image_gallery")
|
| 218 |
+
# Dataset to hold data for selection (item_name, image_path, refined_prompt)
|
| 219 |
+
selection_data = gr.Dataset(components=[gr.Textbox(visible=False), gr.Textbox(visible=False), gr.Textbox(visible=False)], # item, img_path, prompt
|
| 220 |
+
headers=["Item Name", "Image", "Prompt"],
|
| 221 |
+
label="Select Items for 3D Generation",
|
| 222 |
+
visible=False)
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
generate_3d_button = gr.Button("3. Generate 3D Models for Selected Items", visible=False) # Hidden initially
|
| 226 |
+
status_output = gr.Markdown("") # For progress updates
|
| 227 |
+
final_zip_output = gr.File(label="Download 3D Model Collection (ZIP)", visible=False)
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
# --- Event Logic ---
|
| 231 |
+
|
| 232 |
+
def run_list_and_refine(theme, count):
|
| 233 |
+
if not theme:
|
| 234 |
+
return {list_output_display: "Please enter a theme.", generate_images_button: gr.Button(visible=False)}
|
| 235 |
+
|
| 236 |
+
# Ensure output dirs exist
|
| 237 |
+
os.makedirs(os.path.join(OUTPUT_DIR, IMAGES_SUBDIR), exist_ok=True)
|
| 238 |
+
os.makedirs(os.path.join(OUTPUT_DIR, MODELS_SUBDIR), exist_ok=True)
|
| 239 |
+
|
| 240 |
+
|
| 241 |
+
gr.Info("Generating list...")
|
| 242 |
+
items = generate_list_local(theme, int(count))
|
| 243 |
+
if not items or "Error:" in items[0]:
|
| 244 |
+
return {list_output_display: f"Failed to generate list: {items[0] if items else 'Unknown error'}",
|
| 245 |
+
generate_images_button: gr.Button(visible=False)}
|
| 246 |
+
|
| 247 |
+
list_items_state.value = items # Save items to state
|
| 248 |
+
|
| 249 |
+
gr.Info("Refining prompts via API...")
|
| 250 |
+
refined_prompts = {}
|
| 251 |
+
output_md = "### Generated List & Refined Prompts:\n\n"
|
| 252 |
+
for item in items:
|
| 253 |
+
refined = refine_prompt_api(item)
|
| 254 |
+
refined_prompts[item] = refined
|
| 255 |
+
output_md += f"* **{item}:** {refined}\n"
|
| 256 |
+
|
| 257 |
+
refined_prompts_state.value = refined_prompts # Save refined prompts
|
| 258 |
+
|
| 259 |
+
# Enable next step
|
| 260 |
+
return {
|
| 261 |
+
list_output_display: output_md,
|
| 262 |
+
generate_images_button: gr.Button(visible=True) # Show image gen button
|
| 263 |
+
}
|
| 264 |
+
|
| 265 |
+
generate_list_button.click(
|
| 266 |
+
fn=run_list_and_refine,
|
| 267 |
+
inputs=[theme_input, count_input],
|
| 268 |
+
outputs=[list_output_display, generate_images_button, list_items_state, refined_prompts_state] # Update state too
|
| 269 |
+
)
|
| 270 |
+
|
| 271 |
+
|
| 272 |
+
def run_image_generation(items, refined_prompts_dict):
|
| 273 |
+
if not image_pipeline:
|
| 274 |
+
# Skip image generation if pipeline not loaded
|
| 275 |
+
gr.Warning("Image pipeline not loaded. Skipping image previews.")
|
| 276 |
+
# Prepare data for selection without images
|
| 277 |
+
selection_samples = [[item, "N/A", refined_prompts_dict.get(item, "")] for item in items]
|
| 278 |
+
image_paths_state.value = {} # Clear image paths
|
| 279 |
+
return {
|
| 280 |
+
image_gallery: gr.Gallery(visible=False),
|
| 281 |
+
selection_data: gr.Dataset(samples=selection_samples, visible=True),
|
| 282 |
+
generate_3d_button: gr.Button(visible=True) # Allow proceeding without previews
|
| 283 |
+
}
|
| 284 |
+
|
| 285 |
+
gr.Info("Generating image previews... (this may take a while)")
|
| 286 |
+
image_paths = {}
|
| 287 |
+
gallery_images = []
|
| 288 |
+
selection_samples = [] # For the Dataset component
|
| 289 |
+
|
| 290 |
+
img_dir = os.path.join(OUTPUT_DIR, IMAGES_SUBDIR)
|
| 291 |
+
|
| 292 |
+
for item in items:
|
| 293 |
+
refined_prompt = refined_prompts_dict.get(item, f"Image of {item}") # Get refined prompt
|
| 294 |
+
safe_item_name = "".join(c if c.isalnum() else "_" for c in item)
|
| 295 |
+
img_filename = f"{safe_item_name}_{uuid.uuid4()}.png"
|
| 296 |
+
img_path = os.path.join(img_dir, img_filename)
|
| 297 |
+
|
| 298 |
+
generated_path = generate_image_local(refined_prompt, img_path)
|
| 299 |
+
|
| 300 |
+
if generated_path:
|
| 301 |
+
image_paths[item] = generated_path
|
| 302 |
+
gallery_images.append(generated_path)
|
| 303 |
+
selection_samples.append([item, generated_path, refined_prompt])
|
| 304 |
+
else:
|
| 305 |
+
# Handle image generation failure - maybe add placeholder info
|
| 306 |
+
selection_samples.append([item, "Failed", refined_prompt])
|
| 307 |
+
# Optionally add a placeholder to gallery_images too
|
| 308 |
+
|
| 309 |
+
image_paths_state.value = image_paths # Save image paths
|
| 310 |
+
|
| 311 |
+
# Show gallery and selection dataset
|
| 312 |
+
return {
|
| 313 |
+
image_gallery: gr.Gallery(value=gallery_images, visible=True),
|
| 314 |
+
selection_data: gr.Dataset(samples=selection_samples, visible=True),
|
| 315 |
+
generate_3d_button: gr.Button(visible=True) # Show 3D gen button
|
| 316 |
+
}
|
| 317 |
+
|
| 318 |
+
generate_images_button.click(
|
| 319 |
+
fn=run_image_generation,
|
| 320 |
+
inputs=[list_items_state, refined_prompts_state],
|
| 321 |
+
outputs=[image_gallery, selection_data, generate_3d_button, image_paths_state] # Update state
|
| 322 |
+
)
|
| 323 |
+
|
| 324 |
+
# Handler for when user makes selections in the Dataset
|
| 325 |
+
# Note: Gradio's Dataset selection handling might require specific event listeners
|
| 326 |
+
# or potentially using gr.CheckboxGroup or similar if Dataset selection is tricky.
|
| 327 |
+
# For simplicity here, we assume we can get the selected indices/items.
|
| 328 |
+
# A common pattern is to add a hidden Textbox updated by JS on selection,
|
| 329 |
+
# or use the Dataset's 'select' event if available and robust.
|
| 330 |
+
# Let's simulate getting selected *items* (requires correct component setup).
|
| 331 |
+
# This part might need refinement based on Gradio version/behavior.
|
| 332 |
+
|
| 333 |
+
# We'll trigger 3D generation directly from the button click for now,
|
| 334 |
+
# assuming the selection_data component holds the necessary info and selection state.
|
| 335 |
+
|
| 336 |
+
def run_3d_generation(selection_evt: gr.SelectData, all_items_data):
|
| 337 |
+
if not hunyuan_client:
|
| 338 |
+
return {status_output: "Hunyuan3D client not initialized. Cannot generate.", final_zip_output: gr.File(visible=False)}
|
| 339 |
+
|
| 340 |
+
selected_indices = selection_evt.index if selection_evt else []
|
| 341 |
+
if not selected_indices:
|
| 342 |
+
return {status_output: "Please select items from the table above before generating 3D models.", final_zip_output: gr.File(visible=False)}
|
| 343 |
+
|
| 344 |
+
# Extract selected items based on indices from the *current* data in the dataset
|
| 345 |
+
selected_items_info = [all_items_data[i] for i in selected_indices] # Each item is [name, img_path, prompt]
|
| 346 |
+
|
| 347 |
+
generated_files = []
|
| 348 |
+
status_messages = ["### 3D Generation Status:\n"]
|
| 349 |
+
|
| 350 |
+
model_dir = os.path.join(OUTPUT_DIR, MODELS_SUBDIR)
|
| 351 |
+
|
| 352 |
+
total_selected = len(selected_items_info)
|
| 353 |
+
for i, (item_name, _, refined_prompt) in enumerate(selected_items_info):
|
| 354 |
+
current_status = f"({i+1}/{total_selected}) Generating model for: **{item_name}**..."
|
| 355 |
+
print(current_status)
|
| 356 |
+
status_messages.append(f"* {current_status}")
|
| 357 |
+
# Update UI status progressively
|
| 358 |
+
yield {status_output: "\n".join(status_messages), final_zip_output: gr.File(visible=False)}
|
| 359 |
+
|
| 360 |
+
# Adapt prompt slightly for 3D if desired, or use the image prompt directly
|
| 361 |
+
prompt_for_3d = refined_prompt # Or customize: f"A high quality 3D model of {item_name}, {refined_prompt}"
|
| 362 |
+
|
| 363 |
+
item_name_safe = "".join(c if c.isalnum() else "_" for c in item_name)
|
| 364 |
+
|
| 365 |
+
# --- Retry Logic Placeholder ---
|
| 366 |
+
max_retries = 1 # Example: allow 1 retry
|
| 367 |
+
attempts = 0
|
| 368 |
+
model_path = None
|
| 369 |
+
last_error = "Unknown error"
|
| 370 |
+
|
| 371 |
+
while attempts <= max_retries:
|
| 372 |
+
attempts += 1
|
| 373 |
+
if attempts > 1:
|
| 374 |
+
status_messages.append(f" * Retrying ({attempts-1}/{max_retries})...")
|
| 375 |
+
yield {status_output: "\n".join(status_messages)}
|
| 376 |
+
time.sleep(2) # Brief pause before retry
|
| 377 |
+
|
| 378 |
+
model_path, msg = generate_3d_model_hunyuan(prompt_for_3d, model_dir, item_name_safe)
|
| 379 |
+
last_error = msg
|
| 380 |
+
if model_path:
|
| 381 |
+
generated_files.append(model_path)
|
| 382 |
+
status_messages.append(f" * Success! Model saved.")
|
| 383 |
+
break # Exit retry loop on success
|
| 384 |
+
else:
|
| 385 |
+
status_messages.append(f" * Attempt {attempts} failed: {msg}")
|
| 386 |
+
|
| 387 |
+
if not model_path:
|
| 388 |
+
status_messages.append(f" * **Failed** after {attempts} attempt(s). Last error: {last_error}")
|
| 389 |
+
# --- End Retry Logic ---
|
| 390 |
+
|
| 391 |
+
# Update UI status after each item
|
| 392 |
+
yield {status_output: "\n".join(status_messages)}
|
| 393 |
+
|
| 394 |
+
|
| 395 |
+
if generated_files:
|
| 396 |
+
status_messages.append("\nCreating ZIP archive...")
|
| 397 |
+
yield {status_output: "\n".join(status_messages)}
|
| 398 |
+
zip_path = os.path.join(OUTPUT_DIR, ZIP_FILENAME)
|
| 399 |
+
final_zip = create_zip(generated_files, zip_path)
|
| 400 |
+
status_messages.append(f"\n**Collection ready!** Download '{ZIP_FILENAME}' below.")
|
| 401 |
+
generated_3d_files_state.value = generated_files # Store final paths
|
| 402 |
+
return {status_output: "\n".join(status_messages), final_zip_output: gr.File(value=final_zip, visible=True)}
|
| 403 |
+
else:
|
| 404 |
+
status_messages.append("\nNo 3D models were successfully generated.")
|
| 405 |
+
return {status_output: "\n".join(status_messages), final_zip_output: gr.File(visible=False)}
|
| 406 |
+
|
| 407 |
+
|
| 408 |
+
# Link the button click to the generator function
|
| 409 |
+
# The 'select' event on Dataset provides selection info (gr.SelectData)
|
| 410 |
+
# We pass both the selection event data and the full dataset content
|
| 411 |
+
generate_3d_button.click(
|
| 412 |
+
fn=run_3d_generation,
|
| 413 |
+
inputs=[selection_data, selection_data], # Pass dataset twice: once for select event, once for full data access
|
| 414 |
+
outputs=[status_output, final_zip_output, generated_3d_files_state] # Update state
|
| 415 |
+
)
|
| 416 |
+
|
| 417 |
+
|
| 418 |
+
# Launch the Gradio app
|
| 419 |
+
demo.queue().launch(debug=True) # Enable queue for longer processes, debug for detailed errors
|