| | import torch |
| | import torchvision |
| | from PIL import Image |
| | from pathlib import Path |
| | import os |
| | import numpy as np |
| | from carvekit.api.high import HiInterface |
| | import gradio as gr |
| | import torch |
| |
|
| |
|
| | class PlatonicDistanceModel(torch.nn.Module): |
| | def __init__(self, device, carvekit_object_type="object"): |
| | """ |
| | :param device: string or torch.device object to run the model on. |
| | :param carvekit_object_type: object type for foreground segmentation. Can be "object" or "hairs-like". |
| | We find that "object" works well for most images in the CUTE dataset as well as vehicle ReID. |
| | """ |
| | super().__init__() |
| | self.device = device |
| | self.encoder = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitb14') |
| | self.encoder.to(self.device) |
| |
|
| | self.interface = HiInterface(object_type=carvekit_object_type, |
| | batch_size_seg=5, |
| | batch_size_matting=1, |
| | device=str(self.device), |
| | seg_mask_size=640, |
| | matting_mask_size=2048, |
| | trimap_prob_threshold=231, |
| | trimap_dilation=30, |
| | trimap_erosion_iters=5, |
| | fp16=False) |
| |
|
| | def preprocess(self, x_list): |
| |
|
| | preprocessed_images = [] |
| |
|
| | for x in x_list: |
| | |
| | new_width = 336 |
| | new_height = 336 |
| |
|
| | def _to_rgb(x): |
| | if x.mode != "RGB": |
| | x = x.convert("RGB") |
| | return x |
| |
|
| | preprocessed_image = torchvision.transforms.Compose([ |
| | _to_rgb, |
| | torchvision.transforms.Resize((new_height, new_width), interpolation=Image.BICUBIC), |
| | torchvision.transforms.ToTensor(), |
| | torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), |
| | ])(x) |
| | preprocessed_images.append(preprocessed_image) |
| |
|
| | return torch.stack(preprocessed_images, dim=0).to(self.device) |
| |
|
| | def get_foreground_mask(self, tensor_imgs): |
| | masks = [] |
| | for tensor_img in tensor_imgs: |
| | tensor_img = tensor_img.detach().cpu() |
| | numpy_img_sum = tensor_img.sum(dim=0).numpy() |
| | min_value = np.min(numpy_img_sum) |
| | mask = ~(numpy_img_sum == min_value) |
| | mask = mask.astype(np.uint8) |
| | mask = Image.fromarray(mask * 255) |
| | resized_mask = mask.resize((24, 24), Image.BILINEAR) |
| | resized_mask_numpy = np.array(resized_mask) |
| | resized_mask_numpy = resized_mask_numpy / 255.0 |
| | tensor_mask = torch.from_numpy(resized_mask_numpy.astype(np.float32)) |
| | tensor_mask[tensor_mask > 0.5] = 1.0 |
| | tensor_mask = tensor_mask.unsqueeze(0).long().to(self.device) |
| | if tensor_mask.sum() == 0: |
| | tensor_mask = torch.ones_like(tensor_mask) |
| | masks.append(tensor_mask) |
| | return torch.stack(masks, dim=0) |
| |
|
| | def forward(self, variant, *x): |
| | if len(x) == 1 and (isinstance(x[0], list) or isinstance(x[0], torch.Tensor)): |
| | return self.forward_single(x[0], variant) |
| | elif len(x) == 1: |
| | return self.forward_single([x[0]], variant) |
| | elif len(x) == 2: |
| | return torch.cosine_similarity(self.forward_single(x[0], variant)[0], self.forward_single(x[1], variant)[0], dim=0).cpu().item() |
| | else: |
| | raise ValueError("Invalid number of inputs, only 1 or 2 inputs are supported.") |
| |
|
| | def forward_single(self, x_list, variant): |
| |
|
| | with torch.no_grad(): |
| | original_sizes = [(x.size[1], x.size[0]) for x in x_list] |
| | img_list = [np.array(self.interface([x])[0]) for x in x_list] |
| | for img in img_list: |
| | img[img[..., 3] == 0] = [0, 0, 0, 0] |
| | img_list = [Image.fromarray(img) for img in img_list] |
| | preprocessed_imgs = self.preprocess(img_list) |
| | masks = self.get_foreground_mask(preprocessed_imgs) |
| | if variant == "Crop-Feat": |
| | emb = self.encoder.forward_features(preprocessed_imgs) |
| | elif variant == "Crop-Img": |
| | emb = self.encoder.forward_features(self.preprocess(x_list)) |
| | else: |
| | raise ValueError("Invalid variant, only Crop-Feat and Crop-Img are supported.") |
| |
|
| | grid = emb["x_norm_patchtokens"].view(len(x_list), 24, 24, -1) |
| |
|
| | return (grid * masks.permute(0, 2, 3, 1)).sum(dim=(1, 2)) / masks.sum(dim=(1, 2, 3)).unsqueeze(-1) |
| |
|
| |
|
| | def compare(image_1, image_2, variant): |
| | similarity_score = model(variant, [image_1], [image_2]) |
| | return f"The similarity score is: {similarity_score:.2f}" |
| |
|
| |
|
| | device = "cuda" if torch.cuda.is_available() else "cpu" |
| | model = PlatonicDistanceModel(device) |
| |
|
| | demo = gr.Interface(title="Foreground Feature Averaging (FFA) Intrinsic Object Similarity Demo", |
| | description="Compare two images using the foreground feature averaging metric, a strong baseline for intrinsic object similarity. Please see our project website at https://s-tian.github.io/projects/cute/ for more information.", |
| | fn=compare, |
| | inputs=[gr.Image(type="pil", label="Image 1"), |
| | gr.Image(type="pil", label="Image 2"), |
| | gr.Radio(choices=["Crop-Feat", "Crop-Img"], value="Crop-Feat", label="Variant (use Crop-Feat if not sure)")], |
| | outputs="text") |
| |
|
| | if __name__ == "__main__": |
| | demo.launch() |