Spaces:
Runtime error
Runtime error
changed name
Browse files- .gitattributes +2 -0
- README.md +1 -1
- app.py +43 -16
- examples/garments/garment-1.png +3 -0
- examples/garments/garment-2.jpg +3 -0
- examples/garments/garment-3.jpg +3 -0
- examples/humans/human-1.jpg +3 -0
- examples/humans/human-2.jpg +3 -0
- examples/humans/human-3.jpg +3 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
*.png filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
*.jpg filter=lfs diff=lfs merge=lfs -text
|
README.md
CHANGED
|
@@ -1,5 +1,5 @@
|
|
| 1 |
---
|
| 2 |
-
title:
|
| 3 |
emoji: π
|
| 4 |
colorFrom: purple
|
| 5 |
colorTo: yellow
|
|
|
|
| 1 |
---
|
| 2 |
+
title: AYNA 1.0
|
| 3 |
emoji: π
|
| 4 |
colorFrom: purple
|
| 5 |
colorTo: yellow
|
app.py
CHANGED
|
@@ -10,7 +10,6 @@ import torch
|
|
| 10 |
|
| 11 |
# Add the parent directory to the sys.path
|
| 12 |
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
| 13 |
-
from ckpt.pipeline import BlueberryPipeline
|
| 14 |
|
| 15 |
# Download entire repository
|
| 16 |
def download_repo(repo_id=os.getenv("MODEL_REPO_ID"), local_dir="./ckpt"):
|
|
@@ -41,6 +40,8 @@ def download_repo(repo_id=os.getenv("MODEL_REPO_ID"), local_dir="./ckpt"):
|
|
| 41 |
# Download model repository
|
| 42 |
repo_path = download_repo()
|
| 43 |
|
|
|
|
|
|
|
| 44 |
# Simplified Model Cache class for loading and storing only the pipeline
|
| 45 |
class ModelCache:
|
| 46 |
def __init__(self):
|
|
@@ -59,7 +60,7 @@ class ModelCache:
|
|
| 59 |
# Initialize model cache
|
| 60 |
model_cache = ModelCache()
|
| 61 |
|
| 62 |
-
@spaces.GPU
|
| 63 |
def virtual_tryon(garment_img, person_img, garment_type, sleeve_length, garment_length):
|
| 64 |
"""
|
| 65 |
Perform virtual try-on with the given garment on the target person
|
|
@@ -140,21 +141,47 @@ with gr.Blocks(css=css) as demo:
|
|
| 140 |
|
| 141 |
with gr.Row():
|
| 142 |
with gr.Column(scale=2.5):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 143 |
# Add radio buttons for different garment types
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
|
|
|
| 158 |
|
| 159 |
# Update sleeve_length visibility based on garment_type
|
| 160 |
def update_sleeve_visibility(garment):
|
|
|
|
| 10 |
|
| 11 |
# Add the parent directory to the sys.path
|
| 12 |
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
|
|
|
| 13 |
|
| 14 |
# Download entire repository
|
| 15 |
def download_repo(repo_id=os.getenv("MODEL_REPO_ID"), local_dir="./ckpt"):
|
|
|
|
| 40 |
# Download model repository
|
| 41 |
repo_path = download_repo()
|
| 42 |
|
| 43 |
+
from ckpt.pipeline import BlueberryPipeline
|
| 44 |
+
|
| 45 |
# Simplified Model Cache class for loading and storing only the pipeline
|
| 46 |
class ModelCache:
|
| 47 |
def __init__(self):
|
|
|
|
| 60 |
# Initialize model cache
|
| 61 |
model_cache = ModelCache()
|
| 62 |
|
| 63 |
+
@spaces.GPU(duration=60)
|
| 64 |
def virtual_tryon(garment_img, person_img, garment_type, sleeve_length, garment_length):
|
| 65 |
"""
|
| 66 |
Perform virtual try-on with the given garment on the target person
|
|
|
|
| 141 |
|
| 142 |
with gr.Row():
|
| 143 |
with gr.Column(scale=2.5):
|
| 144 |
+
with gr.Column():
|
| 145 |
+
with gr.Row():
|
| 146 |
+
# add example images for garment and humans
|
| 147 |
+
example_garments = [
|
| 148 |
+
"examples/garments/garment-1.png",
|
| 149 |
+
"examples/garments/garment-2.jpg",
|
| 150 |
+
"examples/garments/garment-3.jpg"
|
| 151 |
+
]
|
| 152 |
+
example_humans = [
|
| 153 |
+
"examples/humans/human-1.jpg",
|
| 154 |
+
"examples/humans/human-2.jpg",
|
| 155 |
+
"examples/humans/human-3.jpg"
|
| 156 |
+
]
|
| 157 |
+
gr.Examples(
|
| 158 |
+
examples=example_garments,
|
| 159 |
+
inputs=garment_input,
|
| 160 |
+
label="Garment Examples",
|
| 161 |
+
examples_per_page=3
|
| 162 |
+
)
|
| 163 |
+
gr.Examples(
|
| 164 |
+
examples=example_humans,
|
| 165 |
+
inputs=person_input,
|
| 166 |
+
label="Person Examples",
|
| 167 |
+
examples_per_page=3
|
| 168 |
+
)
|
| 169 |
# Add radio buttons for different garment types
|
| 170 |
+
with gr.Column():
|
| 171 |
+
garment_type = gr.Radio(choices=["upper", "lower", "full-body"], label="Garment Type", value="upper")
|
| 172 |
+
sleeve_length = gr.Radio(
|
| 173 |
+
choices=["3/4 sleeve", "cap sleeve", "short sleeve", "long sleeve", "sleeveless", "ignore"],
|
| 174 |
+
label="Sleeve Length",
|
| 175 |
+
visible=True,
|
| 176 |
+
interactive=True,
|
| 177 |
+
info="Choose 'ignore' if you are not sure"
|
| 178 |
+
)
|
| 179 |
+
garment_length = gr.Radio(
|
| 180 |
+
choices=["crop length", "hip length", "waist length", "tunic length", "thigh length", "knee length", "ignore"],
|
| 181 |
+
label="Garment Length",
|
| 182 |
+
interactive=True,
|
| 183 |
+
info="Choose 'ignore' if you are not sure"
|
| 184 |
+
)
|
| 185 |
|
| 186 |
# Update sleeve_length visibility based on garment_type
|
| 187 |
def update_sleeve_visibility(garment):
|
examples/garments/garment-1.png
ADDED
|
Git LFS Details
|
examples/garments/garment-2.jpg
ADDED
|
Git LFS Details
|
examples/garments/garment-3.jpg
ADDED
|
Git LFS Details
|
examples/humans/human-1.jpg
ADDED
|
Git LFS Details
|
examples/humans/human-2.jpg
ADDED
|
Git LFS Details
|
examples/humans/human-3.jpg
ADDED
|
Git LFS Details
|