Update src/dataset/atlas_dataset.py: precomputed map tokens support + NCCL timeout
Browse files- src/dataset/atlas_dataset.py +34 -0
src/dataset/atlas_dataset.py
CHANGED
|
@@ -103,6 +103,7 @@ class AtlasDataset(Dataset):
|
|
| 103 |
image_path_remap: Optional[str] = None,
|
| 104 |
precomputed_det_tokens: Optional[str] = None,
|
| 105 |
require_precomputed_det_tokens: bool = False,
|
|
|
|
| 106 |
):
|
| 107 |
self.json_file = json_file
|
| 108 |
self.image_root = image_root
|
|
@@ -112,6 +113,7 @@ class AtlasDataset(Dataset):
|
|
| 112 |
self.image_path_remap = (old, new)
|
| 113 |
self.precomputed_det_dir = precomputed_det_tokens
|
| 114 |
self.require_precomputed = require_precomputed_det_tokens
|
|
|
|
| 115 |
self._precomputed_path_map: Optional[Dict[str, str]] = None
|
| 116 |
if self.precomputed_det_dir and os.path.isdir(self.precomputed_det_dir):
|
| 117 |
self._precomputed_path_map = {}
|
|
@@ -120,6 +122,14 @@ class AtlasDataset(Dataset):
|
|
| 120 |
if fname.endswith(".pt"):
|
| 121 |
self._precomputed_path_map[fname[:-3]] = os.path.join(root_dir, fname)
|
| 122 |
print(f"Precomputed det tokens index: {len(self._precomputed_path_map)} files")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 123 |
self.tokenizer = tokenizer
|
| 124 |
self.max_length = max_length
|
| 125 |
self.is_training = is_training
|
|
@@ -406,6 +416,11 @@ class AtlasDataset(Dataset):
|
|
| 406 |
result["precomputed_det"] = pt["detection"]
|
| 407 |
result["precomputed_det_ref"] = pt["detection_ref_points"]
|
| 408 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 409 |
if os.getenv("ATLAS_AUDIT", "0") not in ("", "0", "false", "False"):
|
| 410 |
max_samples = int(os.getenv("ATLAS_AUDIT_MAX_SAMPLES", "1"))
|
| 411 |
if idx < max_samples:
|
|
@@ -636,6 +651,18 @@ class AtlasDataset(Dataset):
|
|
| 636 |
except Exception:
|
| 637 |
return None
|
| 638 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 639 |
def _get_scene_id(self, item: Dict) -> str:
|
| 640 |
if "segment_id" in item and item["segment_id"]:
|
| 641 |
return str(item["segment_id"])
|
|
@@ -1102,6 +1129,13 @@ def atlas_collate_fn(
|
|
| 1102 |
result["precomputed_det"] = torch.stack([item["precomputed_det"] for item in batch])
|
| 1103 |
result["precomputed_det_ref"] = torch.stack([item["precomputed_det_ref"] for item in batch])
|
| 1104 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1105 |
audit_keys = [
|
| 1106 |
"audit_prompt_len",
|
| 1107 |
"audit_answer_len",
|
|
|
|
| 103 |
image_path_remap: Optional[str] = None,
|
| 104 |
precomputed_det_tokens: Optional[str] = None,
|
| 105 |
require_precomputed_det_tokens: bool = False,
|
| 106 |
+
precomputed_map_tokens: Optional[str] = None,
|
| 107 |
):
|
| 108 |
self.json_file = json_file
|
| 109 |
self.image_root = image_root
|
|
|
|
| 113 |
self.image_path_remap = (old, new)
|
| 114 |
self.precomputed_det_dir = precomputed_det_tokens
|
| 115 |
self.require_precomputed = require_precomputed_det_tokens
|
| 116 |
+
self.precomputed_map_dir = precomputed_map_tokens
|
| 117 |
self._precomputed_path_map: Optional[Dict[str, str]] = None
|
| 118 |
if self.precomputed_det_dir and os.path.isdir(self.precomputed_det_dir):
|
| 119 |
self._precomputed_path_map = {}
|
|
|
|
| 122 |
if fname.endswith(".pt"):
|
| 123 |
self._precomputed_path_map[fname[:-3]] = os.path.join(root_dir, fname)
|
| 124 |
print(f"Precomputed det tokens index: {len(self._precomputed_path_map)} files")
|
| 125 |
+
self._precomputed_map_path_map: Optional[Dict[str, str]] = None
|
| 126 |
+
if self.precomputed_map_dir and os.path.isdir(self.precomputed_map_dir):
|
| 127 |
+
self._precomputed_map_path_map = {}
|
| 128 |
+
for root_dir, _, files in os.walk(self.precomputed_map_dir):
|
| 129 |
+
for fname in files:
|
| 130 |
+
if fname.endswith(".pt"):
|
| 131 |
+
self._precomputed_map_path_map[fname[:-3]] = os.path.join(root_dir, fname)
|
| 132 |
+
print(f"Precomputed map tokens index: {len(self._precomputed_map_path_map)} files")
|
| 133 |
self.tokenizer = tokenizer
|
| 134 |
self.max_length = max_length
|
| 135 |
self.is_training = is_training
|
|
|
|
| 416 |
result["precomputed_det"] = pt["detection"]
|
| 417 |
result["precomputed_det_ref"] = pt["detection_ref_points"]
|
| 418 |
|
| 419 |
+
if self.precomputed_map_dir:
|
| 420 |
+
mpt = self._load_precomputed_map(item)
|
| 421 |
+
if mpt is not None:
|
| 422 |
+
result["precomputed_map"] = mpt
|
| 423 |
+
|
| 424 |
if os.getenv("ATLAS_AUDIT", "0") not in ("", "0", "false", "False"):
|
| 425 |
max_samples = int(os.getenv("ATLAS_AUDIT_MAX_SAMPLES", "1"))
|
| 426 |
if idx < max_samples:
|
|
|
|
| 651 |
except Exception:
|
| 652 |
return None
|
| 653 |
|
| 654 |
+
def _load_precomputed_map(self, item: Dict) -> Optional[Dict]:
|
| 655 |
+
if not self.precomputed_map_dir or self._precomputed_map_path_map is None:
|
| 656 |
+
return None
|
| 657 |
+
item_id = str(item.get("id", ""))
|
| 658 |
+
pt_path = self._precomputed_map_path_map.get(item_id)
|
| 659 |
+
if pt_path is None:
|
| 660 |
+
return None
|
| 661 |
+
try:
|
| 662 |
+
return torch.load(pt_path, map_location="cpu")
|
| 663 |
+
except Exception:
|
| 664 |
+
return None
|
| 665 |
+
|
| 666 |
def _get_scene_id(self, item: Dict) -> str:
|
| 667 |
if "segment_id" in item and item["segment_id"]:
|
| 668 |
return str(item["segment_id"])
|
|
|
|
| 1129 |
result["precomputed_det"] = torch.stack([item["precomputed_det"] for item in batch])
|
| 1130 |
result["precomputed_det_ref"] = torch.stack([item["precomputed_det_ref"] for item in batch])
|
| 1131 |
|
| 1132 |
+
if all("precomputed_map" in item for item in batch):
|
| 1133 |
+
map_keys = list(batch[0]["precomputed_map"].keys())
|
| 1134 |
+
result["precomputed_map"] = {
|
| 1135 |
+
k: torch.stack([item["precomputed_map"][k] for item in batch])
|
| 1136 |
+
for k in map_keys
|
| 1137 |
+
}
|
| 1138 |
+
|
| 1139 |
audit_keys = [
|
| 1140 |
"audit_prompt_len",
|
| 1141 |
"audit_answer_len",
|