import gradio as gr from huggingface_hub import snapshot_download import os import sys import cv2 import numpy as np from PIL import Image import spaces import torch # Add the parent directory to the sys.path sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) # Download entire repository def download_repo(repo_id=os.getenv("MODEL_REPO_ID"), local_dir="./ckpt"): """ Download the entire Hugging Face repository Args: repo_id (str): The repository ID on Hugging Face local_dir (str): Local directory to save the repository Returns: str: Path to the downloaded repository """ # Create output directory if it doesn't exist os.makedirs(local_dir, exist_ok=True) api_key = os.getenv("HF_TOKEN") repo_path = snapshot_download( repo_id=repo_id, local_dir=local_dir, local_dir_use_symlinks=False, # Use actual files, not symlinks token=api_key ) return repo_path # Download model repository repo_path = download_repo() from ckpt.pipeline import BlueberryPipeline # Simplified Model Cache class for loading and storing only the pipeline class ModelCache: def __init__(self): self.pipe = None def load_pipeline(self): if self.pipe is None: self.pipe = BlueberryPipeline.from_pretrained( "./ckpt", low_cpu_mem_usage=True, torch_dtype=torch.bfloat16 ).to("cuda") return self.pipe # Initialize model cache model_cache = ModelCache() @spaces.GPU(duration=60) def virtual_tryon(garment_img, person_img, garment_type, sleeve_length, garment_length): """ Perform virtual try-on with the given garment on the target person Args: garment_img: Image of the garment (numpy array) person_img: Image of the person (numpy array) Returns: Image showing the person wearing the garment """ try: # Convert numpy arrays to PIL images garment_pil = Image.fromarray(garment_img).convert('RGB') person_pil = Image.fromarray(person_img).convert('RGB') # Load the pipeline model_cache.load_pipeline() with torch.inference_mode(): inference_kwargs = { "captions": { 'garment_type': garment_type, 'sleeve_length': sleeve_length, 'garment_length': garment_length }, "human": person_pil, "cloth": garment_pil } result_image = model_cache.pipe(**inference_kwargs)[0][0] # Convert PIL image back to numpy array for Gradio return np.array(result_image) except Exception as e: raise RuntimeError(f"Virtual try-on failed") css = """ #col-container { margin: 0 auto; max-width: 640px; } #logo { max-width: 300px; margin: 0 auto; background-color: #ffffff00; border-color: #ffffff00; } #title { text-align: center; } #button { background-color: #FF5E00; } #button:hover { background-color: #9500FF; } a { color: #9500FF; } a:hover { color: #9500FFC8; } """ # Create Gradio interface with Blocks with gr.Blocks(css=css) as demo: with gr.Column(elem_id="col-container"): # Add company logo gr.Image(elem_id="logo", value="https://framerusercontent.com/images/SE8ziToAc9ZZbRn4vi4WVybvt0.svg?scale-down-to=512", show_label=False, show_download_button=False, show_fullscreen_button=False, show_share_button=False) gr.Markdown("# AYNA 1.0", elem_id="title") gr.Markdown('Introducing our first groundbreaking virtual try-on model series, kicking off with **AYNA 1.0**. [Discover more about Ayna here](https://www.getayna.com/) and step into the future of personalized fashion.', elem_id="title") with gr.Row(): garment_input = gr.Image(label="Garment Image", type="numpy") person_input = gr.Image(label="Person Image", type="numpy") output_image = gr.Image(show_label=False) with gr.Row(): with gr.Column(scale=2.5): with gr.Column(): with gr.Row(): # add example images for garment and humans example_garments = [ "examples/garments/garment-1.png", "examples/garments/garment-2.jpg", "examples/garments/garment-3.jpg" ] example_humans = [ "examples/humans/human-1.jpg", "examples/humans/human-2.jpg", "examples/humans/human-3.jpg" ] gr.Examples( examples=example_garments, inputs=garment_input, label="Garment Examples", examples_per_page=3 ) gr.Examples( examples=example_humans, inputs=person_input, label="Person Examples", examples_per_page=3 ) # Add radio buttons for different garment types with gr.Column(): garment_type = gr.Radio(choices=["upper", "lower", "full-body"], label="Garment Type", value="upper") sleeve_length = gr.Radio( choices=["3/4 sleeve", "cap sleeve", "short sleeve", "long sleeve", "sleeveless", "ignore"], label="Sleeve Length", visible=True, interactive=True, info="Choose 'ignore' if you are not sure" ) garment_length = gr.Radio( choices=["crop length", "hip length", "waist length", "tunic length", "thigh length", "knee length", "ignore"], label="Garment Length", interactive=True, info="Choose 'ignore' if you are not sure" ) # Update sleeve_length visibility based on garment_type def update_sleeve_visibility(garment): is_visible = (garment != "lower") return gr.update(visible=is_visible, interactive=is_visible) def update_garment_length_visibility(garment): if garment == "lower": return gr.update(choices=["full length", "knee length", "maxi length", "short length", "midi length", "mini length", "ignore"], value=None) elif garment == "full-body": return gr.update(choices=["full length", "asymmetrical length", "knee length", "maxi length", "midi length", "mini length", "tunic length", "ignore"], value=None) else: return gr.update(choices=["crop length", "hip length", "waist length", "tunic length", "thigh length", "knee length", "ignore"], value=None) garment_type.change( fn=update_sleeve_visibility, inputs=garment_type, outputs=sleeve_length ) garment_type.change( fn=update_garment_length_visibility, inputs=garment_type, outputs=garment_length ) try_on_button = gr.Button("Try On", elem_id="button") # Add validation function def validate_inputs(garment_img, person_img, garment_type, sleeve_length, garment_length): if garment_img is None: raise gr.Error("Please upload a garment image") if person_img is None: raise gr.Error("Please upload a person image") if garment_type is None: raise gr.Error("Please select a garment type") if garment_type != "lower" and sleeve_length is None: raise gr.Error("Please select a sleeve length") if garment_length is None: raise gr.Error("Please select a garment length") # If all validations pass, proceed with try-on try: result = virtual_tryon(garment_img, person_img, garment_type, sleeve_length, garment_length) return result except Exception as e: raise gr.Error(f"Error: {str(e)}") try_on_button.click( fn=validate_inputs, inputs=[garment_input, person_input, garment_type, sleeve_length, garment_length], outputs=output_image ) if __name__ == "__main__": demo.launch()