rahul7star commited on
Commit
d18408d
·
verified ·
1 Parent(s): 342566d

Update app_quant_latent.py

Browse files
Files changed (1) hide show
  1. app_quant_latent.py +14 -6
app_quant_latent.py CHANGED
@@ -705,15 +705,23 @@ def generate_image(prompt, height, width, steps, seed, guidance_scale=0.0):
705
  latents = latents.float()
706
 
707
  num_previews = min(10, steps)
708
- preview_steps = torch.linspace(0, 1, num_previews)
709
 
710
- for alpha in preview_steps:
 
 
 
711
  try:
712
  with torch.no_grad():
713
- preview_latent = latents * alpha + latents * 0 # simple progression
714
- preview_latent = preview_latent.to(pipe.vae.device).to(pipe.vae.dtype)
 
715
 
716
- decoded = pipe.vae.decode(preview_latent, return_dict=False)[0]
 
 
 
 
717
  decoded = (decoded / 2 + 0.5).clamp(0, 1)
718
  decoded = decoded.cpu().permute(0, 2, 3, 1).float().numpy()
719
  decoded = (decoded * 255).round().astype("uint8")
@@ -726,7 +734,7 @@ def generate_image(prompt, height, width, steps, seed, guidance_scale=0.0):
726
  latent_gallery.append(latent_img)
727
 
728
  # Keep last 5 latents
729
- last_latents.append(preview_latent.cpu().clone())
730
  if len(last_latents) > 5:
731
  last_latents.pop(0)
732
 
 
705
  latents = latents.float()
706
 
707
  num_previews = min(10, steps)
708
+ preview_indices = torch.linspace(0, steps - 1, num_previews).long()
709
 
710
+ # clone latents for preview
711
+ preview_latents = latents.clone()
712
+
713
+ for i, step_idx in enumerate(preview_indices):
714
  try:
715
  with torch.no_grad():
716
+ # --- Denoising step simulation ---
717
+ noise_scale = 1.0 - (i / num_previews)
718
+ preview_latent_step = preview_latents + torch.randn_like(preview_latents) * noise_scale
719
 
720
+ # move to VAE device and match dtype
721
+ preview_latent_step = preview_latent_step.to(pipe.vae.device).to(pipe.vae.dtype)
722
+
723
+ # decode latent to image
724
+ decoded = pipe.vae.decode(preview_latent_step, return_dict=False)[0]
725
  decoded = (decoded / 2 + 0.5).clamp(0, 1)
726
  decoded = decoded.cpu().permute(0, 2, 3, 1).float().numpy()
727
  decoded = (decoded * 255).round().astype("uint8")
 
734
  latent_gallery.append(latent_img)
735
 
736
  # Keep last 5 latents
737
+ last_latents.append(preview_latent_step.cpu().clone())
738
  if len(last_latents) > 5:
739
  last_latents.pop(0)
740