Lotus-2_Normal / pipeline.py
haodongli's picture
init
0e84104
from typing import Union, Optional, List, Dict, Any
import numpy as np
import torch
from diffusers import FluxPipeline
from diffusers.pipelines.flux import FluxPipelineOutput
from diffusers.pipelines.flux.pipeline_flux import calculate_shift, retrieve_timesteps
from diffusers.utils import is_torch_xla_available
from utils.image_utils import resize_image, resize_image_first
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
class Lotus2Pipeline(FluxPipeline):
@torch.no_grad()
def __call__(
self,
rgb_in: Optional[torch.FloatTensor] = None,
prompt: Union[str, List[str]] = None,
num_inference_steps: int = 10,
output_type: Optional[str] = "pil",
process_res: Optional[int] = None,
timestep_core_predictor: int = 1,
guidance_scale: float = 3.5,
return_dict: bool = True,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
):
r"""
Function invoked when calling the pipeline for generation.
Args:
rgb_in (`torch.FloatTensor`, *optional*):
The input image to be used for generation.
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide the prediction. Default is ''.
num_inference_steps (`int`, *optional*, defaults to 10):
The number of denoising steps. More denoising steps usually lead to a sharper prediction at the
expense of slower inference.
guidance_scale (`float`, *optional*, defaults to 7.0):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple.
joint_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
Examples:
Returns:
[`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict`
is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated
images.
"""
# 1. prepare
batch_size = rgb_in.shape[0]
input_size = rgb_in.shape[2:]
rgb_in = resize_image_first(rgb_in, process_res)
height, width = rgb_in.shape[2:]
self._guidance_scale = guidance_scale
self._joint_attention_kwargs = joint_attention_kwargs
self._interrupt = False
device = self._execution_device
# 2. encode prompt
(
prompt_embeds,
pooled_prompt_embeds,
text_ids,
) = self.encode_prompt(
prompt=prompt,
prompt_2=None,
device=device,
)
# 3. prepare latent variables
rgb_in = rgb_in.to(device=device, dtype=self.dtype)
rgb_latents = self.vae.encode(rgb_in).latent_dist.sample()
rgb_latents = (rgb_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor
packed_rgb_latents = self._pack_latents(
rgb_latents,
batch_size=rgb_latents.shape[0],
num_channels_latents=rgb_latents.shape[1],
height=rgb_latents.shape[2],
width=rgb_latents.shape[3],
)
latent_image_ids_core_predictor = self._prepare_latent_image_ids(batch_size, rgb_latents.shape[2]//2, rgb_latents.shape[3]//2, device, rgb_latents.dtype)
latent_image_ids = self._prepare_latent_image_ids(batch_size, rgb_latents.shape[2]//2, rgb_latents.shape[3]//2, device, rgb_latents.dtype)
# 4. prepare timesteps
timestep_core_predictor = torch.tensor(timestep_core_predictor).expand(batch_size).to(device=rgb_in.device, dtype=rgb_in.dtype)
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
image_seq_len = packed_rgb_latents.shape[1]
mu = calculate_shift(
image_seq_len,
self.scheduler.config.base_image_seq_len,
self.scheduler.config.max_image_seq_len,
self.scheduler.config.base_shift,
self.scheduler.config.max_shift,
)
timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler,
num_inference_steps,
device,
sigmas=sigmas,
mu=mu,
)
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) # 0
self._num_timesteps = len(timesteps)
# 5. handle guidance
if self.transformer.config.guidance_embeds:
guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
guidance = guidance.expand(packed_rgb_latents.shape[0])
else:
guidance = None
if self.joint_attention_kwargs is None:
self._joint_attention_kwargs = {}
# 6. core predictor
self.transformer.set_adapter("core_predictor")
latents = self.transformer(
hidden_states=packed_rgb_latents,
timestep=timestep_core_predictor / 1000,
guidance=guidance,
pooled_projections=pooled_prompt_embeds,
encoder_hidden_states=prompt_embeds,
txt_ids=text_ids,
img_ids=latent_image_ids_core_predictor,
joint_attention_kwargs=self.joint_attention_kwargs, # {}
return_dict=False,
)[0]
latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
latents = self.local_continuity_module(latents)
# 7. Denoising loop for detail sharpener
self.transformer.set_adapter("detail_sharpener")
latents = self._pack_latents(
latents,
batch_size=latents.shape[0],
num_channels_latents=latents.shape[1],
height=latents.shape[2],
width=latents.shape[3],
)
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:
continue
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latents.shape[0]).to(latents.dtype)
noise_pred = self.transformer(
hidden_states=latents,
timestep=timestep / 1000,
guidance=guidance,
pooled_projections=pooled_prompt_embeds,
encoder_hidden_states=prompt_embeds,
txt_ids=text_ids,
img_ids=latent_image_ids,
joint_attention_kwargs=self.joint_attention_kwargs,
return_dict=False,
)[0]
# compute the previous noisy sample x_t -> x_t-1
latents_dtype = latents.dtype
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
if latents.dtype != latents_dtype:
if torch.backends.mps.is_available():
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
latents = latents.to(latents_dtype)
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if XLA_AVAILABLE:
xm.mark_step()
latents = latents.to(dtype=self.dtype)
if output_type == "latent":
image = latents
else:
latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
image = self.vae.decode(latents, return_dict=False)[0]
image = self.image_processor.postprocess(image, output_type=output_type)
# Resize output image to match input size
image = resize_image(image, input_size)
# Offload all models
self.maybe_free_model_hooks()
if not return_dict:
return (image,)
return FluxPipelineOutput(images=image)