File size: 3,629 Bytes
b4e0431 a81b4f4 b4e0431 9633e6a b4e0431 d7a843d b4e0431 4a86ebc d7a843d b4e0431 9633e6a fb21251 9633e6a 4a86ebc 9633e6a b4e0431 d7a843d b4e0431 37b2b46 d6b8beb b4e0431 d7a843d b4e0431 d7a843d b4e0431 9633e6a b4e0431 d7a843d b4e0431 b8d4769 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 | import gradio as gr
import torch, torchvision
import torch.nn.functional as F
import numpy as np
from PIL import Image, ImageColor
from diffusers import DDPMPipeline
from diffusers import DDIMScheduler
device = 'mps' if torch.backends.mps.is_available() else 'cuda' if torch.cuda.is_available() else 'cpu'
# Load the pretrained pipeline
pipeline_name = 'johnowhitaker/sd-class-wikiart-from-bedrooms'
image_pipe = DDPMPipeline.from_pretrained(pipeline_name).to(device)
# Set up the scheduler
scheduler = DDIMScheduler.from_pretrained(pipeline_name)
scheduler.set_timesteps(num_inference_steps=20)
# The guidance function
def color_loss(images, target_color=(0.1, 0.9, 0.5)):
"""Given a target color (R, G, B) return a loss for how far away on average
the images' pixels are from that color. Defaults to a light teal: (0.1, 0.9, 0.5) """
target = torch.tensor(target_color).to(images.device) * 2 - 1 # Map target color to (-1, 1)
target = target[None, :, None, None] # Get shape right to work with the images (b, c, h, w)
error = torch.abs(images - target).mean() # Mean absolute difference between the image pixels and the target color
return error
def monochromatic_loss(images, threshold=0.5, target_value=0.01):
# Convert images to grayscale (simple average of channels)
# We assume images are [N, C, H, W] where C=3 (RGB)
grayscale_images = (images[:,0,:,:] + images[:,1,:,:] + images[:,2,:,:]) / 3.0
# Penalize pixels that are not close to black or white
# Encourage values close to target_value (e.g., 0.01 for black) or 1.0 (for white)
# This creates a strong push towards high contrast
loss_black = torch.abs(grayscale_images - target_value)
loss_white = torch.abs(grayscale_images - (1.0 - target_value))
# For each pixel, take the minimum deviation from either black or white
min_deviation = torch.min(loss_black, loss_white)
# We want to minimize this deviation across the image
loss = min_deviation.mean()
return loss
# And the core function to generate an image given the relevant inputs
def generate(color, guidance_loss_scale):
target_color = ImageColor.getcolor(color, "RGB") # Target color as RGB
target_color = [a/255 for a in target_color] # Rescale from (0, 255) to (0, 1)
x = torch.randn(1, 3, 256, 256).to(device)
for i, t in enumerate(scheduler.timesteps):
model_input = scheduler.scale_model_input(x, t)
with torch.no_grad():
noise_pred = image_pipe.unet(model_input, t)["sample"]
x = x.detach().requires_grad_()
x0 = scheduler.step(noise_pred, t, x).pred_original_sample
# loss = color_loss(x0, target_color) * guidance_loss_scale
loss = monochromatic_loss(x0)
cond_grad = -torch.autograd.grad(loss, x)[0]
x = x.detach() + cond_grad
x = scheduler.step(noise_pred, t, x).prev_sample
grid = torchvision.utils.make_grid(x, nrow=4)
im = grid.permute(1, 2, 0).cpu().clip(-1, 1)*0.5 + 0.5
im = Image.fromarray(np.array(im*255).astype(np.uint8))
im.save('test.jpeg')
return im
# See the gradio docs for the types of inputs and outputs available
inputs = [
gr.ColorPicker(label="color", value='#55ffaa'), # Add any inputs you need here
gr.Slider(label="guidance_scale", minimum=0, maximum=30, value=3)
]
outputs = gr.Image(label="result")
# Setting up a minimal interface to our function:
demo = gr.Interface(
fn=generate,
inputs=inputs,
outputs=outputs,
examples=[
["#BB2266", 3],["#44CCAA", 5] # You can provide some example inputs to get people started
],
)
# And launching
if __name__ == "__main__":
demo.launch()
|