Spaces:
Runtime error
Runtime error
| import numpy as np | |
| import os | |
| import torch | |
| import einops | |
| import traceback | |
| import cv2 | |
| import modules.async_worker as worker | |
| from modules.util import generate_temp_filename | |
| from PIL import Image | |
| import os | |
| from comfy.model_base import WAN21 | |
| import shared | |
| from shared import path_manager, settings | |
| from pathlib import Path | |
| import random | |
| from modules.pipleline_utils import ( | |
| clean_prompt_cond_caches, | |
| get_previewer, | |
| ) | |
| import comfy.utils | |
| import comfy.model_management | |
| from comfy.sd import load_checkpoint_guess_config | |
| from calcuis_gguf.pig import load_gguf_sd, GGMLOps, GGUFModelPatcher | |
| from nodes import ( | |
| CLIPTextEncode, | |
| VAEDecodeTiled, | |
| ) | |
| from comfy_extras.nodes_hunyuan import EmptyHunyuanLatentVideo | |
| from comfy_extras.nodes_wan import WanImageToVideo | |
| from comfy_extras.nodes_model_advanced import ModelSamplingSD3 | |
| class pipeline: | |
| pipeline_type = ["wan_video"] | |
| class StableDiffusionModel: | |
| def __init__(self, unet, vae, clip, clip_vision): | |
| self.unet = unet | |
| self.vae = vae | |
| self.clip = clip | |
| self.clip_vision = clip_vision | |
| def to_meta(self): | |
| if self.unet is not None: | |
| self.unet.model.to("meta") | |
| if self.clip is not None: | |
| self.clip.cond_stage_model.to("meta") | |
| if self.vae is not None: | |
| self.vae.first_stage_model.to("meta") | |
| model_hash = "" | |
| model_base = None | |
| model_hash_patched = "" | |
| model_base_patched = None | |
| conditions = None | |
| ggml_ops = GGMLOps() | |
| # Optional function | |
| def parse_gen_data(self, gen_data): | |
| gen_data["original_image_number"] = 1 + ((int(gen_data["image_number"] / 4.0) + 1) * 4) | |
| gen_data["image_number"] = 1 | |
| return gen_data | |
| def load_base_model(self, name, unet_only=True): # Wan_Video never has the clip and vae models? | |
| # Check if model is already loaded | |
| if self.model_hash == name: | |
| return | |
| self.model_base = None | |
| self.model_hash = "" | |
| self.model_base_patched = None | |
| self.model_hash_patched = "" | |
| self.conditions = None | |
| filename = str(shared.models.get_file("checkpoints", name)) | |
| print(f"Loading WAN video {'unet' if unet_only else 'model'}: {name}") | |
| if filename.endswith(".gguf") or unet_only: | |
| with torch.torch.inference_mode(): | |
| try: | |
| if filename.endswith(".gguf"): | |
| sd = load_gguf_sd(filename) | |
| unet = comfy.sd.load_diffusion_model_state_dict( | |
| sd, model_options={"custom_operations": self.ggml_ops} | |
| ) | |
| unet = GGUFModelPatcher.clone(unet) | |
| unet.patch_on_device = True | |
| else: | |
| model_options = {} | |
| model_options["dtype"] = torch.float8_e4m3fn # FIXME should be a setting | |
| unet = comfy.sd.load_diffusion_model(filename, model_options=model_options) | |
| clip_paths = [] | |
| clip_names = [] | |
| if isinstance(unet.model, WAN21): | |
| clip_name = settings.default_settings.get("clip_umt5", "umt5_xxl_fp8_e4m3fn_scaled.safetensors") | |
| clip_names.append(str(clip_name)) | |
| clip_path = path_manager.get_folder_file_path( | |
| "clip", | |
| clip_name, | |
| default = os.path.join(path_manager.model_paths["clip_path"], clip_name) | |
| ) | |
| clip_paths.append(str(clip_path)) | |
| clip_type = comfy.sd.CLIPType.WAN | |
| # https://huggingface.co/Comfy-Org/Wan_2.1_ComfyUI_repackaged | |
| vae_name = settings.default_settings.get("vae_wan", "wan_2.1_vae.safetensors") | |
| else: | |
| print(f"ERROR: Not a Wan Video model?") | |
| unet = None | |
| return | |
| print(f"Loading CLIP: {clip_names}") | |
| clip = comfy.sd.load_clip(ckpt_paths=clip_paths, clip_type=clip_type, model_options={}) | |
| vae_path = path_manager.get_folder_file_path( | |
| "vae", | |
| vae_name, | |
| default = os.path.join(path_manager.model_paths["vae_path"], vae_name) | |
| ) | |
| print(f"Loading VAE: {vae_name}") | |
| sd = comfy.utils.load_torch_file(str(vae_path)) | |
| vae = comfy.sd.VAE(sd=sd) | |
| clip_vision_name = settings.default_settings.get("clip_vision", "clip_vision_h_fp8_e4m3fn.safetensors") | |
| clip_vision_path = path_manager.get_folder_file_path( | |
| "clip_vision", | |
| clip_vision_name, | |
| default = os.path.join(path_manager.model_paths["clip_vision_path"], clip_vision_name) | |
| ) | |
| print(f"Loading CLIP Vision: {clip_vision_name}") | |
| sd = comfy.utils.load_torch_file(str(clip_vision_path)) | |
| if "visual.transformer.resblocks.0.attn.in_proj_weight" in sd: | |
| clip_vision = comfy.clip_vision.load_clipvision_from_sd(sd, prefix="visual.", convert_keys=True) | |
| else: | |
| clip_vision = comfy.clip_vision.load_clipvision_from_sd(sd=sd) | |
| except Exception as e: | |
| unet = None | |
| traceback.print_exc() | |
| else: | |
| try: | |
| with torch.torch.inference_mode(): | |
| unet, clip, vae, clip_vision = load_checkpoint_guess_config(filename) | |
| if clip == None or vae == None: | |
| raise | |
| except: | |
| print(f"Failed. Trying to load as unet.") | |
| self.load_base_model( | |
| filename, | |
| unet_only=True | |
| ) | |
| return | |
| if unet == None: | |
| print(f"Failed to load {name}") | |
| self.model_base = None | |
| self.model_hash = "" | |
| else: | |
| self.model_base = self.StableDiffusionModel( | |
| unet=unet, clip=clip, vae=vae, clip_vision=clip_vision | |
| ) | |
| if not ( | |
| isinstance(self.model_base.unet.model, WAN21) | |
| ): | |
| print( | |
| f"Model {type(self.model_base.unet.model)} not supported. Expected Wan Video model." | |
| ) | |
| self.model_base = None | |
| if self.model_base is not None: | |
| self.model_hash = name | |
| print(f"Base model loaded: {self.model_hash}") | |
| return | |
| def load_keywords(self, lora): | |
| filename = lora.replace(".safetensors", ".txt") | |
| try: | |
| with open(filename, "r") as file: | |
| data = file.read() | |
| return data | |
| except FileNotFoundError: | |
| return " " | |
| def load_loras(self, loras): | |
| loaded_loras = [] | |
| model = self.model_base | |
| for name, weight in loras: | |
| if name == "None" or weight == 0: | |
| continue | |
| filename = str(shared.models.get_file("loras", name)) | |
| print(f"Loading LoRAs: {name}") | |
| try: | |
| lora = comfy.utils.load_torch_file(filename, safe_load=True) | |
| unet, clip = comfy.sd.load_lora_for_models( | |
| model.unet, model.clip, lora, weight, weight | |
| ) | |
| model = self.StableDiffusionModel( | |
| unet=unet, | |
| clip=clip, | |
| vae=model.vae, | |
| clip_vision=model.clip_vision, | |
| ) | |
| loaded_loras += [(name, weight)] | |
| except: | |
| pass | |
| self.model_base_patched = model | |
| self.model_hash_patched = str(loras) | |
| print(f"LoRAs loaded: {loaded_loras}") | |
| return | |
| def refresh_controlnet(self, name=None): | |
| return | |
| def clean_prompt_cond_caches(self): | |
| return | |
| conditions = None | |
| def textencode(self, id, text, clip_skip): | |
| update = False | |
| hash = f"{text} {clip_skip}" | |
| if hash != self.conditions[id]["text"]: | |
| self.conditions[id]["cache"] = CLIPTextEncode().encode( | |
| clip=self.model_base_patched.clip, text=text | |
| )[0] | |
| self.conditions[id]["text"] = hash | |
| update = True | |
| return update | |
| def vae_decode_fake(self, latents): | |
| # FIXME: This should probably just be import from comfyui | |
| latent_rgb_factors = [ | |
| [-0.1299, -0.1692, 0.2932], | |
| [ 0.0671, 0.0406, 0.0442], | |
| [ 0.3568, 0.2548, 0.1747], | |
| [ 0.0372, 0.2344, 0.1420], | |
| [ 0.0313, 0.0189, -0.0328], | |
| [ 0.0296, -0.0956, -0.0665], | |
| [-0.3477, -0.4059, -0.2925], | |
| [ 0.0166, 0.1902, 0.1975], | |
| [-0.0412, 0.0267, -0.1364], | |
| [-0.1293, 0.0740, 0.1636], | |
| [ 0.0680, 0.3019, 0.1128], | |
| [ 0.0032, 0.0581, 0.0639], | |
| [-0.1251, 0.0927, 0.1699], | |
| [ 0.0060, -0.0633, 0.0005], | |
| [ 0.3477, 0.2275, 0.2950], | |
| [ 0.1984, 0.0913, 0.1861] | |
| ] | |
| latent_rgb_factors_bias = [-0.1835, -0.0868, -0.3360] | |
| weight = torch.tensor(latent_rgb_factors, device=latents.device, dtype=latents.dtype).transpose(0, 1)[:, :, None, None, None] | |
| bias = torch.tensor(latent_rgb_factors_bias, device=latents.device, dtype=latents.dtype) | |
| images = torch.nn.functional.conv3d(latents, weight, bias=bias, stride=1, padding=0, dilation=1, groups=1) | |
| images = images.clamp(0.0, 1.0) | |
| return images | |
| def process( | |
| self, | |
| gen_data=None, | |
| callback=None, | |
| ): | |
| shared.state["preview_total"] = 1 | |
| seed = gen_data["seed"] if isinstance(gen_data["seed"], int) else random.randint(1, 2**32) | |
| if callback is not None: | |
| worker.add_result( | |
| gen_data["task_id"], | |
| "preview", | |
| (-1, f"Processing text encoding ...", "html/generate_video.jpeg") | |
| ) | |
| if self.conditions is None: | |
| self.conditions = clean_prompt_cond_caches() | |
| positive_prompt = gen_data["positive_prompt"] | |
| negative_prompt = gen_data["negative_prompt"] | |
| clip_skip = 1 | |
| self.textencode("+", positive_prompt, clip_skip) | |
| self.textencode("-", negative_prompt, clip_skip) | |
| pbar = comfy.utils.ProgressBar(gen_data["steps"]) | |
| def callback_function(step, x0, x, total_steps): | |
| y = self.vae_decode_fake(x0) | |
| y = (y * 255.0).detach().cpu().numpy().clip(0, 255).astype(np.uint8) | |
| y = einops.rearrange(y, 'b c t h w -> (b h) (t w) c') | |
| # Skip callback() since we'll just confuse the preview grid and push updates outselves | |
| status = "Generating video" | |
| maxw = 1920 | |
| maxh = 1080 | |
| image = Image.fromarray(y) | |
| ow, oh = image.size | |
| scale = min(maxh / oh, maxw / ow) | |
| image = image.resize((int(ow * scale), int(oh * scale)), Image.LANCZOS) | |
| worker.add_result( | |
| gen_data["task_id"], | |
| "preview", | |
| ( | |
| int(100 * (step / total_steps)), | |
| f"{status} - {step}/{total_steps}", | |
| image | |
| ) | |
| ) | |
| # pbar.update_absolute(step + 1, total_steps, None) | |
| # ModelSamplingSD3 | |
| model_sampling = ModelSamplingSD3().patch( | |
| model = self.model_base_patched.unet, | |
| shift = 8.0, | |
| )[0] | |
| # t2v or i2v? | |
| if gen_data["input_image"]: | |
| image = np.array(gen_data["input_image"]).astype(np.float32) / 255.0 | |
| image = torch.from_numpy(image)[None,] | |
| clip_vision_output = self.model_base_patched.clip_vision.encode_image(image) | |
| (positive, negative, latent_image) = WanImageToVideo().encode( | |
| positive = self.conditions["+"]["cache"], | |
| negative = self.conditions["-"]["cache"], | |
| vae = self.model_base_patched.vae, | |
| width = gen_data["width"], | |
| height = gen_data["height"], | |
| length = gen_data["original_image_number"], | |
| batch_size = 1, | |
| start_image = image, | |
| clip_vision_output = clip_vision_output, | |
| ) | |
| else: | |
| # latent_image | |
| latent_image = EmptyHunyuanLatentVideo().generate( | |
| width = gen_data["width"], | |
| height = gen_data["height"], | |
| length = gen_data["original_image_number"], | |
| batch_size = 1, | |
| )[0] | |
| positive = self.conditions["+"]["cache"] | |
| negative = self.conditions["-"]["cache"] | |
| worker.add_result( | |
| gen_data["task_id"], | |
| "preview", | |
| (-1, f"Generating ...", "html/generate_video.jpeg") | |
| ) | |
| noise = comfy.sample.prepare_noise(latent_image["samples"], seed) | |
| sampled = comfy.sample.sample( | |
| model = model_sampling, | |
| noise = noise, | |
| steps = gen_data["steps"], | |
| cfg = gen_data["cfg"], | |
| sampler_name = gen_data["sampler_name"], | |
| scheduler = gen_data["scheduler"], | |
| positive = positive, | |
| negative = negative, | |
| latent_image = latent_image["samples"], | |
| denoise = 1, | |
| callback = callback_function, | |
| ) | |
| if callback is not None: | |
| worker.add_result( | |
| gen_data["task_id"], | |
| "preview", | |
| (-1, f"VAE Decoding ...", None) | |
| ) | |
| latent_image["samples"] = sampled | |
| decoded_latent = VAEDecodeTiled().decode( | |
| samples=latent_image, | |
| tile_size=128, | |
| overlap=64, | |
| vae=self.model_base_patched.vae, | |
| )[0] | |
| pil_images = [] | |
| for image in decoded_latent: | |
| i = 255. * image.cpu().numpy() | |
| img = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8)) | |
| pil_images.append(img) | |
| if callback is not None: | |
| worker.add_result( | |
| gen_data["task_id"], | |
| "preview", | |
| (-1, f"Saving ...", None) | |
| ) | |
| file = generate_temp_filename( | |
| folder=path_manager.model_paths["temp_outputs_path"], extension="gif" | |
| ) | |
| os.makedirs(os.path.dirname(file), exist_ok=True) | |
| fps=12.0 | |
| compress_level=4 # Min = 0, Max = 9 | |
| # Save GIF | |
| pil_images[0].save( | |
| file, | |
| compress_level=compress_level, | |
| save_all=True, | |
| duration=int(1000.0/fps), | |
| append_images=pil_images[1:], | |
| optimize=True, | |
| loop=0, | |
| ) | |
| # Save mp4 | |
| fourcc = cv2.VideoWriter_fourcc(*'mp4v') | |
| mp4_file = file.with_suffix(".mp4") | |
| out = cv2.VideoWriter(mp4_file, fourcc, fps, (gen_data["width"], gen_data["height"])) | |
| for frame in pil_images: | |
| out.write(cv2.cvtColor(np.asarray(frame), cv2.COLOR_BGR2RGB)) | |
| out.release() | |
| return [file] | |