Spaces:
Paused
Paused
| import gc | |
| import math | |
| import sys | |
| #from IPython import display | |
| import torch | |
| from torch import nn | |
| from torch.nn import functional as F | |
| from torchvision import transforms | |
| from torchvision import utils as tv_utils | |
| from torchvision.transforms import functional as TF | |
| import gradio as gr | |
| from git.repo.base import Repo | |
| from os.path import exists as path_exists | |
| if not (path_exists(f"v-diffusion-pytorch")): | |
| Repo.clone_from("https://github.com/crowsonkb/v-diffusion-pytorch", "v-diffusion-pytorch") | |
| if not (path_exists(f"CLIP")): | |
| Repo.clone_from("https://github.com/openai/CLIP", "CLIP") | |
| sys.path.append('v-diffusion-pytorch') | |
| from huggingface_hub import hf_hub_download | |
| from CLIP import clip | |
| from diffusion import get_model, sampling, utils | |
| class MakeCutouts(nn.Module): | |
| def __init__(self, cut_size, cutn, cut_pow=1.): | |
| super().__init__() | |
| self.cut_size = cut_size | |
| self.cutn = cutn | |
| self.cut_pow = cut_pow | |
| def forward(self, input): | |
| sideY, sideX = input.shape[2:4] | |
| max_size = min(sideX, sideY) | |
| min_size = min(sideX, sideY, self.cut_size) | |
| cutouts = [] | |
| for _ in range(self.cutn): | |
| size = int(torch.rand([])**self.cut_pow * (max_size - min_size) + min_size) | |
| offsetx = torch.randint(0, sideX - size + 1, ()) | |
| offsety = torch.randint(0, sideY - size + 1, ()) | |
| cutout = input[:, :, offsety:offsety + size, offsetx:offsetx + size] | |
| cutout = F.adaptive_avg_pool2d(cutout, self.cut_size) | |
| cutouts.append(cutout) | |
| return torch.cat(cutouts) | |
| def spherical_dist_loss(x, y): | |
| x = F.normalize(x, dim=-1) | |
| y = F.normalize(y, dim=-1) | |
| return (x - y).norm(dim=-1).div(2).arcsin().pow(2).mul(2) | |
| cc12m_model = hf_hub_download(repo_id="multimodalart/crowsonkb-v-diffusion-cc12m-1-cfg", filename="cc12m_1_cfg.pth") | |
| #cc12m_small_model = hf_hub_download(repo_id="multimodalart/crowsonkb-v-diffusion-cc12m-1-cfg", filename="cc12m_1.pth") | |
| model = get_model('cc12m_1_cfg')() | |
| _, side_y, side_x = model.shape | |
| model.load_state_dict(torch.load(cc12m_model, map_location='cpu')) | |
| model = model.half().cuda().eval().requires_grad_(False) | |
| #model_small = get_model('cc12m_1')() | |
| #model_small.load_state_dict(torch.load(cc12m_model, map_location='cpu')) | |
| #model_small = model_small.half().cuda().eval().requires_grad_(False) | |
| clip_model = clip.load(model.clip_model, jit=False, device='cuda')[0] | |
| clip_model.eval().requires_grad_(False) | |
| normalize = transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], | |
| std=[0.26862954, 0.26130258, 0.27577711]) | |
| make_cutouts = MakeCutouts(clip_model.visual.input_resolution, 16, 1.) | |
| gc.collect() | |
| torch.cuda.empty_cache() | |
| def run_all(prompt, steps, n_images, weight, clip_guided): | |
| gc.collect() | |
| torch.cuda.empty_cache() | |
| import random | |
| seed = int(random.randint(0, 2147483647)) | |
| target_embed = clip_model.encode_text(clip.tokenize(prompt).to('cuda')).float()#.cuda() | |
| if(clip_guided): | |
| n_images = 1 | |
| steps = steps*5 | |
| clip_guidance_scale = weight*100 | |
| prompts = [prompt] | |
| target_embeds, weights = [], [] | |
| def parse_prompt(prompt): | |
| if prompt.startswith('http://') or prompt.startswith('https://'): | |
| vals = prompt.rsplit(':', 2) | |
| vals = [vals[0] + ':' + vals[1], *vals[2:]] | |
| else: | |
| vals = prompt.rsplit(':', 1) | |
| vals = vals + ['', '1'][len(vals):] | |
| return vals[0], float(vals[1]) | |
| for prompt in prompts: | |
| txt, weight = parse_prompt(prompt) | |
| target_embeds.append(clip_model.encode_text(clip.tokenize(txt).to('cuda')).float()) | |
| weights.append(weight) | |
| target_embeds = torch.cat(target_embeds) | |
| weights = torch.tensor(weights, device='cuda') | |
| if weights.sum().abs() < 1e-3: | |
| raise RuntimeError('The weights must not sum to 0.') | |
| weights /= weights.sum().abs() | |
| clip_embed = F.normalize(target_embeds.mul(weights[:, None]).sum(0, keepdim=True), dim=-1) | |
| clip_embed = target_embed.repeat([n_images, 1]) | |
| def cfg_model_fn(x, t): | |
| """The CFG wrapper function.""" | |
| n = x.shape[0] | |
| x_in = x.repeat([2, 1, 1, 1]) | |
| t_in = t.repeat([2]) | |
| clip_embed_repeat = target_embed.repeat([n, 1]) | |
| clip_embed_in = torch.cat([torch.zeros_like(clip_embed_repeat), clip_embed_repeat]) | |
| v_uncond, v_cond = model(x_in, t_in, clip_embed_in).chunk(2, dim=0) | |
| v = v_uncond + (v_cond - v_uncond) * weight | |
| return v | |
| def make_cond_model_fn(model, cond_fn): | |
| def cond_model_fn(x, t, **extra_args): | |
| with torch.enable_grad(): | |
| x = x.detach().requires_grad_() | |
| v = model(x, t, **extra_args) | |
| alphas, sigmas = utils.t_to_alpha_sigma(t) | |
| pred = x * alphas[:, None, None, None] - v * sigmas[:, None, None, None] | |
| cond_grad = cond_fn(x, t, pred, **extra_args).detach() | |
| v = v.detach() - cond_grad * (sigmas[:, None, None, None] / alphas[:, None, None, None]) | |
| return v | |
| return cond_model_fn | |
| def cond_fn(x, t, pred, clip_embed): | |
| if min(pred.shape[2:4]) < 256: | |
| pred = F.interpolate(pred, scale_factor=2, mode='bilinear', align_corners=False) | |
| clip_in = normalize(make_cutouts((pred + 1) / 2)) | |
| image_embeds = clip_model.encode_image(clip_in).view([16, x.shape[0], -1]) | |
| losses = spherical_dist_loss(image_embeds, clip_embed[None]) | |
| loss = losses.mean(0).sum() * clip_guidance_scale | |
| grad = -torch.autograd.grad(loss, x)[0] | |
| return grad | |
| torch.manual_seed(seed) | |
| x = torch.randn([n_images, 3, side_y, side_x], device='cuda') | |
| t = torch.linspace(1, 0, steps + 1, device='cuda')[:-1] | |
| if model.min_t == 0: | |
| step_list = utils.get_spliced_ddpm_cosine_schedule(t) | |
| else: | |
| step_list = utils.get_ddpm_schedule(t) | |
| if(not clip_guided): | |
| outs = sampling.plms_sample(cfg_model_fn, x, step_list, {})#, callback=display_callback) | |
| else: | |
| extra_args = {'clip_embed': clip_embed} | |
| cond_fn_ = cond_fn | |
| model_fn = make_cond_model_fn(model, cond_fn_) | |
| outs = sampling.plms_sample(model_fn, x, step_list, extra_args) | |
| images_out = [] | |
| for i, out in enumerate(outs): | |
| images_out.append(utils.to_pil_image(out)) | |
| return(images_out) | |
| ##################### START GRADIO HERE ############################ | |
| gallery = gr.outputs.Carousel(label="Individual images",components=["image"]) | |
| iface = gr.Interface( | |
| fn=run_all, | |
| inputs=[ | |
| gr.inputs.Textbox(label="Prompt - try adding increments to your prompt such as 'oil on canvas', 'a painting', 'a book cover'",default="an eerie alien forest"), | |
| gr.inputs.Slider(label="Steps - more steps can increase quality but will take longer to generate",default=40,maximum=80,minimum=1,step=1), | |
| gr.inputs.Slider(label="Number of images in parallel", default=2, maximum=4, minimum=1, step=1), | |
| gr.inputs.Slider(label="Weight - how closely the image should resemble the prompt", default=5, maximum=15, minimum=0, step=1), | |
| gr.inputs.Checkbox(label="CLIP Guided - improves coherence with complex prompts, makes it slower (with CLIP Guidance only one image is generated)"), | |
| ], | |
| outputs=gallery, | |
| title="Generate images from text with V-Diffusion", | |
| description="<div>By typing a prompt and pressing submit you can generate images based on this prompt. <a href='https://github.com/crowsonkb/v-diffusion-pytorch' target='_blank'>V-Diffusion</a> is diffusion text-to-image model created by <a href='https://twitter.com/RiversHaveWings' target='_blank'>Katherine Crowson</a> and <a href='https://twitter.com/jd_pressman'>JDP</a>, trained on the <a href='https://github.com/google-research-datasets/conceptual-12m'>CC12M dataset</a>. The UI to the model was assembled by <a style='color: rgb(99, 102, 241);font-weight:bold' href='https://twitter.com/multimodalart' target='_blank'>@multimodalart</a>, keep up with the <a style='color: rgb(99, 102, 241);' href='https://multimodal.art/news' target='_blank'>latest multimodal ai art news here</a> and consider <a style='color: rgb(99, 102, 241);' href='https://www.patreon.com/multimodalart' target='_blank'>supporting us on Patreon</a></div>", | |
| ) | |
| iface.launch(enable_queue=True) |