Spaces:
Runtime error
Runtime error
| # MIT License | |
| # Copyright (c) 2022 Intelligent Systems Lab Org | |
| # Permission is hereby granted, free of charge, to any person obtaining a copy | |
| # of this software and associated documentation files (the "Software"), to deal | |
| # in the Software without restriction, including without limitation the rights | |
| # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | |
| # copies of the Software, and to permit persons to whom the Software is | |
| # furnished to do so, subject to the following conditions: | |
| # The above copyright notice and this permission notice shall be included in all | |
| # copies or substantial portions of the Software. | |
| # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | |
| # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | |
| # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | |
| # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | |
| # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | |
| # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | |
| # SOFTWARE. | |
| # File author: Zhenyu Li | |
| import os | |
| import cv2 | |
| import argparse | |
| from zoedepth.utils.config import get_config_user | |
| from zoedepth.models.builder import build_model | |
| from zoedepth.utils.arg_utils import parse_unknown | |
| import numpy as np | |
| from zoedepth.models.base_models.midas import Resize | |
| from torchvision.transforms import Compose | |
| from torchvision.transforms import Normalize | |
| import torch | |
| import torch.nn as nn | |
| import matplotlib.pyplot as plt | |
| import copy | |
| from zoedepth.utils.misc import get_boundaries | |
| from zoedepth.utils.misc import compute_metrics, RunningAverageDict | |
| from tqdm import tqdm | |
| import matplotlib | |
| import torch.nn.functional as F | |
| from zoedepth.data.middleburry import readPFM | |
| import random | |
| import imageio | |
| from PIL import Image | |
| def load_state_dict(model, state_dict): | |
| """Load state_dict into model, handling DataParallel and DistributedDataParallel. Also checks for "model" key in state_dict. | |
| DataParallel prefixes state_dict keys with 'module.' when saving. | |
| If the model is not a DataParallel model but the state_dict is, then prefixes are removed. | |
| If the model is a DataParallel model but the state_dict is not, then prefixes are added. | |
| """ | |
| state_dict = state_dict.get('model', state_dict) | |
| # if model is a DataParallel model, then state_dict keys are prefixed with 'module.' | |
| do_prefix = isinstance( | |
| model, (torch.nn.DataParallel, torch.nn.parallel.DistributedDataParallel)) | |
| state = {} | |
| for k, v in state_dict.items(): | |
| if k.startswith('module.') and not do_prefix: | |
| k = k[7:] | |
| if not k.startswith('module.') and do_prefix: | |
| k = 'module.' + k | |
| state[k] = v | |
| model.load_state_dict(state, strict=True) | |
| # model.load_state_dict(state, strict=False) | |
| print("Loaded successfully") | |
| return model | |
| def load_wts(model, checkpoint_path): | |
| ckpt = torch.load(checkpoint_path, map_location='cpu') | |
| return load_state_dict(model, ckpt) | |
| def load_ckpt(model, checkpoint): | |
| model = load_wts(model, checkpoint) | |
| print("Loaded weights from {0}".format(checkpoint)) | |
| return model | |
| #### def dataset | |
| def read_image(path, dataset_name): | |
| if dataset_name == 'u4k': | |
| img = np.fromfile(open(path, 'rb'), dtype=np.uint8).reshape(2160, 3840, 3) / 255.0 | |
| img = img.astype(np.float32)[:, :, ::-1].copy() | |
| elif dataset_name == 'mid': | |
| img = cv2.imread(path) | |
| if img.ndim == 2: | |
| img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) | |
| img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) / 255.0 | |
| img = F.interpolate(torch.tensor(img).unsqueeze(dim=0).permute(0, 3, 1, 2), IMG_RESOLUTION, mode='bicubic', align_corners=True) | |
| img = img.squeeze().permute(1, 2, 0) | |
| elif dataset_name == 'nyu': | |
| img = Image.open(path) | |
| img = np.asarray(img, dtype=np.float32) / 255.0 | |
| else: | |
| img = cv2.imread(path) | |
| if img.ndim == 2: | |
| img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) | |
| img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) / 255.0 | |
| print(img.shape) | |
| img = F.interpolate(torch.tensor(img).unsqueeze(dim=0).permute(0, 3, 1, 2), IMG_RESOLUTION, mode='bicubic', align_corners=True) | |
| img = img.squeeze().permute(1, 2, 0) | |
| return img | |
| class Images: | |
| def __init__(self, root_dir, files, index, dataset_name=None): | |
| self.root_dir = root_dir | |
| name = files[index] | |
| self.dataset_name = dataset_name | |
| self.rgb_image = read_image(os.path.join(self.root_dir, name), dataset_name) | |
| name = name.replace(".jpg", "") | |
| name = name.replace(".png", "") | |
| name = name.replace(".jpeg", "") | |
| self.name = name | |
| class DepthMap: | |
| def __init__(self, root_dir, files, index, dataset_name, pred=False): | |
| self.root_dir = root_dir | |
| name = files[index] | |
| gt_path = os.path.join(self.root_dir, name) | |
| if dataset_name == 'u4k': | |
| depth_factor = gt_path.replace('val_gt', 'val_factor') | |
| depth_factor = depth_factor.replace('.npy', '.txt') | |
| with open(depth_factor, 'r') as f: | |
| df = f.readline() | |
| df = float(df) | |
| gt_disp = np.load(gt_path, mmap_mode='c') | |
| gt_disp = gt_disp.astype(np.float32) | |
| edges = get_boundaries(gt_disp, th=1, dilation=0) | |
| gt_depth = df/gt_disp | |
| self.gt = gt_depth | |
| self.edge = edges | |
| elif dataset_name == 'gta': | |
| gt_depth = imageio.imread(gt_path) | |
| gt_depth = np.array(gt_depth).astype(np.float32) / 256 | |
| edges = get_boundaries(gt_depth, th=1, dilation=0) | |
| self.gt = gt_depth | |
| self.edge = edges | |
| elif dataset_name == 'mid': | |
| depth_factor = gt_path.replace('gts', 'calibs') | |
| depth_factor = depth_factor.replace('.pfm', '.txt') | |
| with open(depth_factor, 'r') as f: | |
| ext_l = f.readlines() | |
| cam_info = ext_l[0].strip() | |
| cam_info_f = float(cam_info.split(' ')[0].split('[')[1]) | |
| base = float(ext_l[3].strip().split('=')[1]) | |
| doffs = float(ext_l[2].strip().split('=')[1]) | |
| depth_factor = base * cam_info_f | |
| height = 1840 | |
| width = 2300 | |
| disp_gt, scale = readPFM(gt_path) | |
| disp_gt = disp_gt.astype(np.float32) | |
| disp_gt_copy = disp_gt.copy() | |
| disp_gt = disp_gt | |
| invalid_mask = disp_gt == np.inf | |
| depth_gt = depth_factor / (disp_gt + doffs) | |
| depth_gt = depth_gt / 1000 | |
| depth_gt[invalid_mask] = 0 # set to a invalid number | |
| disp_gt_copy[invalid_mask] = 0 | |
| edges = get_boundaries(disp_gt_copy, th=1, dilation=0) | |
| self.gt = depth_gt | |
| self.edge = edges | |
| elif dataset_name == 'nyu': | |
| if pred: | |
| depth_gt = np.load(gt_path.replace('png', 'npy')) | |
| depth_gt = nn.functional.interpolate( | |
| torch.tensor(depth_gt).unsqueeze(dim=0).unsqueeze(dim=0), (480, 640), mode='bilinear', align_corners=True).squeeze().numpy() | |
| edges = get_boundaries(depth_gt, th=1, dilation=0) | |
| else: | |
| depth_gt = np.asarray(Image.open(gt_path), dtype=np.float32) / 1000 | |
| edges = get_boundaries(depth_gt, th=1, dilation=0) | |
| self.gt = depth_gt | |
| self.edge = edges | |
| else: | |
| raise NotImplementedError | |
| name = name.replace(".npy", "") # u4k | |
| name = name.replace(".exr", "") # gta | |
| self.name = name | |
| class ImageDataset: | |
| def __init__(self, rgb_image_dir, gt_dir=None, dataset_name=''): | |
| self.rgb_image_dir = rgb_image_dir | |
| self.files = sorted(os.listdir(self.rgb_image_dir)) | |
| self.gt_dir = gt_dir | |
| self.dataset_name = dataset_name | |
| if gt_dir is not None: | |
| self.gt_dir = gt_dir | |
| self.gt_files = sorted(os.listdir(self.gt_dir)) | |
| def __len__(self): | |
| return len(self.files) | |
| def __getitem__(self, index): | |
| if self.dataset_name == 'nyu': | |
| return Images(self.rgb_image_dir, self.files, index, self.dataset_name), DepthMap(self.gt_dir, self.gt_files, index, self.dataset_name), DepthMap('/ibex/ai/home/liz0l/codes/ZoeDepth/nfs/save/nyu', self.gt_files, index, self.dataset_name, pred=True) | |
| if self.gt_dir is not None: | |
| return Images(self.rgb_image_dir, self.files, index, self.dataset_name), DepthMap(self.gt_dir, self.gt_files, index, self.dataset_name) | |
| else: | |
| return Images(self.rgb_image_dir, self.files, index) | |
| def crop(img, crop_bbox): | |
| crop_y1, crop_y2, crop_x1, crop_x2 = crop_bbox | |
| templete = torch.zeros((1, 1, img.shape[-2], img.shape[-1]), dtype=torch.float) | |
| templete[:, :, crop_y1:crop_y2, crop_x1:crop_x2] = 1.0 | |
| img = img[:, :, crop_y1:crop_y2, crop_x1:crop_x2] | |
| return img, templete | |
| # def generatemask(size): | |
| # # Generates a Guassian mask | |
| # mask = np.zeros(size, dtype=np.float32) | |
| # sigma = int(size[0]/16) | |
| # k_size = int(2 * np.ceil(2 * int(size[0]/16)) + 1) | |
| # mask[int(0.15*size[0]):size[0] - int(0.15*size[0]), int(0.15*size[1]): size[1] - int(0.15*size[1])] = 1 | |
| # mask = cv2.GaussianBlur(mask, (int(k_size), int(k_size)), sigma) | |
| # mask = (mask - mask.min()) / (mask.max() - mask.min()) | |
| # mask = mask.astype(np.float32) | |
| # return mask | |
| def generatemask(size): | |
| # Generates a Guassian mask | |
| mask = np.zeros(size, dtype=np.float32) | |
| sigma = int(size[0]/16) | |
| k_size = int(2 * np.ceil(2 * int(size[0]/16)) + 1) | |
| mask[int(0.1*size[0]):size[0] - int(0.1*size[0]), int(0.1*size[1]): size[1] - int(0.1*size[1])] = 1 | |
| mask = cv2.GaussianBlur(mask, (int(k_size), int(k_size)), sigma) | |
| mask = (mask - mask.min()) / (mask.max() - mask.min()) | |
| mask = mask.astype(np.float32) | |
| return mask | |
| def generatemask_coarse(size): | |
| # Generates a Guassian mask | |
| mask = np.zeros(size, dtype=np.float32) | |
| sigma = int(size[0]/64) | |
| k_size = int(2 * np.ceil(2 * int(size[0]/64)) + 1) | |
| mask[int(0.001*size[0]):size[0] - int(0.001*size[0]), int(0.001*size[1]): size[1] - int(0.001*size[1])] = 1 | |
| mask = cv2.GaussianBlur(mask, (int(k_size), int(k_size)), sigma) | |
| mask = (mask - mask.min()) / (mask.max() - mask.min()) | |
| mask = mask.astype(np.float32) | |
| return mask | |
| class RunningAverageMap: | |
| """A dictionary of running averages.""" | |
| def __init__(self, average_map, count_map): | |
| self.average_map = average_map | |
| self.count_map = count_map | |
| self.average_map = self.average_map / self.count_map | |
| def update(self, pred_map, ct_map): | |
| self.average_map = (pred_map + self.count_map * self.average_map) / (self.count_map + ct_map) | |
| self.count_map = self.count_map + ct_map | |
| # default size [540, 960] | |
| # x_start, y_start = [0, 540, 1080, 1620], [0, 960, 1920, 2880] | |
| def regular_tile(model, image, offset_x=0, offset_y=0, img_lr=None, iter_pred=None, boundary=0, update=False, avg_depth_map=None, blr_mask=False): | |
| # crop size | |
| # height = 540 | |
| # width = 960 | |
| height = CROP_SIZE[0] | |
| width = CROP_SIZE[1] | |
| assert offset_x >= 0 and offset_y >= 0 | |
| tile_num_x = (IMG_RESOLUTION[1] - offset_x) // width | |
| tile_num_y = (IMG_RESOLUTION[0] - offset_y) // height | |
| x_start = [width * x + offset_x for x in range(tile_num_x)] | |
| y_start = [height * y + offset_y for y in range(tile_num_y)] | |
| imgs_crop = [] | |
| crop_areas = [] | |
| bboxs_roi = [] | |
| bboxs_raw = [] | |
| if iter_pred is not None: | |
| iter_pred = iter_pred.unsqueeze(dim=0).unsqueeze(dim=0) | |
| iter_priors = [] | |
| for x in x_start: # w | |
| for y in y_start: # h | |
| bbox = (int(y), int(y+height), int(x), int(x+width)) | |
| img_crop, crop_area = crop(image, bbox) | |
| imgs_crop.append(img_crop) | |
| crop_areas.append(crop_area) | |
| crop_y1, crop_y2, crop_x1, crop_x2 = bbox | |
| bbox_roi = torch.tensor([crop_x1 / IMG_RESOLUTION[1] * 512, crop_y1 / IMG_RESOLUTION[0] * 384, crop_x2 / IMG_RESOLUTION[1] * 512, crop_y2 / IMG_RESOLUTION[0] * 384]) | |
| bboxs_roi.append(bbox_roi) | |
| bbox_raw = torch.tensor([crop_x1, crop_y1, crop_x2, crop_y2]) | |
| bboxs_raw.append(bbox_raw) | |
| if iter_pred is not None: | |
| iter_prior, _ = crop(iter_pred, bbox) | |
| iter_priors.append(iter_prior) | |
| crop_areas = torch.cat(crop_areas, dim=0) | |
| imgs_crop = torch.cat(imgs_crop, dim=0) | |
| bboxs_roi = torch.stack(bboxs_roi, dim=0) | |
| bboxs_raw = torch.stack(bboxs_raw, dim=0) | |
| if iter_pred is not None: | |
| iter_priors = torch.cat(iter_priors, dim=0) | |
| iter_priors = TRANSFORM(iter_priors) | |
| iter_priors = iter_priors.cuda().float() | |
| crop_areas = TRANSFORM(crop_areas) | |
| imgs_crop = TRANSFORM(imgs_crop) | |
| imgs_crop = imgs_crop.cuda().float() | |
| bboxs_roi = bboxs_roi.cuda().float() | |
| crop_areas = crop_areas.cuda().float() | |
| img_lr = img_lr.cuda().float() | |
| pred_depth_crops = [] | |
| with torch.no_grad(): | |
| for i, (img, bbox, crop_area) in enumerate(zip(imgs_crop, bboxs_roi, crop_areas)): | |
| if iter_pred is not None: | |
| iter_prior = iter_priors[i].unsqueeze(dim=0) | |
| else: | |
| iter_prior = None | |
| if i == 0: | |
| out_dict = model(img.unsqueeze(dim=0), mode='eval', image_raw=img_lr, bbox=bbox.unsqueeze(dim=0), crop_area=crop_area.unsqueeze(dim=0), iter_prior=iter_prior if update is True else None) | |
| whole_depth_pred = out_dict['coarse_depth_pred'] | |
| # return whole_depth_pred.squeeze() | |
| # pred_depth_crop = out_dict['fine_depth_pred'] | |
| pred_depth_crop = out_dict['metric_depth'] | |
| else: | |
| pred_depth_crop = model(img.unsqueeze(dim=0), mode='eval', image_raw=img_lr, bbox=bbox.unsqueeze(dim=0), crop_area=crop_area.unsqueeze(dim=0), iter_prior=iter_prior if update is True else None)['metric_depth'] | |
| # pred_depth_crop = model(img.unsqueeze(dim=0), mode='eval', image_raw=img_lr, bbox=bbox.unsqueeze(dim=0), crop_area=crop_area.unsqueeze(dim=0), iter_prior=iter_prior if update is True else None)['fine_depth_pred'] | |
| pred_depth_crop = nn.functional.interpolate( | |
| pred_depth_crop, (height, width), mode='bilinear', align_corners=True) | |
| # pred_depth_crop = nn.functional.interpolate( | |
| # pred_depth_crop, (height, width), mode='nearest') | |
| pred_depth_crops.append(pred_depth_crop.squeeze()) | |
| whole_depth_pred = whole_depth_pred.squeeze() | |
| whole_depth_pred = nn.functional.interpolate(whole_depth_pred.unsqueeze(dim=0).unsqueeze(dim=0), IMG_RESOLUTION, mode='bilinear', align_corners=True).squeeze() | |
| ####### stich part | |
| inner_idx = 0 | |
| init_flag = False | |
| if offset_x == 0 and offset_y == 0: | |
| init_flag = True | |
| # pred_depth = whole_depth_pred | |
| pred_depth = torch.zeros(IMG_RESOLUTION, device=pred_depth_crops[inner_idx].device) | |
| else: | |
| iter_pred = iter_pred.squeeze() | |
| pred_depth = iter_pred | |
| blur_mask = generatemask((height, width)) + 1e-3 | |
| count_map = torch.zeros(IMG_RESOLUTION, device=pred_depth_crops[inner_idx].device) | |
| for ii, x in enumerate(x_start): | |
| for jj, y in enumerate(y_start): | |
| if init_flag: | |
| # pred_depth[y: y+height, x: x+width] = blur_mask * pred_depth_crops[inner_idx] + (1 - blur_mask) * crop_temp | |
| # pred_depth[y: y+height, x: x+width] = blur_mask * pred_depth_crops[inner_idx] + (1 - blur_mask) * crop_temp | |
| blur_mask = torch.tensor(blur_mask, device=whole_depth_pred.device) | |
| count_map[y: y+height, x: x+width] = blur_mask | |
| pred_depth[y: y+height, x: x+width] = pred_depth_crops[inner_idx] * blur_mask | |
| else: | |
| # ensemble with running mean | |
| if blr_mask: | |
| blur_mask = torch.tensor(blur_mask, device=whole_depth_pred.device) | |
| count_map = torch.zeros(IMG_RESOLUTION, device=pred_depth_crops[inner_idx].device) | |
| count_map[y: y+height, x: x+width] = blur_mask | |
| pred_map = torch.zeros(IMG_RESOLUTION, device=pred_depth_crops[inner_idx].device) | |
| pred_map[y: y+height, x: x+width] = pred_depth_crops[inner_idx] * blur_mask | |
| avg_depth_map.update(pred_map, count_map) | |
| else: | |
| if boundary != 0: | |
| count_map = torch.zeros(IMG_RESOLUTION, device=pred_depth_crops[inner_idx].device) | |
| count_map[y+boundary: y+height-boundary, x+boundary: x+width-boundary] = 1 | |
| pred_map = torch.zeros(IMG_RESOLUTION, device=pred_depth_crops[inner_idx].device) | |
| pred_map[y+boundary: y+height-boundary, x+boundary: x+width-boundary] = pred_depth_crops[inner_idx][boundary:-boundary, boundary:-boundary] | |
| avg_depth_map.update(pred_map, count_map) | |
| else: | |
| count_map = torch.zeros(IMG_RESOLUTION, device=pred_depth_crops[inner_idx].device) | |
| count_map[y: y+height, x: x+width] = 1 | |
| pred_map = torch.zeros(IMG_RESOLUTION, device=pred_depth_crops[inner_idx].device) | |
| pred_map[y: y+height, x: x+width] = pred_depth_crops[inner_idx] | |
| avg_depth_map.update(pred_map, count_map) | |
| inner_idx += 1 | |
| if init_flag: | |
| avg_depth_map = RunningAverageMap(pred_depth, count_map) | |
| # blur_mask = generatemask_coarse(IMG_RESOLUTION) | |
| # blur_mask = torch.tensor(blur_mask, device=whole_depth_pred.device) | |
| # count_map = (1 - blur_mask) | |
| # pred_map = whole_depth_pred * (1 - blur_mask) | |
| # avg_depth_map.update(pred_map, count_map) | |
| return avg_depth_map | |
| def regular_tile_param(model, image, offset_x=0, offset_y=0, img_lr=None, iter_pred=None, boundary=0, update=False, avg_depth_map=None, blr_mask=False, crop_size=None, | |
| img_resolution=None, transform=None): | |
| # crop size | |
| # height = 540 | |
| # width = 960 | |
| height = crop_size[0] | |
| width = crop_size[1] | |
| assert offset_x >= 0 and offset_y >= 0 | |
| tile_num_x = (img_resolution[1] - offset_x) // width | |
| tile_num_y = (img_resolution[0] - offset_y) // height | |
| x_start = [width * x + offset_x for x in range(tile_num_x)] | |
| y_start = [height * y + offset_y for y in range(tile_num_y)] | |
| imgs_crop = [] | |
| crop_areas = [] | |
| bboxs_roi = [] | |
| bboxs_raw = [] | |
| if iter_pred is not None: | |
| iter_pred = iter_pred.unsqueeze(dim=0).unsqueeze(dim=0) | |
| iter_priors = [] | |
| for x in x_start: # w | |
| for y in y_start: # h | |
| bbox = (int(y), int(y+height), int(x), int(x+width)) | |
| img_crop, crop_area = crop(image, bbox) | |
| imgs_crop.append(img_crop) | |
| crop_areas.append(crop_area) | |
| crop_y1, crop_y2, crop_x1, crop_x2 = bbox | |
| bbox_roi = torch.tensor([crop_x1 / img_resolution[1] * 512, crop_y1 / img_resolution[0] * 384, crop_x2 / img_resolution[1] * 512, crop_y2 / img_resolution[0] * 384]) | |
| bboxs_roi.append(bbox_roi) | |
| bbox_raw = torch.tensor([crop_x1, crop_y1, crop_x2, crop_y2]) | |
| bboxs_raw.append(bbox_raw) | |
| if iter_pred is not None: | |
| iter_prior, _ = crop(iter_pred, bbox) | |
| iter_priors.append(iter_prior) | |
| crop_areas = torch.cat(crop_areas, dim=0) | |
| imgs_crop = torch.cat(imgs_crop, dim=0) | |
| bboxs_roi = torch.stack(bboxs_roi, dim=0) | |
| bboxs_raw = torch.stack(bboxs_raw, dim=0) | |
| if iter_pred is not None: | |
| iter_priors = torch.cat(iter_priors, dim=0) | |
| iter_priors = transform(iter_priors) | |
| iter_priors = iter_priors.to(image.device).float() | |
| crop_areas = transform(crop_areas) | |
| imgs_crop = transform(imgs_crop) | |
| imgs_crop = imgs_crop.to(image.device).float() | |
| bboxs_roi = bboxs_roi.to(image.device).float() | |
| crop_areas = crop_areas.to(image.device).float() | |
| img_lr = img_lr.to(image.device).float() | |
| pred_depth_crops = [] | |
| with torch.no_grad(): | |
| for i, (img, bbox, crop_area) in enumerate(zip(imgs_crop, bboxs_roi, crop_areas)): | |
| if iter_pred is not None: | |
| iter_prior = iter_priors[i].unsqueeze(dim=0) | |
| else: | |
| iter_prior = None | |
| if i == 0: | |
| out_dict = model(img.unsqueeze(dim=0), mode='eval', image_raw=img_lr, bbox=bbox.unsqueeze(dim=0), crop_area=crop_area.unsqueeze(dim=0), iter_prior=iter_prior if update is True else None) | |
| whole_depth_pred = out_dict['coarse_depth_pred'] | |
| # return whole_depth_pred.squeeze() | |
| # pred_depth_crop = out_dict['fine_depth_pred'] | |
| pred_depth_crop = out_dict['metric_depth'] | |
| else: | |
| pred_depth_crop = model(img.unsqueeze(dim=0), mode='eval', image_raw=img_lr, bbox=bbox.unsqueeze(dim=0), crop_area=crop_area.unsqueeze(dim=0), iter_prior=iter_prior if update is True else None)['metric_depth'] | |
| # pred_depth_crop = model(img.unsqueeze(dim=0), mode='eval', image_raw=img_lr, bbox=bbox.unsqueeze(dim=0), crop_area=crop_area.unsqueeze(dim=0), iter_prior=iter_prior if update is True else None)['fine_depth_pred'] | |
| pred_depth_crop = nn.functional.interpolate( | |
| pred_depth_crop, (height, width), mode='bilinear', align_corners=True) | |
| # pred_depth_crop = nn.functional.interpolate( | |
| # pred_depth_crop, (height, width), mode='nearest') | |
| pred_depth_crops.append(pred_depth_crop.squeeze()) | |
| whole_depth_pred = whole_depth_pred.squeeze() | |
| whole_depth_pred = nn.functional.interpolate(whole_depth_pred.unsqueeze(dim=0).unsqueeze(dim=0), img_resolution, mode='bilinear', align_corners=True).squeeze() | |
| ####### stich part | |
| inner_idx = 0 | |
| init_flag = False | |
| if offset_x == 0 and offset_y == 0: | |
| init_flag = True | |
| # pred_depth = whole_depth_pred | |
| pred_depth = torch.zeros(img_resolution, device=pred_depth_crops[inner_idx].device) | |
| else: | |
| iter_pred = iter_pred.squeeze() | |
| pred_depth = iter_pred | |
| blur_mask = generatemask((height, width)) + 1e-3 | |
| count_map = torch.zeros(img_resolution, device=pred_depth_crops[inner_idx].device) | |
| for ii, x in enumerate(x_start): | |
| for jj, y in enumerate(y_start): | |
| if init_flag: | |
| # pred_depth[y: y+height, x: x+width] = blur_mask * pred_depth_crops[inner_idx] + (1 - blur_mask) * crop_temp | |
| # pred_depth[y: y+height, x: x+width] = blur_mask * pred_depth_crops[inner_idx] + (1 - blur_mask) * crop_temp | |
| blur_mask = torch.tensor(blur_mask, device=whole_depth_pred.device) | |
| count_map[y: y+height, x: x+width] = blur_mask | |
| pred_depth[y: y+height, x: x+width] = pred_depth_crops[inner_idx] * blur_mask | |
| else: | |
| # ensemble with running mean | |
| if blr_mask: | |
| blur_mask = torch.tensor(blur_mask, device=whole_depth_pred.device) | |
| count_map = torch.zeros(img_resolution, device=pred_depth_crops[inner_idx].device) | |
| count_map[y: y+height, x: x+width] = blur_mask | |
| pred_map = torch.zeros(img_resolution, device=pred_depth_crops[inner_idx].device) | |
| pred_map[y: y+height, x: x+width] = pred_depth_crops[inner_idx] * blur_mask | |
| avg_depth_map.update(pred_map, count_map) | |
| else: | |
| if boundary != 0: | |
| count_map = torch.zeros(img_resolution, device=pred_depth_crops[inner_idx].device) | |
| count_map[y+boundary: y+height-boundary, x+boundary: x+width-boundary] = 1 | |
| pred_map = torch.zeros(img_resolution, device=pred_depth_crops[inner_idx].device) | |
| pred_map[y+boundary: y+height-boundary, x+boundary: x+width-boundary] = pred_depth_crops[inner_idx][boundary:-boundary, boundary:-boundary] | |
| avg_depth_map.update(pred_map, count_map) | |
| else: | |
| count_map = torch.zeros(img_resolution, device=pred_depth_crops[inner_idx].device) | |
| count_map[y: y+height, x: x+width] = 1 | |
| pred_map = torch.zeros(img_resolution, device=pred_depth_crops[inner_idx].device) | |
| pred_map[y: y+height, x: x+width] = pred_depth_crops[inner_idx] | |
| avg_depth_map.update(pred_map, count_map) | |
| inner_idx += 1 | |
| if init_flag: | |
| avg_depth_map = RunningAverageMap(pred_depth, count_map) | |
| # blur_mask = generatemask_coarse(img_resolution) | |
| # blur_mask = torch.tensor(blur_mask, device=whole_depth_pred.device) | |
| # count_map = (1 - blur_mask) | |
| # pred_map = whole_depth_pred * (1 - blur_mask) | |
| # avg_depth_map.update(pred_map, count_map) | |
| return avg_depth_map | |
| def random_tile(model, image, img_lr=None, iter_pred=None, boundary=0, update=False, avg_depth_map=None, blr_mask=False): | |
| height = CROP_SIZE[0] | |
| width = CROP_SIZE[1] | |
| x_start = [random.randint(0, IMG_RESOLUTION[1] - width - 1)] | |
| y_start = [random.randint(0, IMG_RESOLUTION[0] - height - 1)] | |
| imgs_crop = [] | |
| crop_areas = [] | |
| bboxs_roi = [] | |
| bboxs_raw = [] | |
| if iter_pred is not None: | |
| iter_pred = iter_pred.unsqueeze(dim=0).unsqueeze(dim=0) | |
| iter_priors = [] | |
| for x in x_start: # w | |
| for y in y_start: # h | |
| bbox = (int(y), int(y+height), int(x), int(x+width)) | |
| img_crop, crop_area = crop(image, bbox) | |
| imgs_crop.append(img_crop) | |
| crop_areas.append(crop_area) | |
| crop_y1, crop_y2, crop_x1, crop_x2 = bbox | |
| bbox_roi = torch.tensor([crop_x1 / IMG_RESOLUTION[1] * 512, crop_y1 / IMG_RESOLUTION[0] * 384, crop_x2 / IMG_RESOLUTION[1] * 512, crop_y2 / IMG_RESOLUTION[0] * 384]) | |
| bboxs_roi.append(bbox_roi) | |
| bbox_raw = torch.tensor([crop_x1, crop_y1, crop_x2, crop_y2]) | |
| bboxs_raw.append(bbox_raw) | |
| if iter_pred is not None: | |
| iter_prior, _ = crop(iter_pred, bbox) | |
| iter_priors.append(iter_prior) | |
| crop_areas = torch.cat(crop_areas, dim=0) | |
| imgs_crop = torch.cat(imgs_crop, dim=0) | |
| bboxs_roi = torch.stack(bboxs_roi, dim=0) | |
| bboxs_raw = torch.stack(bboxs_raw, dim=0) | |
| if iter_pred is not None: | |
| iter_priors = torch.cat(iter_priors, dim=0) | |
| iter_priors = TRANSFORM(iter_priors) | |
| iter_priors = iter_priors.cuda().float() | |
| crop_areas = TRANSFORM(crop_areas) | |
| imgs_crop = TRANSFORM(imgs_crop) | |
| imgs_crop = imgs_crop.cuda().float() | |
| bboxs_roi = bboxs_roi.cuda().float() | |
| crop_areas = crop_areas.cuda().float() | |
| img_lr = img_lr.cuda().float() | |
| pred_depth_crops = [] | |
| with torch.no_grad(): | |
| for i, (img, bbox, crop_area) in enumerate(zip(imgs_crop, bboxs_roi, crop_areas)): | |
| if iter_pred is not None: | |
| iter_prior = iter_priors[i].unsqueeze(dim=0) | |
| else: | |
| iter_prior = None | |
| if i == 0: | |
| out_dict = model(img.unsqueeze(dim=0), mode='eval', image_raw=img_lr, bbox=bbox.unsqueeze(dim=0), crop_area=crop_area.unsqueeze(dim=0), iter_prior=iter_prior if update is True else None) | |
| whole_depth_pred = out_dict['coarse_depth_pred'] | |
| pred_depth_crop = out_dict['metric_depth'] | |
| # return whole_depth_pred.squeeze() | |
| else: | |
| pred_depth_crop = model(img.unsqueeze(dim=0), mode='eval', image_raw=img_lr, bbox=bbox.unsqueeze(dim=0), crop_area=crop_area.unsqueeze(dim=0), iter_prior=iter_prior if update is True else None)['metric_depth'] | |
| pred_depth_crop = nn.functional.interpolate( | |
| pred_depth_crop, (height, width), mode='bilinear', align_corners=True) | |
| # pred_depth_crop = nn.functional.interpolate( | |
| # pred_depth_crop, (height, width), mode='nearest') | |
| pred_depth_crops.append(pred_depth_crop.squeeze()) | |
| whole_depth_pred = whole_depth_pred.squeeze() | |
| ####### stich part | |
| inner_idx = 0 | |
| init_flag = False | |
| iter_pred = iter_pred.squeeze() | |
| pred_depth = iter_pred | |
| blur_mask = generatemask((height, width)) + 1e-3 | |
| for ii, x in enumerate(x_start): | |
| for jj, y in enumerate(y_start): | |
| if init_flag: | |
| # wont be here | |
| crop_temp = copy.deepcopy(whole_depth_pred[y: y+height, x: x+width]) | |
| blur_mask = torch.ones((height, width)) | |
| blur_mask = torch.tensor(blur_mask, device=whole_depth_pred.device) | |
| pred_depth[y: y+height, x: x+width] = blur_mask * pred_depth_crops[inner_idx]+ (1 - blur_mask) * crop_temp | |
| else: | |
| if blr_mask: | |
| blur_mask = torch.tensor(blur_mask, device=whole_depth_pred.device) | |
| count_map = torch.zeros(IMG_RESOLUTION, device=pred_depth_crops[inner_idx].device) | |
| count_map[y: y+height, x: x+width] = blur_mask | |
| pred_map = torch.zeros(IMG_RESOLUTION, device=pred_depth_crops[inner_idx].device) | |
| pred_map[y: y+height, x: x+width] = pred_depth_crops[inner_idx] * blur_mask | |
| avg_depth_map.update(pred_map, count_map) | |
| else: | |
| # ensemble with running mean | |
| if boundary != 0: | |
| count_map = torch.zeros(IMG_RESOLUTION, device=pred_depth_crops[inner_idx].device) | |
| count_map[y+boundary: y+height-boundary, x+boundary: x+width-boundary] = 1 | |
| pred_map = torch.zeros(IMG_RESOLUTION, device=pred_depth_crops[inner_idx].device) | |
| pred_map[y+boundary: y+height-boundary, x+boundary: x+width-boundary] = pred_depth_crops[inner_idx][boundary:-boundary, boundary:-boundary] | |
| avg_depth_map.update(pred_map, count_map) | |
| else: | |
| count_map = torch.zeros(IMG_RESOLUTION, device=pred_depth_crops[inner_idx].device) | |
| count_map[y: y+height, x: x+width] = 1 | |
| pred_map = torch.zeros(IMG_RESOLUTION, device=pred_depth_crops[inner_idx].device) | |
| pred_map[y: y+height, x: x+width] = pred_depth_crops[inner_idx] | |
| avg_depth_map.update(pred_map, count_map) | |
| inner_idx += 1 | |
| if avg_depth_map is None: | |
| return pred_depth | |
| def random_tile_param(model, image, img_lr=None, iter_pred=None, boundary=0, update=False, avg_depth_map=None, blr_mask=False, crop_size=None, | |
| img_resolution=None, transform=None): | |
| height = crop_size[0] | |
| width = crop_size[1] | |
| x_start = [random.randint(0, img_resolution[1] - width - 1)] | |
| y_start = [random.randint(0, img_resolution[0] - height - 1)] | |
| imgs_crop = [] | |
| crop_areas = [] | |
| bboxs_roi = [] | |
| bboxs_raw = [] | |
| if iter_pred is not None: | |
| iter_pred = iter_pred.unsqueeze(dim=0).unsqueeze(dim=0) | |
| iter_priors = [] | |
| for x in x_start: # w | |
| for y in y_start: # h | |
| bbox = (int(y), int(y+height), int(x), int(x+width)) | |
| img_crop, crop_area = crop(image, bbox) | |
| imgs_crop.append(img_crop) | |
| crop_areas.append(crop_area) | |
| crop_y1, crop_y2, crop_x1, crop_x2 = bbox | |
| bbox_roi = torch.tensor([crop_x1 / img_resolution[1] * 512, crop_y1 / img_resolution[0] * 384, crop_x2 / img_resolution[1] * 512, crop_y2 / img_resolution[0] * 384]) | |
| bboxs_roi.append(bbox_roi) | |
| bbox_raw = torch.tensor([crop_x1, crop_y1, crop_x2, crop_y2]) | |
| bboxs_raw.append(bbox_raw) | |
| if iter_pred is not None: | |
| iter_prior, _ = crop(iter_pred, bbox) | |
| iter_priors.append(iter_prior) | |
| crop_areas = torch.cat(crop_areas, dim=0) | |
| imgs_crop = torch.cat(imgs_crop, dim=0) | |
| bboxs_roi = torch.stack(bboxs_roi, dim=0) | |
| bboxs_raw = torch.stack(bboxs_raw, dim=0) | |
| if iter_pred is not None: | |
| iter_priors = torch.cat(iter_priors, dim=0) | |
| iter_priors = transform(iter_priors) | |
| iter_priors = iter_priors.cuda().float() | |
| crop_areas = transform(crop_areas) | |
| imgs_crop = transform(imgs_crop) | |
| imgs_crop = imgs_crop.cuda().float() | |
| bboxs_roi = bboxs_roi.cuda().float() | |
| crop_areas = crop_areas.cuda().float() | |
| img_lr = img_lr.cuda().float() | |
| pred_depth_crops = [] | |
| with torch.no_grad(): | |
| for i, (img, bbox, crop_area) in enumerate(zip(imgs_crop, bboxs_roi, crop_areas)): | |
| if iter_pred is not None: | |
| iter_prior = iter_priors[i].unsqueeze(dim=0) | |
| else: | |
| iter_prior = None | |
| if i == 0: | |
| out_dict = model(img.unsqueeze(dim=0), mode='eval', image_raw=img_lr, bbox=bbox.unsqueeze(dim=0), crop_area=crop_area.unsqueeze(dim=0), iter_prior=iter_prior if update is True else None) | |
| whole_depth_pred = out_dict['coarse_depth_pred'] | |
| pred_depth_crop = out_dict['metric_depth'] | |
| # return whole_depth_pred.squeeze() | |
| else: | |
| pred_depth_crop = model(img.unsqueeze(dim=0), mode='eval', image_raw=img_lr, bbox=bbox.unsqueeze(dim=0), crop_area=crop_area.unsqueeze(dim=0), iter_prior=iter_prior if update is True else None)['metric_depth'] | |
| pred_depth_crop = nn.functional.interpolate( | |
| pred_depth_crop, (height, width), mode='bilinear', align_corners=True) | |
| # pred_depth_crop = nn.functional.interpolate( | |
| # pred_depth_crop, (height, width), mode='nearest') | |
| pred_depth_crops.append(pred_depth_crop.squeeze()) | |
| whole_depth_pred = whole_depth_pred.squeeze() | |
| ####### stich part | |
| inner_idx = 0 | |
| init_flag = False | |
| iter_pred = iter_pred.squeeze() | |
| pred_depth = iter_pred | |
| blur_mask = generatemask((height, width)) + 1e-3 | |
| for ii, x in enumerate(x_start): | |
| for jj, y in enumerate(y_start): | |
| if init_flag: | |
| # wont be here | |
| crop_temp = copy.deepcopy(whole_depth_pred[y: y+height, x: x+width]) | |
| blur_mask = torch.ones((height, width)) | |
| blur_mask = torch.tensor(blur_mask, device=whole_depth_pred.device) | |
| pred_depth[y: y+height, x: x+width] = blur_mask * pred_depth_crops[inner_idx]+ (1 - blur_mask) * crop_temp | |
| else: | |
| if blr_mask: | |
| blur_mask = torch.tensor(blur_mask, device=whole_depth_pred.device) | |
| count_map = torch.zeros(img_resolution, device=pred_depth_crops[inner_idx].device) | |
| count_map[y: y+height, x: x+width] = blur_mask | |
| pred_map = torch.zeros(img_resolution, device=pred_depth_crops[inner_idx].device) | |
| pred_map[y: y+height, x: x+width] = pred_depth_crops[inner_idx] * blur_mask | |
| avg_depth_map.update(pred_map, count_map) | |
| else: | |
| # ensemble with running mean | |
| if boundary != 0: | |
| count_map = torch.zeros(img_resolution, device=pred_depth_crops[inner_idx].device) | |
| count_map[y+boundary: y+height-boundary, x+boundary: x+width-boundary] = 1 | |
| pred_map = torch.zeros(img_resolution, device=pred_depth_crops[inner_idx].device) | |
| pred_map[y+boundary: y+height-boundary, x+boundary: x+width-boundary] = pred_depth_crops[inner_idx][boundary:-boundary, boundary:-boundary] | |
| avg_depth_map.update(pred_map, count_map) | |
| else: | |
| count_map = torch.zeros(img_resolution, device=pred_depth_crops[inner_idx].device) | |
| count_map[y: y+height, x: x+width] = 1 | |
| pred_map = torch.zeros(img_resolution, device=pred_depth_crops[inner_idx].device) | |
| pred_map[y: y+height, x: x+width] = pred_depth_crops[inner_idx] | |
| avg_depth_map.update(pred_map, count_map) | |
| inner_idx += 1 | |
| if avg_depth_map is None: | |
| return pred_depth | |
| def colorize_infer(value, cmap='magma_r', vmin=None, vmax=None): | |
| # normalize | |
| vmin = value.min() if vmin is None else vmin | |
| # vmax = value.max() if vmax is None else vmax | |
| vmax = np.percentile(value, 95) if vmax is None else vmax | |
| if vmin != vmax: | |
| value = (value - vmin) / (vmax - vmin) # vmin..vmax | |
| else: | |
| value = value * 0. | |
| cmapper = matplotlib.cm.get_cmap(cmap) | |
| value = cmapper(value, bytes=True) # ((1)xhxwx4) | |
| value = value[:, :, :3] # bgr -> rgb | |
| rgb_value = value[..., ::-1] | |
| return rgb_value | |
| def colorize(value, vmin=None, vmax=None, cmap='turbo_r', invalid_val=-99, invalid_mask=None, background_color=(128, 128, 128, 255), gamma_corrected=False, value_transform=None, dataset_name=None): | |
| """Converts a depth map to a color image. | |
| Args: | |
| value (torch.Tensor, numpy.ndarry): Input depth map. Shape: (H, W) or (1, H, W) or (1, 1, H, W). All singular dimensions are squeezed | |
| vmin (float, optional): vmin-valued entries are mapped to start color of cmap. If None, value.min() is used. Defaults to None. | |
| vmax (float, optional): vmax-valued entries are mapped to end color of cmap. If None, value.max() is used. Defaults to None. | |
| cmap (str, optional): matplotlib colormap to use. Defaults to 'magma_r'. | |
| invalid_val (int, optional): Specifies value of invalid pixels that should be colored as 'background_color'. Defaults to -99. | |
| invalid_mask (numpy.ndarray, optional): Boolean mask for invalid regions. Defaults to None. | |
| background_color (tuple[int], optional): 4-tuple RGB color to give to invalid pixels. Defaults to (128, 128, 128, 255). | |
| gamma_corrected (bool, optional): Apply gamma correction to colored image. Defaults to False. | |
| value_transform (Callable, optional): Apply transform function to valid pixels before coloring. Defaults to None. | |
| Returns: | |
| numpy.ndarray, dtype - uint8: Colored depth map. Shape: (H, W, 4) | |
| """ | |
| if isinstance(value, torch.Tensor): | |
| value = value.detach().cpu().numpy() | |
| value = value.squeeze() | |
| if invalid_mask is None: | |
| invalid_mask = value == invalid_val | |
| mask = np.logical_not(invalid_mask) | |
| # normalize | |
| # vmin = np.percentile(value[mask],2) if vmin is None else vmin | |
| # vmin = value.min() if vmin is None else vmin | |
| # vmax = np.percentile(value[mask],95) if vmax is None else vmax | |
| # mid gt | |
| if dataset_name == 'mid': | |
| vmin = np.percentile(value[mask],2) if vmin is None else vmin | |
| vmax = np.percentile(value[mask],85) if vmax is None else vmax | |
| else: | |
| vmin = value.min() if vmin is None else vmin | |
| vmax = np.percentile(value[mask],95) if vmax is None else vmax | |
| if vmin != vmax: | |
| value = (value - vmin) / (vmax - vmin) # vmin..vmax | |
| else: | |
| # Avoid 0-division | |
| value = value * 0. | |
| # squeeze last dim if it exists | |
| # grey out the invalid values | |
| value[invalid_mask] = np.nan | |
| cmapper = matplotlib.cm.get_cmap(cmap) | |
| if value_transform: | |
| value = value_transform(value) | |
| # value = value / value.max() | |
| value = cmapper(value, bytes=True) # (nxmx4) | |
| # img = value[:, :, :] | |
| img = value[...] | |
| img[invalid_mask] = background_color | |
| # return img.transpose((2, 0, 1)) | |
| if gamma_corrected: | |
| # gamma correction | |
| img = img / 255 | |
| img = np.power(img, 2.2) | |
| img = img * 255 | |
| img = img.astype(np.uint8) | |
| return img | |
| def rescale(A, lbound=0, ubound=1): | |
| """ | |
| Rescale an array to [lbound, ubound]. | |
| Parameters: | |
| - A: Input data as numpy array | |
| - lbound: Lower bound of the scale, default is 0. | |
| - ubound: Upper bound of the scale, default is 1. | |
| Returns: | |
| - Rescaled array | |
| """ | |
| A_min = np.min(A) | |
| A_max = np.max(A) | |
| return (ubound - lbound) * (A - A_min) / (A_max - A_min) + lbound | |
| def run(model, dataset, gt_dir=None, show_path=None, show=False, save_flag=False, save_path=None, mode=None, dataset_name=None, base_zoed=False, blr_mask=False): | |
| data_len = len(dataset) | |
| if gt_dir is not None: | |
| metrics_avg = RunningAverageDict() | |
| for image_ind in tqdm(range(data_len)): | |
| if dataset_name == 'nyu': | |
| images, depths, pred_depths = dataset[image_ind] | |
| else: | |
| if gt_dir is None: | |
| images = dataset[image_ind] | |
| else: | |
| images, depths = dataset[image_ind] | |
| # Load image from dataset | |
| img = torch.tensor(images.rgb_image).unsqueeze(dim=0).permute(0, 3, 1, 2) # shape: 1, 3, h, w | |
| img_lr = TRANSFORM(img) | |
| if base_zoed: | |
| with torch.no_grad(): | |
| pred_depth = model(img.cuda())['metric_depth'].squeeze() | |
| avg_depth_map = RunningAverageMap(pred_depth) | |
| else: | |
| # pred_depth, count_map = regular_tile(model, img, offset_x=0, offset_y=0, img_lr=img_lr) | |
| # avg_depth_map = RunningAverageMap(pred_depth, count_map) | |
| avg_depth_map = regular_tile(model, img, offset_x=0, offset_y=0, img_lr=img_lr) | |
| if mode== 'p16': | |
| pass | |
| elif mode== 'p49': | |
| regular_tile(model, img, offset_x=CROP_SIZE[1]//2, offset_y=0, img_lr=img_lr, iter_pred=avg_depth_map.average_map, boundary=BOUNDARY, update=True, avg_depth_map=avg_depth_map, blr_mask=blr_mask) | |
| regular_tile(model, img, offset_x=0, offset_y=CROP_SIZE[0]//2, img_lr=img_lr, iter_pred=avg_depth_map.average_map, boundary=BOUNDARY, update=True, avg_depth_map=avg_depth_map, blr_mask=blr_mask) | |
| regular_tile(model, img, offset_x=CROP_SIZE[1]//2, offset_y=CROP_SIZE[0]//2, img_lr=img_lr, iter_pred=avg_depth_map.average_map, boundary=BOUNDARY, update=True, avg_depth_map=avg_depth_map, blr_mask=blr_mask) | |
| elif mode[0] == 'r': | |
| regular_tile(model, img, offset_x=CROP_SIZE[1]//2, offset_y=0, img_lr=img_lr, iter_pred=avg_depth_map.average_map, boundary=BOUNDARY, update=True, avg_depth_map=avg_depth_map, blr_mask=blr_mask) | |
| regular_tile(model, img, offset_x=0, offset_y=CROP_SIZE[0]//2, img_lr=img_lr, iter_pred=avg_depth_map.average_map, boundary=BOUNDARY, update=True, avg_depth_map=avg_depth_map, blr_mask=blr_mask) | |
| regular_tile(model, img, offset_x=CROP_SIZE[1]//2, offset_y=CROP_SIZE[0]//2, img_lr=img_lr, iter_pred=avg_depth_map.average_map, boundary=BOUNDARY, update=True, avg_depth_map=avg_depth_map, blr_mask=blr_mask) | |
| for i in tqdm(range(int(mode[1:]))): | |
| random_tile(model, img, img_lr=img_lr, iter_pred=avg_depth_map.average_map, boundary=BOUNDARY, update=True, avg_depth_map=avg_depth_map, blr_mask=blr_mask) | |
| if show: | |
| color_map = copy.deepcopy(avg_depth_map.average_map) | |
| color_map = colorize_infer(color_map.detach().cpu().numpy()) | |
| cv2.imwrite(os.path.join(show_path, '{}.png'.format(images.name)), color_map) | |
| if save_flag: | |
| np.save(os.path.join(save_path, '{}.npy'.format(images.name)), avg_depth_map.average_map.squeeze().detach().cpu().numpy()) | |
| # np.save(os.path.join(save_path, '{}.npy'.format(images.name)), depths.gt) | |
| if gt_dir is not None: | |
| if dataset_name == 'nyu': | |
| metrics = compute_metrics(torch.tensor(depths.gt), avg_depth_map.average_map, disp_gt_edges=depths.edge, min_depth_eval=1e-3, max_depth_eval=10, garg_crop=False, eigen_crop=True, dataset='nyu', pred_depths=torch.tensor(pred_depths.gt)) | |
| # metrics = compute_metrics(torch.tensor(depths.gt), avg_depth_map.average_map, disp_gt_edges=depths.edge, min_depth_eval=1e-3, max_depth_eval=10, garg_crop=False, eigen_crop=True, dataset='nyu') | |
| else: | |
| metrics = compute_metrics(torch.tensor(depths.gt), avg_depth_map.average_map, disp_gt_edges=depths.edge, min_depth_eval=1e-3, max_depth_eval=80, garg_crop=False, eigen_crop=False, dataset='') | |
| metrics_avg.update(metrics) | |
| print(metrics) | |
| if gt_dir is not None: | |
| print(metrics_avg.get_value()) | |
| else: | |
| print("successful!") | |
| return avg_depth_map | |
| #### | |
| if __name__ == '__main__': | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument('--rgb_dir', type=str, required=True) | |
| parser.add_argument('--show_path', type=str, required=None) | |
| parser.add_argument("--ckp_path", type=str, required=True) | |
| parser.add_argument("-m", "--model", type=str, default="zoedepth") | |
| parser.add_argument("--model_cfg_path", type=str, default="") | |
| parser.add_argument("--gt_dir", type=str, default=None) | |
| parser.add_argument("--dataset_name", type=str, default=None) | |
| parser.add_argument("--show", action='store_true') | |
| parser.add_argument("--save", action='store_true') | |
| parser.add_argument("--save_path", type=str, default=None) | |
| parser.add_argument("--img_resolution", type=str, default=None) | |
| parser.add_argument("--crop_size", type=str, default=None) | |
| parser.add_argument("--mode", type=str, default=None) | |
| parser.add_argument("--base_zoed", action='store_true') | |
| parser.add_argument("--boundary", type=int, default=0) | |
| parser.add_argument("--blur_mask", action='store_true') | |
| args, unknown_args = parser.parse_known_args() | |
| # prepare some global args | |
| global IMG_RESOLUTION | |
| if args.dataset_name == 'u4k': | |
| IMG_RESOLUTION = (2160, 3840) | |
| elif args.dataset_name == 'gta': | |
| IMG_RESOLUTION = (1080, 1920) | |
| elif args.dataset_name == 'nyu': | |
| IMG_RESOLUTION = (480, 640) | |
| else: | |
| IMG_RESOLUTION = (2160, 3840) | |
| global TRANSFORM | |
| TRANSFORM = Compose([Resize(512, 384, keep_aspect_ratio=False, ensure_multiple_of=32, resize_method="minimal")]) | |
| global BOUNDARY | |
| BOUNDARY = args.boundary | |
| if args.img_resolution is not None: | |
| IMG_RESOLUTION = (int(args.img_resolution.split('x')[0]), int(args.img_resolution.split('x')[1])) | |
| global CROP_SIZE | |
| CROP_SIZE = (int(IMG_RESOLUTION[0] // 4), int(IMG_RESOLUTION[1] // 4)) | |
| if args.crop_size is not None: | |
| CROP_SIZE = (int(args.crop_size.split('x')[0]), int(args.crop_size.split('x')[1])) | |
| print("\nCurrent image resolution: {}\n Current crop size: {}".format(IMG_RESOLUTION, CROP_SIZE)) | |
| overwrite_kwargs = parse_unknown(unknown_args) | |
| overwrite_kwargs['model_cfg_path'] = args.model_cfg_path | |
| overwrite_kwargs["model"] = args.model | |
| # blur_mask_crop = generatemask(CROP_SIZE) | |
| # plt.imshow(blur_mask_crop) | |
| # plt.savefig('./nfs/results_show/crop_mask.png') | |
| # blur_mask_crop = generatemask_coarse(IMG_RESOLUTION) | |
| # plt.imshow(blur_mask_crop) | |
| # plt.savefig('./nfs/results_show/whole_mask.png') | |
| config = get_config_user(args.model, **overwrite_kwargs) | |
| config["pretrained_resource"] = '' | |
| model = build_model(config) | |
| model = load_ckpt(model, args.ckp_path) | |
| model.eval() | |
| model.cuda() | |
| # Create dataset from input images | |
| dataset_custom = ImageDataset(args.rgb_dir, args.gt_dir, args.dataset_name) | |
| # start running | |
| if args.show: | |
| os.makedirs(args.show_path, exist_ok=True) | |
| if args.save: | |
| os.makedirs(args.save_path, exist_ok=True) | |
| run(model, dataset_custom, args.gt_dir, args.show_path, args.show, args.save, args.save_path, mode=args.mode, dataset_name=args.dataset_name, base_zoed=args.base_zoed, blr_mask=args.blur_mask) | |