FLUX.1-dev-base / app.py
cbensimon's picture
cbensimon HF Staff
Disable optimizations
c549a9b verified
raw
history blame contribute delete
869 Bytes
from datetime import datetime
import gradio as gr
import spaces
import torch
from diffusers import FluxPipeline
from optimization import optimize_pipeline_
pipeline = FluxPipeline.from_pretrained('black-forest-labs/FLUX.1-dev', torch_dtype=torch.bfloat16).to('cuda')
# optimize_pipeline_(pipeline, "prompt")
@spaces.GPU
def generate_image(prompt: str, progress=gr.Progress(track_tqdm=True)):
generator = torch.Generator(device='cuda').manual_seed(42)
t0 = datetime.now()
output = pipeline(
prompt=prompt,
num_inference_steps=28,
generator=generator,
)
return [(output.images[0], f'{(datetime.now() - t0).total_seconds():.2f}s')]
gr.Interface(
fn=generate_image,
inputs=gr.Text(label="Prompt"),
outputs=gr.Gallery(),
examples=["A cat playing with a ball of yarn"],
cache_examples=False,
).launch()