DZRobo
commited on
Commit
·
e69f3b7
1
Parent(s):
7af46cf
Add Z_image support and Improve latent/channel handling
Browse filesAdds functions to harmonize latent channel counts and condition token lengths to prevent mismatches, especially for models like FLUX/Z_image. Enhances error reporting with debug output and traceback printing. Updates mg_combinode to better validate VAE/CLIP presence for checkpoint and input selection. Fixes hybrid sigma schedule alignment in mg_zesmart_sampler_v1_1.
- mod/easy/mg_cade25_easy.py +146 -0
- mod/hard/mg_cade25.py +1 -0
- mod/hard/mg_zesmart_sampler_v1_1.py +9 -0
- mod/mg_combinode.py +21 -2
mod/easy/mg_cade25_easy.py
CHANGED
|
@@ -7,6 +7,7 @@ import torch
|
|
| 7 |
import os
|
| 8 |
import numpy as np
|
| 9 |
import torch.nn.functional as F
|
|
|
|
| 10 |
|
| 11 |
import nodes
|
| 12 |
import comfy.model_management as model_management
|
|
@@ -1115,6 +1116,133 @@ def safe_decode(vae, lat, tile=512, ovlp=128, to_fp32: bool = False):
|
|
| 1115 |
return out
|
| 1116 |
|
| 1117 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1118 |
def safe_encode(vae, img, tile=512, ovlp=64):
|
| 1119 |
import math, torch.nn.functional as F
|
| 1120 |
h, w = img.shape[1:3]
|
|
@@ -2309,6 +2437,13 @@ class ComfyAdaptiveDetailEnhancer25:
|
|
| 2309 |
except Exception:
|
| 2310 |
pass
|
| 2311 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2312 |
image = safe_decode(vae, latent, to_fp32=bool(vae_decode_fp32))
|
| 2313 |
# allow user cancel right after initial decode
|
| 2314 |
model_management.throw_exception_if_processing_interrupted()
|
|
@@ -2830,6 +2965,7 @@ class ComfyAdaptiveDetailEnhancer25:
|
|
| 2830 |
)
|
| 2831 |
# Prepare latent + noise like in MG_ZeSmartSampler
|
| 2832 |
lat_img = current_latent["samples"]
|
|
|
|
| 2833 |
lat_img = _sample.fix_empty_latent_channels(sampler_model, lat_img)
|
| 2834 |
batch_inds = current_latent.get("batch_index", None)
|
| 2835 |
noise = _sample.prepare_noise(lat_img, int(iter_seed), batch_inds)
|
|
@@ -2848,6 +2984,16 @@ class ComfyAdaptiveDetailEnhancer25:
|
|
| 2848 |
current_latent = {**current_latent}
|
| 2849 |
current_latent["samples"] = samples
|
| 2850 |
except Exception as e:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2851 |
# Before any fallback, propagate user cancel if set
|
| 2852 |
try:
|
| 2853 |
model_management.throw_exception_if_processing_interrupted()
|
|
|
|
| 7 |
import os
|
| 8 |
import numpy as np
|
| 9 |
import torch.nn.functional as F
|
| 10 |
+
import traceback
|
| 11 |
|
| 12 |
import nodes
|
| 13 |
import comfy.model_management as model_management
|
|
|
|
| 1116 |
return out
|
| 1117 |
|
| 1118 |
|
| 1119 |
+
def _match_latent_channels(vae, latent: dict, model=None):
|
| 1120 |
+
"""Align latent channel count to model/VAE expectations (e.g., FLUX/Z_image 16ch) with variance preservation."""
|
| 1121 |
+
if not isinstance(latent, dict) or ("samples" not in latent):
|
| 1122 |
+
return latent
|
| 1123 |
+
z = latent.get("samples", None)
|
| 1124 |
+
if z is None:
|
| 1125 |
+
return latent
|
| 1126 |
+
try:
|
| 1127 |
+
target_c = None
|
| 1128 |
+
# Prefer model latent_format if available (more reliable than VAE decoder)
|
| 1129 |
+
if model is not None:
|
| 1130 |
+
try:
|
| 1131 |
+
lf = model.get_model_object("latent_format")
|
| 1132 |
+
target_c = int(getattr(lf, "latent_channels", None) or 0) or None
|
| 1133 |
+
except Exception:
|
| 1134 |
+
target_c = None
|
| 1135 |
+
fs = getattr(vae, "first_stage_model", None)
|
| 1136 |
+
dec = getattr(fs, "decoder", None)
|
| 1137 |
+
if dec is not None and hasattr(dec, "conv_in"):
|
| 1138 |
+
target_c = target_c or int(dec.conv_in.in_channels)
|
| 1139 |
+
if target_c is None and hasattr(fs, "latent_channels"):
|
| 1140 |
+
target_c = int(getattr(fs, "latent_channels"))
|
| 1141 |
+
if target_c is None and hasattr(vae, "latent_channels"):
|
| 1142 |
+
target_c = int(getattr(vae, "latent_channels"))
|
| 1143 |
+
if target_c is None:
|
| 1144 |
+
return latent
|
| 1145 |
+
cur_c = int(z.shape[1])
|
| 1146 |
+
if cur_c == target_c:
|
| 1147 |
+
return latent
|
| 1148 |
+
# Repeat channels when divisible (common case: 4 -> 16)
|
| 1149 |
+
if target_c % cur_c == 0 and cur_c > 0:
|
| 1150 |
+
rep = target_c // cur_c
|
| 1151 |
+
reps = [1, rep] + [1] * (z.ndim - 2)
|
| 1152 |
+
z_fixed = z.repeat(*reps)
|
| 1153 |
+
# Preserve variance after channel replication
|
| 1154 |
+
z_fixed = z_fixed / (rep ** 0.5)
|
| 1155 |
+
else:
|
| 1156 |
+
# Fallback: pad zeros or slice to match
|
| 1157 |
+
if target_c > cur_c:
|
| 1158 |
+
pad = target_c - cur_c
|
| 1159 |
+
pad_tensor = torch.zeros_like(z[:, :1, ...]).repeat(1, pad, *([1] * (z.ndim - 2)))
|
| 1160 |
+
z_fixed = torch.cat([z, pad_tensor], dim=1)
|
| 1161 |
+
else:
|
| 1162 |
+
z_fixed = z[:, :target_c, ...]
|
| 1163 |
+
latent = {**latent, "samples": z_fixed}
|
| 1164 |
+
except Exception:
|
| 1165 |
+
pass
|
| 1166 |
+
return latent
|
| 1167 |
+
|
| 1168 |
+
|
| 1169 |
+
def _harmonize_cond_tokens(cond_list):
|
| 1170 |
+
"""Pad/truncate cond tokens + masks to a common length to avoid mismatches (e.g., 499 vs 528 or 981 vs 1286)."""
|
| 1171 |
+
if not isinstance(cond_list, list):
|
| 1172 |
+
return cond_list
|
| 1173 |
+
# pass 1: find max token length across cross_attn
|
| 1174 |
+
max_len = 0
|
| 1175 |
+
for c in cond_list:
|
| 1176 |
+
if isinstance(c, dict):
|
| 1177 |
+
ca = c.get("cross_attn", None)
|
| 1178 |
+
if ca is not None:
|
| 1179 |
+
try:
|
| 1180 |
+
max_len = max(max_len, int(ca.shape[1]))
|
| 1181 |
+
except Exception:
|
| 1182 |
+
pass
|
| 1183 |
+
if max_len <= 0:
|
| 1184 |
+
return cond_list
|
| 1185 |
+
fixed = []
|
| 1186 |
+
for c in cond_list:
|
| 1187 |
+
if not isinstance(c, dict):
|
| 1188 |
+
fixed.append(c)
|
| 1189 |
+
continue
|
| 1190 |
+
d = c.copy()
|
| 1191 |
+
ca = d.get("cross_attn", None)
|
| 1192 |
+
am = d.get("attention_mask", None)
|
| 1193 |
+
# Harmonize cross_attn length
|
| 1194 |
+
if ca is not None:
|
| 1195 |
+
try:
|
| 1196 |
+
ca_len = int(ca.shape[1])
|
| 1197 |
+
if ca_len < max_len:
|
| 1198 |
+
pad_shape = list(ca.shape)
|
| 1199 |
+
pad_shape[1] = max_len - ca_len
|
| 1200 |
+
ca_pad = torch.zeros(pad_shape, device=ca.device, dtype=ca.dtype)
|
| 1201 |
+
ca = torch.cat([ca, ca_pad], dim=1)
|
| 1202 |
+
elif ca_len > max_len:
|
| 1203 |
+
ca = ca[:, :max_len, ...]
|
| 1204 |
+
d["cross_attn"] = ca
|
| 1205 |
+
except Exception:
|
| 1206 |
+
pass
|
| 1207 |
+
# Harmonize mask length to cross_attn length
|
| 1208 |
+
if ca is not None:
|
| 1209 |
+
ca_len = int(ca.shape[1])
|
| 1210 |
+
if am is None:
|
| 1211 |
+
am = torch.ones((ca.shape[0], ca_len), device=ca.device, dtype=ca.dtype)
|
| 1212 |
+
try:
|
| 1213 |
+
am_len = int(am.shape[-1] if am.dim() == 2 else am.shape[1])
|
| 1214 |
+
if am_len < ca_len:
|
| 1215 |
+
pad = ca_len - am_len
|
| 1216 |
+
pad_shape = list(am.shape)
|
| 1217 |
+
pad_shape[-1] = pad
|
| 1218 |
+
pad_tensor = torch.zeros(pad_shape, device=am.device, dtype=am.dtype)
|
| 1219 |
+
am = torch.cat([am, pad_tensor], dim=-1)
|
| 1220 |
+
elif am_len > ca_len:
|
| 1221 |
+
am = am[..., :ca_len]
|
| 1222 |
+
d["attention_mask"] = am
|
| 1223 |
+
try:
|
| 1224 |
+
d["num_tokens"] = int(torch.count_nonzero(am, dim=-1).max().item())
|
| 1225 |
+
except Exception:
|
| 1226 |
+
d["num_tokens"] = ca_len
|
| 1227 |
+
except Exception:
|
| 1228 |
+
pass
|
| 1229 |
+
fixed.append(d)
|
| 1230 |
+
return fixed
|
| 1231 |
+
|
| 1232 |
+
|
| 1233 |
+
def _summarize_conds(label, conds):
|
| 1234 |
+
out = []
|
| 1235 |
+
if isinstance(conds, list):
|
| 1236 |
+
for idx, c in enumerate(conds):
|
| 1237 |
+
try:
|
| 1238 |
+
ca = c.get("cross_attn", None) if isinstance(c, dict) else None
|
| 1239 |
+
am = c.get("attention_mask", None) if isinstance(c, dict) else None
|
| 1240 |
+
out.append(f"{label}[{idx}]: ca={None if ca is None else list(ca.shape)}, am={None if am is None else list(am.shape)}")
|
| 1241 |
+
except Exception:
|
| 1242 |
+
pass
|
| 1243 |
+
return "; ".join(out)
|
| 1244 |
+
|
| 1245 |
+
|
| 1246 |
def safe_encode(vae, img, tile=512, ovlp=64):
|
| 1247 |
import math, torch.nn.functional as F
|
| 1248 |
h, w = img.shape[1:3]
|
|
|
|
| 2437 |
except Exception:
|
| 2438 |
pass
|
| 2439 |
|
| 2440 |
+
# Align latent channels to VAE/model (e.g., Z_image/FLUX use 16ch latents)
|
| 2441 |
+
latent = _match_latent_channels(vae, latent, model)
|
| 2442 |
+
|
| 2443 |
+
# Harmonize cond token lengths to prevent rare MGHybrid size mismatches
|
| 2444 |
+
positive = _harmonize_cond_tokens(positive)
|
| 2445 |
+
negative = _harmonize_cond_tokens(negative)
|
| 2446 |
+
|
| 2447 |
image = safe_decode(vae, latent, to_fp32=bool(vae_decode_fp32))
|
| 2448 |
# allow user cancel right after initial decode
|
| 2449 |
model_management.throw_exception_if_processing_interrupted()
|
|
|
|
| 2965 |
)
|
| 2966 |
# Prepare latent + noise like in MG_ZeSmartSampler
|
| 2967 |
lat_img = current_latent["samples"]
|
| 2968 |
+
lat_img = _match_latent_channels(vae, {"samples": lat_img}, sampler_model)["samples"]
|
| 2969 |
lat_img = _sample.fix_empty_latent_channels(sampler_model, lat_img)
|
| 2970 |
batch_inds = current_latent.get("batch_index", None)
|
| 2971 |
noise = _sample.prepare_noise(lat_img, int(iter_seed), batch_inds)
|
|
|
|
| 2984 |
current_latent = {**current_latent}
|
| 2985 |
current_latent["samples"] = samples
|
| 2986 |
except Exception as e:
|
| 2987 |
+
try:
|
| 2988 |
+
print(f"[CADE2.5][MGHybrid][debug] sigmas={list(sigmas.shape)} lat={list(current_latent['samples'].shape)}")
|
| 2989 |
+
print(_summarize_conds("pos", positive))
|
| 2990 |
+
print(_summarize_conds("neg", negative))
|
| 2991 |
+
except Exception:
|
| 2992 |
+
pass
|
| 2993 |
+
try:
|
| 2994 |
+
traceback.print_exc()
|
| 2995 |
+
except Exception:
|
| 2996 |
+
pass
|
| 2997 |
# Before any fallback, propagate user cancel if set
|
| 2998 |
try:
|
| 2999 |
model_management.throw_exception_if_processing_interrupted()
|
mod/hard/mg_cade25.py
CHANGED
|
@@ -11,6 +11,7 @@ import torch
|
|
| 11 |
import os
|
| 12 |
import numpy as np
|
| 13 |
import torch.nn.functional as F
|
|
|
|
| 14 |
|
| 15 |
import nodes
|
| 16 |
import comfy.model_management as model_management
|
|
|
|
| 11 |
import os
|
| 12 |
import numpy as np
|
| 13 |
import torch.nn.functional as F
|
| 14 |
+
import traceback
|
| 15 |
|
| 16 |
import nodes
|
| 17 |
import comfy.model_management as model_management
|
mod/hard/mg_zesmart_sampler_v1_1.py
CHANGED
|
@@ -33,7 +33,15 @@ def _build_hybrid_sigmas(model, steps: int, base_sampler: str, mode: str,
|
|
| 33 |
sig_k = _samplers.calculate_sigmas(ms, "karras", steps)
|
| 34 |
sig_b = _samplers.calculate_sigmas(ms, "beta", steps)
|
| 35 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
mode = str(mode).lower()
|
|
|
|
| 37 |
if mode == "karras":
|
| 38 |
sig = sig_k
|
| 39 |
elif mode == "beta":
|
|
@@ -54,6 +62,7 @@ def _build_hybrid_sigmas(model, steps: int, base_sampler: str, mode: str,
|
|
| 54 |
new_steps = max(1, int(steps / max(1e-6, float(denoise))))
|
| 55 |
sk = _samplers.calculate_sigmas(ms, "karras", new_steps)
|
| 56 |
sb = _samplers.calculate_sigmas(ms, "beta", new_steps)
|
|
|
|
| 57 |
if mode == "karras":
|
| 58 |
sig_full = sk
|
| 59 |
elif mode == "beta":
|
|
|
|
| 33 |
sig_k = _samplers.calculate_sigmas(ms, "karras", steps)
|
| 34 |
sig_b = _samplers.calculate_sigmas(ms, "beta", steps)
|
| 35 |
|
| 36 |
+
def _align_len(a: torch.Tensor, b: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
| 37 |
+
"""Align two sigma schedules to the same length (use tail of longer)."""
|
| 38 |
+
if a.shape[0] == b.shape[0]:
|
| 39 |
+
return a, b
|
| 40 |
+
m = min(a.shape[0], b.shape[0])
|
| 41 |
+
return a[-m:], b[-m:]
|
| 42 |
+
|
| 43 |
mode = str(mode).lower()
|
| 44 |
+
sig_k, sig_b = _align_len(sig_k, sig_b)
|
| 45 |
if mode == "karras":
|
| 46 |
sig = sig_k
|
| 47 |
elif mode == "beta":
|
|
|
|
| 62 |
new_steps = max(1, int(steps / max(1e-6, float(denoise))))
|
| 63 |
sk = _samplers.calculate_sigmas(ms, "karras", new_steps)
|
| 64 |
sb = _samplers.calculate_sigmas(ms, "beta", new_steps)
|
| 65 |
+
sk, sb = _align_len(sk, sb)
|
| 66 |
if mode == "karras":
|
| 67 |
sig_full = sk
|
| 68 |
elif mode == "beta":
|
mod/mg_combinode.py
CHANGED
|
@@ -275,13 +275,30 @@ class MagicNodesCombiNode:
|
|
| 275 |
pos_text_expanded = _norm_prompt(_expand_dynamic(positive_prompt, int(dyn_seed), bool(dynamic_break_freeze)) if bool(dynamic_pos) else positive_prompt)
|
| 276 |
neg_text_expanded = _norm_prompt(_expand_dynamic(negative_prompt, int(dyn_seed), bool(dynamic_break_freeze)) if bool(dynamic_neg) else negative_prompt)
|
| 277 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 278 |
if use_checkpoint and checkpoint:
|
| 279 |
checkpoint_path = folder_paths.get_full_path_or_raise("checkpoints", checkpoint)
|
| 280 |
_unload_old_checkpoint(checkpoint_path)
|
| 281 |
base_model, base_clip, vae = _load_checkpoint(checkpoint_path)
|
| 282 |
model = base_model.clone()
|
| 283 |
-
|
| 284 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 285 |
|
| 286 |
elif model_in and clip_in:
|
| 287 |
_unload_old_checkpoint(None)
|
|
@@ -289,6 +306,8 @@ class MagicNodesCombiNode:
|
|
| 289 |
clip = clip_in.clone()
|
| 290 |
clip_clean = clip_in.clone()
|
| 291 |
vae = vae_in
|
|
|
|
|
|
|
| 292 |
else:
|
| 293 |
raise Exception("No model selected!")
|
| 294 |
|
|
|
|
| 275 |
pos_text_expanded = _norm_prompt(_expand_dynamic(positive_prompt, int(dyn_seed), bool(dynamic_break_freeze)) if bool(dynamic_pos) else positive_prompt)
|
| 276 |
neg_text_expanded = _norm_prompt(_expand_dynamic(negative_prompt, int(dyn_seed), bool(dynamic_break_freeze)) if bool(dynamic_neg) else negative_prompt)
|
| 277 |
|
| 278 |
+
def _valid_vae(v):
|
| 279 |
+
try:
|
| 280 |
+
return (v is not None) and (getattr(v, "first_stage_model", None) is not None)
|
| 281 |
+
except Exception:
|
| 282 |
+
return False
|
| 283 |
+
|
| 284 |
if use_checkpoint and checkpoint:
|
| 285 |
checkpoint_path = folder_paths.get_full_path_or_raise("checkpoints", checkpoint)
|
| 286 |
_unload_old_checkpoint(checkpoint_path)
|
| 287 |
base_model, base_clip, vae = _load_checkpoint(checkpoint_path)
|
| 288 |
model = base_model.clone()
|
| 289 |
+
# Some flow/DiT style checkpoints (e.g., Z_image) ship without CLIP/VAE.
|
| 290 |
+
clip_source = base_clip or clip_in
|
| 291 |
+
if clip_source is None:
|
| 292 |
+
raise Exception("Checkpoint has no CLIP. Connect a CLIP input node or use a checkpoint that bundles CLIP.")
|
| 293 |
+
clip = clip_source.clone()
|
| 294 |
+
clip_clean = clip_source.clone() # keep pristine CLIP for standard pipeline path
|
| 295 |
+
# Prefer external VAE when provided; some FLOW/DiT checkpoints return an invalid stub VAE.
|
| 296 |
+
for candidate in (vae_in, vae):
|
| 297 |
+
if _valid_vae(candidate):
|
| 298 |
+
vae = candidate
|
| 299 |
+
break
|
| 300 |
+
else:
|
| 301 |
+
raise Exception("Checkpoint has no valid VAE. Connect a VAE input node or use a checkpoint that bundles VAE.")
|
| 302 |
|
| 303 |
elif model_in and clip_in:
|
| 304 |
_unload_old_checkpoint(None)
|
|
|
|
| 306 |
clip = clip_in.clone()
|
| 307 |
clip_clean = clip_in.clone()
|
| 308 |
vae = vae_in
|
| 309 |
+
if not _valid_vae(vae):
|
| 310 |
+
raise Exception("VAE input is missing or invalid. Please connect a proper VAE node.")
|
| 311 |
else:
|
| 312 |
raise Exception("No model selected!")
|
| 313 |
|