Spaces:
Running
on
Zero
Running
on
Zero
| import os | |
| import torch | |
| import torchvision | |
| from PIL import Image | |
| import numpy as np | |
| import imageio | |
| from einops import rearrange, repeat | |
| def load_image(img, size): | |
| # img = Image.open(filename).convert('RGB') | |
| if not isinstance(img, np.ndarray): | |
| img = Image.open(img).convert('RGB') | |
| img = img.resize((size, size)) | |
| img = np.asarray(img) | |
| img = np.transpose(img, (2, 0, 1)) # 3 x 256 x 256 | |
| return img / 255.0 | |
| def img_preprocessing(img_path, size): | |
| img = load_image(img_path, size) # [0, 1] | |
| img = torch.from_numpy(img).unsqueeze(0).float() # [0, 1] | |
| imgs_norm = (img - 0.5) * 2.0 # [-1, 1] | |
| return imgs_norm | |
| def resize(img, size): | |
| transform = torchvision.transforms.Compose([ | |
| torchvision.transforms.Resize(size, antialias=True), | |
| torchvision.transforms.CenterCrop(size) | |
| ]) | |
| return transform(img) | |
| def vid_preprocessing(vid_path, size): | |
| vid_dict = torchvision.io.read_video(vid_path, pts_unit='sec') | |
| vid = vid_dict[0].permute(0, 3, 1, 2).unsqueeze(0) # btchw | |
| fps = vid_dict[2]['video_fps'] | |
| vid_norm = (vid / 255.0 - 0.5) * 2.0 # [-1, 1] | |
| vid_norm = torch.cat([ | |
| resize(vid_norm[:, i, :, :, :], size).unsqueeze(1) for i in range(vid.size(1)) | |
| ], dim=1) | |
| return vid_norm, fps | |
| def img_denorm(img): | |
| img = img.clamp(-1, 1).cpu() | |
| img = (img - img.min()) / (img.max() - img.min()) | |
| return img | |
| def vid_denorm(vid): | |
| vid = vid.clamp(-1, 1).cpu() | |
| vid = (vid - vid.min()) / (vid.max() - vid.min()) | |
| return vid | |
| def save_img_edit(save_dir, img, img_e): | |
| # img: BCHW | |
| # img_e: BCHW | |
| output_img_path = os.path.join(save_dir, "img_edit.png") | |
| output_img_all_path = os.path.join(save_dir, "img_all.png") | |
| img = rearrange(img, 'b c h w -> b h w c') | |
| img_e = rearrange(img_e, 'b c h w -> b h w c') | |
| img_all = torch.cat([img, img_e], dim=2) | |
| img_e_np = (img_denorm(img_e[0]).numpy() * 255).astype('uint8') | |
| img_all_np = (img_denorm(img_all[0]).numpy() * 255).astype('uint8') | |
| imageio.imwrite(output_img_path, img_e_np, quality=8) | |
| imageio.imwrite(output_img_all_path, img_all_np, quality=8) | |
| return | |
| def save_vid_edit(save_dir, vid_d, vid_a, fps): | |
| # img_s: BCHW | |
| # vid_d: BTCHW | |
| # vid_a: BCTHW | |
| output_vid_a_path = os.path.join(save_dir, "vid_animation.mp4") | |
| output_vid_all_path = os.path.join(save_dir, "vid_all.mp4") | |
| vid_d = rearrange(vid_d, 'b t c h w -> b t h w c') | |
| vid_a = rearrange(vid_a, 'b c t h w -> b t h w c') | |
| vid_all = torch.cat([vid_d, vid_a], dim=3) | |
| vid_a_np = (vid_denorm(vid_a[0]).numpy() * 255).astype('uint8') | |
| vid_all_np = (vid_denorm(vid_all[0]).numpy() * 255).astype('uint8') | |
| imageio.mimwrite(output_vid_a_path, vid_a_np, fps=fps, codec='libx264', quality=8) | |
| imageio.mimwrite(output_vid_all_path, vid_all_np, fps=fps, codec='libx264', quality=8) | |
| return | |
| def save_animation(save_dir, img_s, vid_d, vid_a, fps): | |
| # img_s: BCHW | |
| # vid_d: BTCHW | |
| # vid_a: BCTHW | |
| output_vid_a_path = os.path.join(save_dir, "vid_animation.mp4") | |
| output_img_e_path = os.path.join(save_dir, "img_edit.png") | |
| output_vid_all_path = os.path.join(save_dir, "vid_all.mp4") | |
| vid_d = rearrange(vid_d, 'b t c h w -> b t h w c') | |
| vid_a = rearrange(vid_a, 'b c t h w -> b t h w c') | |
| img_s = repeat(rearrange(img_s, 'b c h w -> b h w c'), 'b h w c -> b t h w c', t=vid_d.size(1)) | |
| vid_all = torch.cat([img_s, vid_d, vid_a], dim=3) | |
| vid_a_np = (vid_denorm(vid_a[0]).numpy() * 255).astype('uint8') | |
| img_e_np = vid_a_np[0] | |
| vid_all_np = (vid_denorm(vid_all[0]).numpy() * 255).astype('uint8') | |
| imageio.mimwrite(output_vid_a_path, vid_a_np, fps=fps, codec='libx264', quality=8) | |
| imageio.mimwrite(output_vid_all_path, vid_all_np, fps=fps, codec='libx264', quality=8) | |
| imageio.imwrite(output_img_e_path, img_e_np, quality=8) | |
| return | |
| def save_linear_manipulation(save_dir, vid, fps): | |
| # vid: BCTHW | |
| output_vid_path = os.path.join(save_dir, "vid_interpolation.mp4") | |
| vid = rearrange(vid, 'b c t h w -> b t h w c') | |
| vid_np = (vid_denorm(vid[0]).numpy() * 255).astype('uint8') | |
| imageio.mimwrite(output_vid_path, vid_np, fps=fps, codec='libx264', quality=8) | |
| return |