Spaces:
Paused
Paused
Update app_quant_latent.py
Browse files- app_quant_latent.py +14 -6
app_quant_latent.py
CHANGED
|
@@ -705,15 +705,23 @@ def generate_image(prompt, height, width, steps, seed, guidance_scale=0.0):
|
|
| 705 |
latents = latents.float()
|
| 706 |
|
| 707 |
num_previews = min(10, steps)
|
| 708 |
-
|
| 709 |
|
| 710 |
-
|
|
|
|
|
|
|
|
|
|
| 711 |
try:
|
| 712 |
with torch.no_grad():
|
| 713 |
-
|
| 714 |
-
|
|
|
|
| 715 |
|
| 716 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 717 |
decoded = (decoded / 2 + 0.5).clamp(0, 1)
|
| 718 |
decoded = decoded.cpu().permute(0, 2, 3, 1).float().numpy()
|
| 719 |
decoded = (decoded * 255).round().astype("uint8")
|
|
@@ -726,7 +734,7 @@ def generate_image(prompt, height, width, steps, seed, guidance_scale=0.0):
|
|
| 726 |
latent_gallery.append(latent_img)
|
| 727 |
|
| 728 |
# Keep last 5 latents
|
| 729 |
-
last_latents.append(
|
| 730 |
if len(last_latents) > 5:
|
| 731 |
last_latents.pop(0)
|
| 732 |
|
|
|
|
| 705 |
latents = latents.float()
|
| 706 |
|
| 707 |
num_previews = min(10, steps)
|
| 708 |
+
preview_indices = torch.linspace(0, steps - 1, num_previews).long()
|
| 709 |
|
| 710 |
+
# clone latents for preview
|
| 711 |
+
preview_latents = latents.clone()
|
| 712 |
+
|
| 713 |
+
for i, step_idx in enumerate(preview_indices):
|
| 714 |
try:
|
| 715 |
with torch.no_grad():
|
| 716 |
+
# --- Denoising step simulation ---
|
| 717 |
+
noise_scale = 1.0 - (i / num_previews)
|
| 718 |
+
preview_latent_step = preview_latents + torch.randn_like(preview_latents) * noise_scale
|
| 719 |
|
| 720 |
+
# move to VAE device and match dtype
|
| 721 |
+
preview_latent_step = preview_latent_step.to(pipe.vae.device).to(pipe.vae.dtype)
|
| 722 |
+
|
| 723 |
+
# decode latent to image
|
| 724 |
+
decoded = pipe.vae.decode(preview_latent_step, return_dict=False)[0]
|
| 725 |
decoded = (decoded / 2 + 0.5).clamp(0, 1)
|
| 726 |
decoded = decoded.cpu().permute(0, 2, 3, 1).float().numpy()
|
| 727 |
decoded = (decoded * 255).round().astype("uint8")
|
|
|
|
| 734 |
latent_gallery.append(latent_img)
|
| 735 |
|
| 736 |
# Keep last 5 latents
|
| 737 |
+
last_latents.append(preview_latent_step.cpu().clone())
|
| 738 |
if len(last_latents) > 5:
|
| 739 |
last_latents.pop(0)
|
| 740 |
|