Spaces:
Running
on
Zero
Running
on
Zero
File size: 9,017 Bytes
0e84104 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 |
from typing import Union, Optional, List, Dict, Any
import numpy as np
import torch
from diffusers import FluxPipeline
from diffusers.pipelines.flux import FluxPipelineOutput
from diffusers.pipelines.flux.pipeline_flux import calculate_shift, retrieve_timesteps
from diffusers.utils import is_torch_xla_available
from utils.image_utils import resize_image, resize_image_first
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
class Lotus2Pipeline(FluxPipeline):
@torch.no_grad()
def __call__(
self,
rgb_in: Optional[torch.FloatTensor] = None,
prompt: Union[str, List[str]] = None,
num_inference_steps: int = 10,
output_type: Optional[str] = "pil",
process_res: Optional[int] = None,
timestep_core_predictor: int = 1,
guidance_scale: float = 3.5,
return_dict: bool = True,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
):
r"""
Function invoked when calling the pipeline for generation.
Args:
rgb_in (`torch.FloatTensor`, *optional*):
The input image to be used for generation.
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide the prediction. Default is ''.
num_inference_steps (`int`, *optional*, defaults to 10):
The number of denoising steps. More denoising steps usually lead to a sharper prediction at the
expense of slower inference.
guidance_scale (`float`, *optional*, defaults to 7.0):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple.
joint_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
Examples:
Returns:
[`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict`
is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated
images.
"""
# 1. prepare
batch_size = rgb_in.shape[0]
input_size = rgb_in.shape[2:]
rgb_in = resize_image_first(rgb_in, process_res)
height, width = rgb_in.shape[2:]
self._guidance_scale = guidance_scale
self._joint_attention_kwargs = joint_attention_kwargs
self._interrupt = False
device = self._execution_device
# 2. encode prompt
(
prompt_embeds,
pooled_prompt_embeds,
text_ids,
) = self.encode_prompt(
prompt=prompt,
prompt_2=None,
device=device,
)
# 3. prepare latent variables
rgb_in = rgb_in.to(device=device, dtype=self.dtype)
rgb_latents = self.vae.encode(rgb_in).latent_dist.sample()
rgb_latents = (rgb_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor
packed_rgb_latents = self._pack_latents(
rgb_latents,
batch_size=rgb_latents.shape[0],
num_channels_latents=rgb_latents.shape[1],
height=rgb_latents.shape[2],
width=rgb_latents.shape[3],
)
latent_image_ids_core_predictor = self._prepare_latent_image_ids(batch_size, rgb_latents.shape[2]//2, rgb_latents.shape[3]//2, device, rgb_latents.dtype)
latent_image_ids = self._prepare_latent_image_ids(batch_size, rgb_latents.shape[2]//2, rgb_latents.shape[3]//2, device, rgb_latents.dtype)
# 4. prepare timesteps
timestep_core_predictor = torch.tensor(timestep_core_predictor).expand(batch_size).to(device=rgb_in.device, dtype=rgb_in.dtype)
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
image_seq_len = packed_rgb_latents.shape[1]
mu = calculate_shift(
image_seq_len,
self.scheduler.config.base_image_seq_len,
self.scheduler.config.max_image_seq_len,
self.scheduler.config.base_shift,
self.scheduler.config.max_shift,
)
timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler,
num_inference_steps,
device,
sigmas=sigmas,
mu=mu,
)
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) # 0
self._num_timesteps = len(timesteps)
# 5. handle guidance
if self.transformer.config.guidance_embeds:
guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
guidance = guidance.expand(packed_rgb_latents.shape[0])
else:
guidance = None
if self.joint_attention_kwargs is None:
self._joint_attention_kwargs = {}
# 6. core predictor
self.transformer.set_adapter("core_predictor")
latents = self.transformer(
hidden_states=packed_rgb_latents,
timestep=timestep_core_predictor / 1000,
guidance=guidance,
pooled_projections=pooled_prompt_embeds,
encoder_hidden_states=prompt_embeds,
txt_ids=text_ids,
img_ids=latent_image_ids_core_predictor,
joint_attention_kwargs=self.joint_attention_kwargs, # {}
return_dict=False,
)[0]
latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
latents = self.local_continuity_module(latents)
# 7. Denoising loop for detail sharpener
self.transformer.set_adapter("detail_sharpener")
latents = self._pack_latents(
latents,
batch_size=latents.shape[0],
num_channels_latents=latents.shape[1],
height=latents.shape[2],
width=latents.shape[3],
)
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:
continue
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latents.shape[0]).to(latents.dtype)
noise_pred = self.transformer(
hidden_states=latents,
timestep=timestep / 1000,
guidance=guidance,
pooled_projections=pooled_prompt_embeds,
encoder_hidden_states=prompt_embeds,
txt_ids=text_ids,
img_ids=latent_image_ids,
joint_attention_kwargs=self.joint_attention_kwargs,
return_dict=False,
)[0]
# compute the previous noisy sample x_t -> x_t-1
latents_dtype = latents.dtype
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
if latents.dtype != latents_dtype:
if torch.backends.mps.is_available():
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
latents = latents.to(latents_dtype)
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if XLA_AVAILABLE:
xm.mark_step()
latents = latents.to(dtype=self.dtype)
if output_type == "latent":
image = latents
else:
latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
image = self.vae.decode(latents, return_dict=False)[0]
image = self.image_processor.postprocess(image, output_type=output_type)
# Resize output image to match input size
image = resize_image(image, input_size)
# Offload all models
self.maybe_free_model_hooks()
if not return_dict:
return (image,)
return FluxPipelineOutput(images=image) |