Commit
·
1f61707
1
Parent(s):
dff784d
Initial commit
Browse files- .gitattributes +3 -0
- .python-version +1 -0
- inference_flux_model.py +82 -0
- inference_pixart_custom_redux.py +89 -0
- inference_pixart_flux_redux.py +92 -0
- pyproject.toml +16 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
flux-image-variations-model/** filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
pixart-custom-redux/** filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
pixart-flux-redux/** filter=lfs diff=lfs merge=lfs -text
|
.python-version
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
3.9
|
inference_flux_model.py
ADDED
|
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import time
|
| 3 |
+
import torch
|
| 4 |
+
from diffusers import FluxTransformer2DModel
|
| 5 |
+
from transformers import CLIPModel
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
from PIL import Image
|
| 8 |
+
from open_flux_pipeline import FluxWithCFGPipeline
|
| 9 |
+
|
| 10 |
+
pipe = None
|
| 11 |
+
|
| 12 |
+
def generate(prompt, image_prompt=None, guidance_scale=2, num_images=4, resolution=512):
|
| 13 |
+
# Create blank image prompt backgrounds
|
| 14 |
+
image_prompt_kwargs = {
|
| 15 |
+
"image_prompt": Image.new("RGB", (resolution, resolution)),
|
| 16 |
+
"negative_image_prompt": Image.new("RGB", (resolution, resolution)),
|
| 17 |
+
}
|
| 18 |
+
if image_prompt is not None:
|
| 19 |
+
image_prompt_kwargs["image_prompt"] = image_prompt
|
| 20 |
+
|
| 21 |
+
with torch.no_grad():
|
| 22 |
+
images = pipe(
|
| 23 |
+
prompt=prompt,
|
| 24 |
+
negative_prompt="",
|
| 25 |
+
height=resolution,
|
| 26 |
+
width=resolution,
|
| 27 |
+
max_sequence_length=256,
|
| 28 |
+
guidance_scale=guidance_scale,
|
| 29 |
+
num_images_per_prompt=num_images,
|
| 30 |
+
**image_prompt_kwargs
|
| 31 |
+
).images
|
| 32 |
+
|
| 33 |
+
# Concatenate all images horizontally
|
| 34 |
+
widths, heights = zip(*[img.size for img in images])
|
| 35 |
+
total_width = sum(widths) + len(images) - 1
|
| 36 |
+
max_height = max(heights)
|
| 37 |
+
out = Image.new('RGB', (total_width, max_height))
|
| 38 |
+
x_offset = 0
|
| 39 |
+
for img in images:
|
| 40 |
+
out.paste(img, (x_offset, 0))
|
| 41 |
+
x_offset += img.width + 1
|
| 42 |
+
|
| 43 |
+
# If an image prompt was provided, stack it above the generated images
|
| 44 |
+
if image_prompt is not None:
|
| 45 |
+
out_with_image_prompt = Image.new('RGB', (out.width, out.height + 1 + resolution))
|
| 46 |
+
resized_prompt = image_prompt.resize((resolution, resolution), Image.Resampling.BILINEAR)
|
| 47 |
+
out_with_image_prompt.paste(resized_prompt, (0, 0))
|
| 48 |
+
out_with_image_prompt.paste(out, (0, resolution + 1))
|
| 49 |
+
out = out_with_image_prompt
|
| 50 |
+
|
| 51 |
+
# Ensure the output directory exists and save the final image
|
| 52 |
+
Path("image-outputs").mkdir(parents=True, exist_ok=True)
|
| 53 |
+
output_filename = f"image-outputs/{prompt[:40].replace(' ', '_')}.{int(time.time())}.png"
|
| 54 |
+
out.save(output_filename)
|
| 55 |
+
print(f"Saved output to {output_filename}")
|
| 56 |
+
|
| 57 |
+
def main():
|
| 58 |
+
parser = argparse.ArgumentParser(description="Generate images using an image and a text prompt (Flux Image Variations).")
|
| 59 |
+
parser.add_argument("--prompt", type=str, default="", help='The text prompt for image generation (default "")')
|
| 60 |
+
parser.add_argument("--image_prompt", type=str, default=None,
|
| 61 |
+
help="Path to an optional image to use as a prompt")
|
| 62 |
+
parser.add_argument("--guidance_scale", type=float, default=2,
|
| 63 |
+
help="Guidance scale for image generation (default: 2)")
|
| 64 |
+
parser.add_argument("--num_images", type=int, default=4,
|
| 65 |
+
help="Number of images to generate (default: 4)")
|
| 66 |
+
parser.add_argument("--resolution", type=int, default=512,
|
| 67 |
+
help="Resolution for generated images (default: 512)")
|
| 68 |
+
args = parser.parse_args()
|
| 69 |
+
|
| 70 |
+
# Load models and pipelines
|
| 71 |
+
global pipe
|
| 72 |
+
clip = CLIPModel.from_pretrained("openai/clip-vit-large-patch14", torch_dtype=torch.bfloat16)
|
| 73 |
+
pipe = FluxWithCFGPipeline.from_pretrained("ostris/OpenFLUX.1", text_encoder=clip, transformer=None, torch_dtype=torch.bfloat16)
|
| 74 |
+
pipe.transformer = FluxTransformer2DModel.from_pretrained("flux-image-variations-model", torch_dtype=torch.bfloat16)
|
| 75 |
+
pipe.to("cuda")
|
| 76 |
+
|
| 77 |
+
img_prompt = Image.open(args.image_prompt) if args.image_prompt else None
|
| 78 |
+
generate(args.prompt, image_prompt=img_prompt, guidance_scale=args.guidance_scale,
|
| 79 |
+
num_images=args.num_images, resolution=args.resolution)
|
| 80 |
+
|
| 81 |
+
if __name__ == "__main__":
|
| 82 |
+
main()
|
inference_pixart_custom_redux.py
ADDED
|
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import time
|
| 3 |
+
import torch
|
| 4 |
+
from diffusers import PixArtAlphaPipeline
|
| 5 |
+
from diffusers.pipelines.flux import FluxPriorReduxPipeline
|
| 6 |
+
from diffusers.pipelines.flux.modeling_flux import ReduxImageEncoder
|
| 7 |
+
from transformers import SiglipImageProcessor
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
from PIL import Image
|
| 10 |
+
|
| 11 |
+
pipe = None
|
| 12 |
+
redux = None
|
| 13 |
+
redux_embedder = None
|
| 14 |
+
|
| 15 |
+
def generate(prompt, image_prompt=None, guidance_scale=2, num_images=4, resolution=512):
|
| 16 |
+
with torch.no_grad():
|
| 17 |
+
clip_image_processor = SiglipImageProcessor(size={"height": 384, "width": 384})
|
| 18 |
+
clip_pixel_values = clip_image_processor.preprocess(
|
| 19 |
+
image_prompt.convert("RGB"), return_tensors="pt"
|
| 20 |
+
).pixel_values.to("cuda", dtype=torch.bfloat16)
|
| 21 |
+
|
| 22 |
+
image_prompt_latents = redux.image_encoder(clip_pixel_values).last_hidden_state
|
| 23 |
+
image_prompt_embeds = redux_embedder(image_prompt_latents).image_embeds
|
| 24 |
+
prompt_embeds = image_prompt_embeds[:, :120, :]
|
| 25 |
+
attention_mask = torch.ones(prompt_embeds.shape[0], prompt_embeds.shape[1]).to("cuda")
|
| 26 |
+
|
| 27 |
+
images = pipe(
|
| 28 |
+
prompt_embeds=prompt_embeds,
|
| 29 |
+
prompt_attention_mask=attention_mask,
|
| 30 |
+
negative_prompt="",
|
| 31 |
+
height=resolution,
|
| 32 |
+
width=resolution,
|
| 33 |
+
guidance_scale=guidance_scale,
|
| 34 |
+
num_images_per_prompt=num_images,
|
| 35 |
+
).images
|
| 36 |
+
|
| 37 |
+
# Concatenate all images horizontally
|
| 38 |
+
widths, heights = zip(*[img.size for img in images])
|
| 39 |
+
total_width = sum(widths) + len(images) - 1
|
| 40 |
+
max_height = max(heights)
|
| 41 |
+
out = Image.new('RGB', (total_width, max_height))
|
| 42 |
+
x_offset = 0
|
| 43 |
+
for img in images:
|
| 44 |
+
out.paste(img, (x_offset, 0))
|
| 45 |
+
x_offset += img.width + 1
|
| 46 |
+
|
| 47 |
+
# If an image prompt was provided, stack it above the generated images
|
| 48 |
+
if image_prompt is not None:
|
| 49 |
+
out_with_image_prompt = Image.new('RGB', (out.width, out.height + 1 + resolution))
|
| 50 |
+
resized_prompt = image_prompt.resize((resolution, resolution), Image.Resampling.BILINEAR)
|
| 51 |
+
out_with_image_prompt.paste(resized_prompt, (0, 0))
|
| 52 |
+
out_with_image_prompt.paste(out, (0, resolution + 1))
|
| 53 |
+
out = out_with_image_prompt
|
| 54 |
+
|
| 55 |
+
Path("image-outputs").mkdir(parents=True, exist_ok=True)
|
| 56 |
+
output_filename = f"image-outputs/{prompt[:40].replace(' ', '_')}.{int(time.time())}.png"
|
| 57 |
+
out.save(output_filename)
|
| 58 |
+
print(f"Saved output to {output_filename}")
|
| 59 |
+
|
| 60 |
+
def main():
|
| 61 |
+
parser = argparse.ArgumentParser(
|
| 62 |
+
description="Generate images using an image and a text prompt (PixArt Custom Redux)."
|
| 63 |
+
)
|
| 64 |
+
parser.add_argument("--prompt", type=str, default="",
|
| 65 |
+
help='The text prompt for image generation (default: "")')
|
| 66 |
+
parser.add_argument("--image_prompt", type=str, default=None,
|
| 67 |
+
help="Path to an optional image to use as a prompt")
|
| 68 |
+
parser.add_argument("--guidance_scale", type=float, default=2,
|
| 69 |
+
help="Guidance scale for image generation (default: 2)")
|
| 70 |
+
parser.add_argument("--num_images", type=int, default=4,
|
| 71 |
+
help="Number of images to generate (default: 4)")
|
| 72 |
+
parser.add_argument("--resolution", type=int, default=512,
|
| 73 |
+
help="Resolution for generated images (default: 512)")
|
| 74 |
+
args = parser.parse_args()
|
| 75 |
+
|
| 76 |
+
global pipe, redux, redux_embedder
|
| 77 |
+
pipe = PixArtAlphaPipeline.from_pretrained("PixArt-alpha/PixArt-XL-2-512x512", torch_dtype=torch.bfloat16)
|
| 78 |
+
redux_embedder = ReduxImageEncoder.from_pretrained("pixart-custom-redux", torch_dtype=torch.bfloat16)
|
| 79 |
+
redux = FluxPriorReduxPipeline.from_pretrained("FLUX.1-Redux-dev", image_embedder=redux_embedder, torch_dtype=torch.bfloat16)
|
| 80 |
+
|
| 81 |
+
pipe.to("cuda")
|
| 82 |
+
redux.to("cuda")
|
| 83 |
+
|
| 84 |
+
img_prompt = Image.open(args.image_prompt) if args.image_prompt else None
|
| 85 |
+
generate(args.prompt, image_prompt=img_prompt, guidance_scale=args.guidance_scale,
|
| 86 |
+
num_images=args.num_images, resolution=args.resolution)
|
| 87 |
+
|
| 88 |
+
if __name__ == "__main__":
|
| 89 |
+
main()
|
inference_pixart_flux_redux.py
ADDED
|
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import time
|
| 3 |
+
import torch
|
| 4 |
+
from diffusers import PixArtAlphaPipeline, PixArtTransformer2DModel
|
| 5 |
+
from diffusers.pipelines.flux import FluxPriorReduxPipeline
|
| 6 |
+
from transformers import SiglipImageProcessor
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
from PIL import Image
|
| 9 |
+
|
| 10 |
+
pipe = None
|
| 11 |
+
redux = None
|
| 12 |
+
redux_embedder = None
|
| 13 |
+
|
| 14 |
+
def generate(prompt, image_prompt=None, guidance_scale=2, num_images=4, resolution=512):
|
| 15 |
+
with torch.no_grad():
|
| 16 |
+
clip_image_processor = SiglipImageProcessor(size={"height": 384, "width": 384})
|
| 17 |
+
clip_pixel_values = clip_image_processor.preprocess(
|
| 18 |
+
image_prompt.convert("RGB"), return_tensors="pt"
|
| 19 |
+
).pixel_values.to("cuda", dtype=torch.bfloat16)
|
| 20 |
+
|
| 21 |
+
image_prompt_latents = redux.image_encoder(clip_pixel_values).last_hidden_state
|
| 22 |
+
image_prompt_embeds = redux_embedder(image_prompt_latents).image_embeds
|
| 23 |
+
prompt_embeds = image_prompt_embeds[:, :120, :] * 0.04
|
| 24 |
+
attention_mask = torch.ones(prompt_embeds.shape[0], prompt_embeds.shape[1]).to("cuda")
|
| 25 |
+
|
| 26 |
+
images = pipe(
|
| 27 |
+
prompt_embeds=prompt_embeds,
|
| 28 |
+
prompt_attention_mask=attention_mask,
|
| 29 |
+
negative_prompt="",
|
| 30 |
+
height=resolution,
|
| 31 |
+
width=resolution,
|
| 32 |
+
guidance_scale=guidance_scale,
|
| 33 |
+
num_images_per_prompt=num_images,
|
| 34 |
+
).images
|
| 35 |
+
|
| 36 |
+
# Concatenate all images horizontally
|
| 37 |
+
widths, heights = zip(*[img.size for img in images])
|
| 38 |
+
total_width = sum(widths) + len(images) - 1
|
| 39 |
+
max_height = max(heights)
|
| 40 |
+
out = Image.new('RGB', (total_width, max_height))
|
| 41 |
+
x_offset = 0
|
| 42 |
+
for img in images:
|
| 43 |
+
out.paste(img, (x_offset, 0))
|
| 44 |
+
x_offset += img.width + 1
|
| 45 |
+
|
| 46 |
+
# If an image prompt was provided, stack it above the generated images
|
| 47 |
+
if image_prompt is not None:
|
| 48 |
+
out_with_image_prompt = Image.new('RGB', (out.width, out.height + 1 + resolution))
|
| 49 |
+
resized_prompt = image_prompt.resize((resolution, resolution), Image.Resampling.BILINEAR)
|
| 50 |
+
out_with_image_prompt.paste(resized_prompt, (0, 0))
|
| 51 |
+
out_with_image_prompt.paste(out, (0, resolution + 1))
|
| 52 |
+
out = out_with_image_prompt
|
| 53 |
+
|
| 54 |
+
Path("image-outputs").mkdir(parents=True, exist_ok=True)
|
| 55 |
+
output_filename = f"image-outputs/{prompt[:40].replace(' ', '_')}.{int(time.time())}.png"
|
| 56 |
+
out.save(output_filename)
|
| 57 |
+
print(f"Saved output to {output_filename}")
|
| 58 |
+
|
| 59 |
+
def main():
|
| 60 |
+
parser = argparse.ArgumentParser(
|
| 61 |
+
description="Generate images using an image and a text prompt (PixArt Flux Redux)."
|
| 62 |
+
)
|
| 63 |
+
parser.add_argument("--prompt", type=str, default="",
|
| 64 |
+
help='The text prompt for image generation (default: "")')
|
| 65 |
+
parser.add_argument("--image_prompt", type=str, default=None,
|
| 66 |
+
help="Path to an optional image to use as a prompt")
|
| 67 |
+
parser.add_argument("--guidance_scale", type=float, default=2,
|
| 68 |
+
help="Guidance scale for image generation (default: 2)")
|
| 69 |
+
parser.add_argument("--num_images", type=int, default=4,
|
| 70 |
+
help="Number of images to generate (default: 4)")
|
| 71 |
+
parser.add_argument("--resolution", type=int, default=512,
|
| 72 |
+
help="Resolution for generated images (default: 512)")
|
| 73 |
+
args = parser.parse_args()
|
| 74 |
+
|
| 75 |
+
global pipe, redux, redux_embedder
|
| 76 |
+
pipe = PixArtAlphaPipeline.from_pretrained(
|
| 77 |
+
"PixArt-alpha/PixArt-XL-2-512x512", transformer=None, torch_dtype=torch.bfloat16
|
| 78 |
+
)
|
| 79 |
+
transformer = PixArtTransformer2DModel.from_pretrained("pixart-flux-redux", torch_dtype=torch.bfloat16)
|
| 80 |
+
pipe.transformer = transformer
|
| 81 |
+
redux = FluxPriorReduxPipeline.from_pretrained("FLUX.1-Redux-dev", torch_dtype=torch.bfloat16)
|
| 82 |
+
redux_embedder = redux.image_embedder
|
| 83 |
+
|
| 84 |
+
redux.to("cuda")
|
| 85 |
+
pipe.to("cuda")
|
| 86 |
+
|
| 87 |
+
img_prompt = Image.open(args.image_prompt) if args.image_prompt else None
|
| 88 |
+
generate(args.prompt, image_prompt=img_prompt, guidance_scale=args.guidance_scale,
|
| 89 |
+
num_images=args.num_images, resolution=args.resolution)
|
| 90 |
+
|
| 91 |
+
if __name__ == "__main__":
|
| 92 |
+
main()
|
pyproject.toml
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[project]
|
| 2 |
+
name = "image-variations-experiment"
|
| 3 |
+
description = "Experimental Flux/PixArt finetunes for image variations"
|
| 4 |
+
version = "0.1.0"
|
| 5 |
+
readme = "README.md"
|
| 6 |
+
requires-python = ">=3.9"
|
| 7 |
+
dependencies = [
|
| 8 |
+
"accelerate>=1.4.0",
|
| 9 |
+
"diffusers>=0.32.2",
|
| 10 |
+
"pillow>=11.1.0",
|
| 11 |
+
"protobuf>=5.29.3",
|
| 12 |
+
"sentencepiece>=0.2.0",
|
| 13 |
+
"torch==2.5.1",
|
| 14 |
+
"torchvision==0.20.1",
|
| 15 |
+
"transformers==4.46.1",
|
| 16 |
+
]
|