File size: 3,629 Bytes
b4e0431
 
 
 
 
 
 
 
a81b4f4
 
b4e0431
 
 
 
 
 
9633e6a
b4e0431
d7a843d
b4e0431
 
 
 
 
 
 
 
4a86ebc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d7a843d
b4e0431
9633e6a
 
 
fb21251
9633e6a
 
 
 
 
4a86ebc
 
9633e6a
 
 
 
 
 
 
 
b4e0431
d7a843d
b4e0431
37b2b46
d6b8beb
b4e0431
 
 
d7a843d
b4e0431
 
 
 
 
d7a843d
b4e0431
9633e6a
b4e0431
d7a843d
b4e0431
b8d4769
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import gradio as gr
import torch, torchvision
import torch.nn.functional as F
import numpy as np
from PIL import Image, ImageColor
from diffusers import DDPMPipeline
from diffusers import DDIMScheduler

device = 'mps' if torch.backends.mps.is_available() else 'cuda' if torch.cuda.is_available() else 'cpu'

# Load the pretrained pipeline
pipeline_name = 'johnowhitaker/sd-class-wikiart-from-bedrooms'
image_pipe = DDPMPipeline.from_pretrained(pipeline_name).to(device)

# Set up the scheduler
scheduler = DDIMScheduler.from_pretrained(pipeline_name)
scheduler.set_timesteps(num_inference_steps=20)

# The guidance function 
def color_loss(images, target_color=(0.1, 0.9, 0.5)):
    """Given a target color (R, G, B) return a loss for how far away on average 
    the images' pixels are from that color. Defaults to a light teal: (0.1, 0.9, 0.5) """
    target = torch.tensor(target_color).to(images.device) * 2 - 1 # Map target color to (-1, 1)
    target = target[None, :, None, None] # Get shape right to work with the images (b, c, h, w)
    error = torch.abs(images - target).mean() # Mean absolute difference between the image pixels and the target color
    return error

def monochromatic_loss(images, threshold=0.5, target_value=0.01):
    # Convert images to grayscale (simple average of channels)
    # We assume images are [N, C, H, W] where C=3 (RGB)
    grayscale_images = (images[:,0,:,:] + images[:,1,:,:] + images[:,2,:,:]) / 3.0

    # Penalize pixels that are not close to black or white
    # Encourage values close to target_value (e.g., 0.01 for black) or 1.0 (for white)
    # This creates a strong push towards high contrast
    loss_black = torch.abs(grayscale_images - target_value)
    loss_white = torch.abs(grayscale_images - (1.0 - target_value))

    # For each pixel, take the minimum deviation from either black or white
    min_deviation = torch.min(loss_black, loss_white)

    # We want to minimize this deviation across the image
    loss = min_deviation.mean()
    return loss

# And the core function to generate an image given the relevant inputs
def generate(color, guidance_loss_scale):
  target_color = ImageColor.getcolor(color, "RGB") # Target color as RGB
  target_color = [a/255 for a in target_color] # Rescale from (0, 255) to (0, 1)
  x = torch.randn(1, 3, 256, 256).to(device) 
  for i, t in enumerate(scheduler.timesteps):
    model_input = scheduler.scale_model_input(x, t)
    with torch.no_grad():
        noise_pred = image_pipe.unet(model_input, t)["sample"]
    x = x.detach().requires_grad_()
    x0 = scheduler.step(noise_pred, t, x).pred_original_sample
    # loss = color_loss(x0, target_color) * guidance_loss_scale
    loss  = monochromatic_loss(x0)
    cond_grad = -torch.autograd.grad(loss, x)[0]
    x = x.detach() + cond_grad 
    x = scheduler.step(noise_pred, t, x).prev_sample
  grid = torchvision.utils.make_grid(x, nrow=4)
  im = grid.permute(1, 2, 0).cpu().clip(-1, 1)*0.5 + 0.5
  im = Image.fromarray(np.array(im*255).astype(np.uint8))
  im.save('test.jpeg')
  return im

# See the gradio docs for the types of inputs and outputs available
inputs = [
    gr.ColorPicker(label="color", value='#55ffaa'), # Add any inputs you need here
    gr.Slider(label="guidance_scale", minimum=0, maximum=30, value=3)
]
outputs = gr.Image(label="result")

# Setting up a minimal interface to our function:
demo = gr.Interface(
    fn=generate,
    inputs=inputs,
    outputs=outputs,
    examples=[
        ["#BB2266", 3],["#44CCAA", 5] # You can provide some example inputs to get people started
    ],
)

# And launching
if __name__ == "__main__":
    demo.launch()