|
|
import gradio as gr |
|
|
import sys |
|
|
import torch |
|
|
|
|
|
from PIL import Image |
|
|
import numpy as np |
|
|
from io import BytesIO |
|
|
import os |
|
|
|
|
|
from diffusers.utils import load_image |
|
|
from diffusers import ControlNetModel |
|
|
import numpy as np |
|
|
import torch |
|
|
from diffusers.image_processor import VaeImageProcessor |
|
|
from PIL import Image |
|
|
from pipeline_controlnet_blip_diffusion import BlipDiffusionControlNetPipeline |
|
|
|
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
|
|
|
blip_diffusion_pipe = BlipDiffusionControlNetPipeline.from_pretrained( |
|
|
"Salesforce/blipdiffusion-controlnet" |
|
|
) |
|
|
controlnet = ControlNetModel.from_pretrained("lllyasviel/control_v11p_sd15_inpaint") |
|
|
|
|
|
blip_diffusion_pipe.controlnet = controlnet |
|
|
blip_diffusion_pipe.to(device) |
|
|
|
|
|
def make_inpaint_condition(image, image_mask): |
|
|
image = np.array(image.convert("RGB")).astype(np.float32) / 255.0 |
|
|
image_mask = np.array(image_mask.convert("L")).astype(np.float32) / 255.0 |
|
|
assert image.shape[0:1] == image_mask.shape[0:1], "image and image_mask must have the same image size" |
|
|
image[image_mask > 0.5] = -1 |
|
|
image = np.expand_dims(image, 0).transpose(0, 3, 1, 2) |
|
|
image = torch.from_numpy(image) |
|
|
return image |
|
|
|
|
|
css=''' |
|
|
.container {max-width: 1150px;margin: auto;padding-top: 1.5rem} |
|
|
.image_upload{min-height:500px} |
|
|
.image_upload [data-testid="image"], .image_upload [data-testid="image"] > div{min-height: 500px} |
|
|
.image_upload [data-testid="target"], .image_upload [data-testid="target"] > div{min-height: 500px} |
|
|
.image_upload .touch-none{display: flex} |
|
|
#output_image{min-height:500px;max-height=500px;} |
|
|
''' |
|
|
|
|
|
|
|
|
def create_demo(): |
|
|
|
|
|
HEIGHT, WIDTH=512,512 |
|
|
with gr.Blocks(theme=gr.themes.Default(font=[gr.themes.GoogleFont("IBM Plex Mono"), "ui-monospace","monospace"], |
|
|
primary_hue="lime", |
|
|
secondary_hue="emerald", |
|
|
neutral_hue="slate", |
|
|
), css=css) as demo: |
|
|
gr.Markdown('# BLIP-Diffusion') |
|
|
with gr.Accordion('Instructions', open=False): |
|
|
gr.Markdown('1. Upload src image and draw mask') |
|
|
gr.Markdown('2. Upload tgt image') |
|
|
gr.Markdown('3. Input name of tgt object and description') |
|
|
gr.Markdown('4. Click `Generate` when it is ready!') |
|
|
|
|
|
with gr.Group(): |
|
|
with gr.Box(): |
|
|
with gr.Column(): |
|
|
with gr.Row() as main_blocks: |
|
|
|
|
|
with gr.Column() as step_1: |
|
|
gr.Markdown('### Source Input and Add Mask') |
|
|
image = gr.Image(source='upload', |
|
|
shape=[HEIGHT,WIDTH], |
|
|
type='pil', |
|
|
elem_classes="image_upload", |
|
|
label='Source Image', |
|
|
tool='sketch', |
|
|
brush_radius=60).style(height=500) |
|
|
src_input=image |
|
|
text_prompt = gr.Textbox(label='Prompt') |
|
|
run_button = gr.Button(label='Generate', value='Generate', variant="primary") |
|
|
|
|
|
with gr.Column() as step_2: |
|
|
gr.Markdown('### Target Input') |
|
|
target = gr.Image(source='upload', |
|
|
shape=[HEIGHT,WIDTH], |
|
|
type='pil', |
|
|
elem_classes="image_upload", |
|
|
label='Target Image' |
|
|
).style(height=500) |
|
|
tgt_input=target |
|
|
style_subject = gr.Textbox(label='Target Object') |
|
|
|
|
|
with gr.Row() as output_blocks: |
|
|
with gr.Column() as output_step: |
|
|
gr.Markdown('### Output') |
|
|
output_image = gr.Gallery( |
|
|
label="Generated images", |
|
|
show_label=False, |
|
|
elem_id="output_image", |
|
|
).style(height=500,containter=True) |
|
|
|
|
|
with gr.Accordion('Advanced options', open=False): |
|
|
num_inference_steps = gr.Slider(label='Steps', |
|
|
minimum=1, |
|
|
maximum=100, |
|
|
value=50, |
|
|
step=1) |
|
|
guidance_scale = gr.Slider(label='Text Guidance Scale', |
|
|
minimum=0.1, |
|
|
maximum=30.0, |
|
|
value=7.5, |
|
|
step=0.1) |
|
|
seed = gr.Slider(label='Seed', |
|
|
minimum=-1, |
|
|
maximum=2147483647, |
|
|
step=1, |
|
|
randomize=True) |
|
|
|
|
|
|
|
|
inputs = [ |
|
|
src_input, |
|
|
tgt_input, |
|
|
text_prompt, |
|
|
style_subject, |
|
|
num_inference_steps, |
|
|
guidance_scale, |
|
|
seed, |
|
|
] |
|
|
|
|
|
def generate(src_input, |
|
|
tgt_input, |
|
|
text_prompt, |
|
|
style_subject, |
|
|
num_inference_steps, |
|
|
guidance_scale, |
|
|
seed, |
|
|
): |
|
|
if src_input is None or tgt_input is None: |
|
|
gr.Error("You must upload an image first.") |
|
|
return {output_image : None,} |
|
|
|
|
|
tgt_subject = style_subject |
|
|
generator = torch.Generator(device="cpu").manual_seed(seed) |
|
|
init_image = src_input['image'] |
|
|
cldm_cond_image = src_input['mask'] |
|
|
control_image = make_inpaint_condition(init_image, cldm_cond_image) |
|
|
style_image = tgt_input |
|
|
|
|
|
negative_prompt = "over-exposure, under-exposure, saturated, duplicate, out of frame, lowres, cropped, worst quality, low quality, jpeg artifacts, morbid, mutilated, out of frame, ugly, bad anatomy, bad proportions, deformed, blurry, duplicate" |
|
|
|
|
|
output = blip_diffusion_pipe( |
|
|
text_prompt, |
|
|
style_image, |
|
|
control_image, |
|
|
style_subject, |
|
|
tgt_subject, |
|
|
generator=generator, |
|
|
image=init_image, |
|
|
mask_image=cldm_cond_image, |
|
|
guidance_scale=guidance_scale, |
|
|
num_inference_steps=num_inference_steps, |
|
|
neg_prompt=negative_prompt, |
|
|
height=HEIGHT, |
|
|
width=WIDTH, |
|
|
).images |
|
|
return {output_image : output,} |
|
|
|
|
|
run_button.click(fn=generate, inputs=inputs, outputs=[output_image]) |
|
|
return demo |
|
|
|
|
|
if __name__ == '__main__': |
|
|
demo = create_demo() |
|
|
demo.queue().launch() |
|
|
|
|
|
|
|
|
|