guoyb0 commited on
Commit
81e8304
·
verified ·
1 Parent(s): f6b2a4d

Update src/dataset/atlas_dataset.py: precomputed map tokens support + NCCL timeout

Browse files
Files changed (1) hide show
  1. src/dataset/atlas_dataset.py +34 -0
src/dataset/atlas_dataset.py CHANGED
@@ -103,6 +103,7 @@ class AtlasDataset(Dataset):
103
  image_path_remap: Optional[str] = None,
104
  precomputed_det_tokens: Optional[str] = None,
105
  require_precomputed_det_tokens: bool = False,
 
106
  ):
107
  self.json_file = json_file
108
  self.image_root = image_root
@@ -112,6 +113,7 @@ class AtlasDataset(Dataset):
112
  self.image_path_remap = (old, new)
113
  self.precomputed_det_dir = precomputed_det_tokens
114
  self.require_precomputed = require_precomputed_det_tokens
 
115
  self._precomputed_path_map: Optional[Dict[str, str]] = None
116
  if self.precomputed_det_dir and os.path.isdir(self.precomputed_det_dir):
117
  self._precomputed_path_map = {}
@@ -120,6 +122,14 @@ class AtlasDataset(Dataset):
120
  if fname.endswith(".pt"):
121
  self._precomputed_path_map[fname[:-3]] = os.path.join(root_dir, fname)
122
  print(f"Precomputed det tokens index: {len(self._precomputed_path_map)} files")
 
 
 
 
 
 
 
 
123
  self.tokenizer = tokenizer
124
  self.max_length = max_length
125
  self.is_training = is_training
@@ -406,6 +416,11 @@ class AtlasDataset(Dataset):
406
  result["precomputed_det"] = pt["detection"]
407
  result["precomputed_det_ref"] = pt["detection_ref_points"]
408
 
 
 
 
 
 
409
  if os.getenv("ATLAS_AUDIT", "0") not in ("", "0", "false", "False"):
410
  max_samples = int(os.getenv("ATLAS_AUDIT_MAX_SAMPLES", "1"))
411
  if idx < max_samples:
@@ -636,6 +651,18 @@ class AtlasDataset(Dataset):
636
  except Exception:
637
  return None
638
 
 
 
 
 
 
 
 
 
 
 
 
 
639
  def _get_scene_id(self, item: Dict) -> str:
640
  if "segment_id" in item and item["segment_id"]:
641
  return str(item["segment_id"])
@@ -1102,6 +1129,13 @@ def atlas_collate_fn(
1102
  result["precomputed_det"] = torch.stack([item["precomputed_det"] for item in batch])
1103
  result["precomputed_det_ref"] = torch.stack([item["precomputed_det_ref"] for item in batch])
1104
 
 
 
 
 
 
 
 
1105
  audit_keys = [
1106
  "audit_prompt_len",
1107
  "audit_answer_len",
 
103
  image_path_remap: Optional[str] = None,
104
  precomputed_det_tokens: Optional[str] = None,
105
  require_precomputed_det_tokens: bool = False,
106
+ precomputed_map_tokens: Optional[str] = None,
107
  ):
108
  self.json_file = json_file
109
  self.image_root = image_root
 
113
  self.image_path_remap = (old, new)
114
  self.precomputed_det_dir = precomputed_det_tokens
115
  self.require_precomputed = require_precomputed_det_tokens
116
+ self.precomputed_map_dir = precomputed_map_tokens
117
  self._precomputed_path_map: Optional[Dict[str, str]] = None
118
  if self.precomputed_det_dir and os.path.isdir(self.precomputed_det_dir):
119
  self._precomputed_path_map = {}
 
122
  if fname.endswith(".pt"):
123
  self._precomputed_path_map[fname[:-3]] = os.path.join(root_dir, fname)
124
  print(f"Precomputed det tokens index: {len(self._precomputed_path_map)} files")
125
+ self._precomputed_map_path_map: Optional[Dict[str, str]] = None
126
+ if self.precomputed_map_dir and os.path.isdir(self.precomputed_map_dir):
127
+ self._precomputed_map_path_map = {}
128
+ for root_dir, _, files in os.walk(self.precomputed_map_dir):
129
+ for fname in files:
130
+ if fname.endswith(".pt"):
131
+ self._precomputed_map_path_map[fname[:-3]] = os.path.join(root_dir, fname)
132
+ print(f"Precomputed map tokens index: {len(self._precomputed_map_path_map)} files")
133
  self.tokenizer = tokenizer
134
  self.max_length = max_length
135
  self.is_training = is_training
 
416
  result["precomputed_det"] = pt["detection"]
417
  result["precomputed_det_ref"] = pt["detection_ref_points"]
418
 
419
+ if self.precomputed_map_dir:
420
+ mpt = self._load_precomputed_map(item)
421
+ if mpt is not None:
422
+ result["precomputed_map"] = mpt
423
+
424
  if os.getenv("ATLAS_AUDIT", "0") not in ("", "0", "false", "False"):
425
  max_samples = int(os.getenv("ATLAS_AUDIT_MAX_SAMPLES", "1"))
426
  if idx < max_samples:
 
651
  except Exception:
652
  return None
653
 
654
+ def _load_precomputed_map(self, item: Dict) -> Optional[Dict]:
655
+ if not self.precomputed_map_dir or self._precomputed_map_path_map is None:
656
+ return None
657
+ item_id = str(item.get("id", ""))
658
+ pt_path = self._precomputed_map_path_map.get(item_id)
659
+ if pt_path is None:
660
+ return None
661
+ try:
662
+ return torch.load(pt_path, map_location="cpu")
663
+ except Exception:
664
+ return None
665
+
666
  def _get_scene_id(self, item: Dict) -> str:
667
  if "segment_id" in item and item["segment_id"]:
668
  return str(item["segment_id"])
 
1129
  result["precomputed_det"] = torch.stack([item["precomputed_det"] for item in batch])
1130
  result["precomputed_det_ref"] = torch.stack([item["precomputed_det_ref"] for item in batch])
1131
 
1132
+ if all("precomputed_map" in item for item in batch):
1133
+ map_keys = list(batch[0]["precomputed_map"].keys())
1134
+ result["precomputed_map"] = {
1135
+ k: torch.stack([item["precomputed_map"][k] for item in batch])
1136
+ for k in map_keys
1137
+ }
1138
+
1139
  audit_keys = [
1140
  "audit_prompt_len",
1141
  "audit_answer_len",