#!/usr/bin/env python3 import argparse import os import sys import json import logging from pathlib import Path from typing import Dict, List, Optional from collections import defaultdict import torch import numpy as np sys.path.insert(0, str(Path(__file__).resolve().parent)) from src.model.modeling_atlas import AtlasForCausalLM from src.model.topomlp_adapter import TopoMLPToAtlasMapTokens from src.model.streampetr_adapter import extract_streampetr_topk_tokens from src.dataset.atlas_dataset import AtlasDataset, make_atlas_collate_fn, load_tokenizer from src.eval.metrics import ( parse_atlas_output, parse_planning_output, normalize_ground_truths, calculate_multi_threshold_detection_f1, calculate_lane_detection_metrics, calculate_planning_metrics, ) logger = logging.getLogger("eval_atlas") def parse_args(): p = argparse.ArgumentParser() p.add_argument("--checkpoint", required=True) p.add_argument("--llm_model", default="lmsys/vicuna-7b-v1.5") p.add_argument("--visual_hidden_size", type=int, default=256) p.add_argument("--num_det_queries", type=int, default=256) p.add_argument("--num_map_queries", type=int, default=256) p.add_argument("--streampetr_config", default=None) p.add_argument("--streampetr_ckpt", default=None) p.add_argument("--topomlp_config", default=None) p.add_argument("--topomlp_ckpt", default=None) p.add_argument("--data_json", required=True) p.add_argument("--data_root", default="/mnt/data/nuscenes") p.add_argument("--max_length", type=int, default=4096) p.add_argument("--max_new_tokens", type=int, default=2700) p.add_argument("--batch_size", type=int, default=1) p.add_argument("--num_workers", type=int, default=4) p.add_argument("--use_lora", action="store_true") p.add_argument("--lora_r", type=int, default=64) p.add_argument("--lora_alpha", type=int, default=64) p.add_argument("--load_in_4bit", action="store_true") p.add_argument("--output_json", default=None) p.add_argument("--max_samples", type=int, default=0) p.add_argument("--fp16", action="store_true") p.add_argument("--bf16", action="store_true") p.add_argument("--no_flash_attn", action="store_true") p.add_argument("--image_path_remap", default=None, help="old=new path remap, e.g. /home/guoyuanbo/autodl-tmp/OpenLane-V2=/mnt/OpenLane-V2") return p.parse_args() def infer_task(item: Dict) -> str: task = item.get("task", None) if task: return str(task) conv = item.get("conversations", []) prompt = "" for turn in conv: if turn.get("from") in ("human", "user"): prompt = turn.get("value", "").lower() break if "lane" in prompt: return "lane" if "waypoint" in prompt or "trajectory" in prompt or "planning" in prompt: return "planning" if "caption" in prompt or "describe" in prompt: return "caption" return "detection" def load_frozen_encoder(config_path, ckpt_path, model_type, device): if config_path is None or ckpt_path is None: return None try: from mmcv import Config from mmdet3d.models import build_model from mmcv.runner import load_checkpoint except ImportError: logger.warning("mmcv/mmdet3d not available, skipping %s", model_type) return None project_root = Path(__file__).resolve().parent if model_type == "streampetr": sp_root = str(project_root / "external" / "StreamPETR") if sp_root not in sys.path: sys.path.insert(0, sp_root) try: import projects.mmdet3d_plugin # noqa: F401 except ImportError: return None elif model_type == "topomlp": tp_root = str(project_root / "external" / "TopoMLP_Repo") if tp_root not in sys.path: sys.path.insert(0, tp_root) try: os.environ["ATLAS_TOPOMLP_MODELS_ONLY"] = "1" # Allow duplicate registry entries (shared components with StreamPETR) from mmcv.utils import registry as _reg _orig = _reg.Registry._register_module def _tolerant_register(self, module, module_name=None, force=False): return _orig(self, module, module_name=module_name, force=True) _reg.Registry._register_module = _tolerant_register import projects.topomlp # noqa: F401 _reg.Registry._register_module = _orig except ImportError: return None cfg = Config.fromfile(config_path) model = build_model(cfg.model, test_cfg=cfg.get("test_cfg")) load_checkpoint(model, ckpt_path, map_location="cpu") model.eval() model.to(device) for param in model.parameters(): param.requires_grad_(False) return model def _run_streampetr_forward(model, imgs, img_metas, batch, device, prev_exists=None): B, N = imgs.shape[:2] img_feats = model.extract_img_feat(imgs, 1) data = { "img": imgs, "img_feats": img_feats, "prev_exists": prev_exists if prev_exists is not None else imgs.new_zeros(B), } if "intrinsics_det" in batch: K3 = batch["intrinsics_det"].to(device) K4 = torch.zeros(B, N, 4, 4, device=device, dtype=K3.dtype) K4[:, :, :3, :3] = K3 K4[:, :, 3, 3] = 1.0 data["intrinsics"] = K4 else: data["intrinsics"] = torch.eye(4, device=device).unsqueeze(0).unsqueeze(0).expand(B, N, -1, -1).contiguous() if "lidar2img_det" in batch: data["lidar2img"] = batch["lidar2img_det"].to(device) else: data["lidar2img"] = torch.eye(4, device=device).unsqueeze(0).unsqueeze(0).expand(B, N, -1, -1).contiguous() if "ego_pose" in batch and batch["ego_pose"] is not None: data["ego_pose"] = batch["ego_pose"].to(device) else: data["ego_pose"] = torch.eye(4, device=device).unsqueeze(0).expand(B, -1, -1).contiguous() if "ego_pose_inv" in batch and batch["ego_pose_inv"] is not None: data["ego_pose_inv"] = batch["ego_pose_inv"].to(device) else: data["ego_pose_inv"] = torch.inverse(data["ego_pose"]) if "timestamp" in batch and batch["timestamp"] is not None: data["timestamp"] = batch["timestamp"].to(device) else: data["timestamp"] = torch.zeros(B, device=device) location = model.prepare_location(img_metas, **data) outs_roi = model.forward_roi_head(location, **data) topk_indexes = outs_roi["topk_indexes"] outs = model.pts_bbox_head(location, img_metas, topk_indexes, **data) return outs def extract_visual_tokens( streampetr_model, topomlp_model, topomlp_adapter, batch, device, num_det_queries, visual_hidden_size, prev_exists=None, query_token_id=None, ): B = batch["pixel_values_det"].shape[0] N = batch["pixel_values_det"].shape[1] vis: Dict[str, torch.Tensor] = {} needs_map = False if query_token_id is not None and "input_ids" in batch: n_queries = int((batch["input_ids"] == query_token_id).sum(dim=-1).max().item()) needs_map = n_queries > num_det_queries if streampetr_model is not None: imgs_det = batch["pixel_values_det"].to(device) fH, fW = 800, 1600 scene_ids = batch.get("scene_id", ["__atlas_eval__"] * B) img_metas = [{ "pad_shape": [(fH, fW, 3)] * N, "img_shape": [(fH, fW, 3)] * N, "scene_token": scene_ids[b] if b < len(scene_ids) else "__atlas_eval__", } for b in range(B)] if "lidar2img_det" in batch: for b in range(B): img_metas[b]["lidar2img"] = batch["lidar2img_det"][b].cpu().numpy() with torch.no_grad(): _run_streampetr_forward(streampetr_model, imgs_det, img_metas, batch, device, prev_exists=prev_exists) ego_pose_for_ref = batch.get("ego_pose") if ego_pose_for_ref is not None: ego_pose_for_ref = ego_pose_for_ref.to(device) det_out = extract_streampetr_topk_tokens( streampetr_model.pts_bbox_head, topk=num_det_queries, ego_pose=ego_pose_for_ref, ) vis["detection"] = det_out["detection"] vis["detection_ref_points"] = det_out["detection_ref_points"] else: vis["detection"] = torch.zeros(B, num_det_queries, visual_hidden_size, device=device) vis["detection_ref_points"] = torch.zeros(B, num_det_queries, 3, device=device) if needs_map and topomlp_model is not None and topomlp_adapter is not None: imgs_map = batch["pixel_values_map"].to(device) img_metas = [] for b in range(B): meta = {"img_shape": [(800, 1600, 3)] * N, "pad_shape": [(800, 1600, 3)] * N} meta["scale_factor"] = 1.0 meta["te_yolov8"] = None if "lidar2img_map" in batch: meta["lidar2img"] = batch["lidar2img_map"][b].cpu().numpy() img_metas.append(meta) with torch.no_grad(): outs = topomlp_model.simple_forward(imgs_map, img_metas) map_out = topomlp_adapter(outs) vis["map"] = map_out["map"] vis["map_ref_points"] = map_out["map_ref_points"] return vis def parse_gt_from_item(item: Dict, task: str) -> Dict: gt = {} if task == "detection": annotations = item.get("gt_boxes_3d", item.get("annotations", [])) gt_dets = [] for ann in annotations: if isinstance(ann, dict): cat = ann.get("category_name", ann.get("category", "unknown")) if "box" in ann: coords = ann["box"][:3] elif "translation" in ann: coords = ann["translation"][:3] else: continue gt_dets.append({ "category": cat, "world_coords": list(coords), }) gt["detections"] = normalize_ground_truths(gt_dets) elif task == "lane": conv = item.get("conversations", []) answer = "" for turn in conv: if turn.get("from") in ("gpt", "assistant"): answer = turn.get("value", "") break gt["lanes"] = parse_atlas_output(answer) elif task == "planning": ego = item.get("ego_motion", {}) gt["waypoints"] = ego.get("waypoints", []) gt["gt_boxes"] = item.get("gt_boxes_3d", []) return gt def main(): args = parse_args() logging.basicConfig( level=logging.INFO, format="%(asctime)s [%(levelname)s] %(name)s: %(message)s", ) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") tokenizer = load_tokenizer(args.llm_model) if "" not in tokenizer.get_vocab(): tokenizer.add_tokens([""]) dtype = torch.float32 if args.bf16: dtype = torch.bfloat16 elif args.fp16: dtype = torch.float16 dm = "auto" if args.load_in_4bit else None _use_fa = not getattr(args, 'no_flash_attn', False) atlas = AtlasForCausalLM( llm_model_name=args.llm_model, visual_hidden_size=args.visual_hidden_size, num_queries=args.num_det_queries, num_map_queries=args.num_map_queries, load_in_4bit=args.load_in_4bit, use_flash_attention=_use_fa, device_map=dm, torch_dtype=dtype, use_lora=args.use_lora, lora_r=args.lora_r, lora_alpha=args.lora_alpha, ) atlas.resize_token_embeddings(len(tokenizer)) atlas.set_query_token_id(tokenizer.convert_tokens_to_ids("")) if dm is None: atlas = atlas.to(device) topomlp_adapter = None ckpt = torch.load(args.checkpoint, map_location="cpu") if "atlas_state_dict" not in ckpt: raise RuntimeError(f"Checkpoint missing 'atlas_state_dict'. Keys: {list(ckpt.keys())}") if "atlas_state_dict" in ckpt: has_lora_keys = any("lora_" in k for k in ckpt["atlas_state_dict"]) if has_lora_keys and not args.use_lora: logger.warning( "Checkpoint contains LoRA weights but --use_lora was not set. " "Auto-enabling LoRA to prevent silent degradation." ) args.use_lora = True # Rebuild model with LoRA atlas = AtlasForCausalLM( llm_model_name=args.llm_model, visual_hidden_size=args.visual_hidden_size, num_queries=args.num_det_queries, num_map_queries=args.num_map_queries, load_in_4bit=args.load_in_4bit, use_flash_attention=_use_fa, device_map=dm, torch_dtype=dtype, use_lora=True, lora_r=args.lora_r, lora_alpha=args.lora_alpha, ) atlas.resize_token_embeddings(len(tokenizer)) atlas.set_query_token_id(tokenizer.convert_tokens_to_ids("")) if dm is None: atlas = atlas.to(device) missing, unexpected = atlas.load_state_dict(ckpt["atlas_state_dict"], strict=False) if unexpected: logger.warning("Unexpected keys in checkpoint (possibly ignored): %s", unexpected[:5]) logger.info("Loaded Atlas weights from %s", args.checkpoint) if "adapter_state_dict" in ckpt: _tp_bev_range = (-51.2, -25.6, -8.0, 51.2, 25.6, 4.0) if args.topomlp_config: try: from mmcv import Config as _Cfg _tp_cfg = _Cfg.fromfile(args.topomlp_config) if hasattr(_tp_cfg, "point_cloud_range"): _tp_bev_range = tuple(float(v) for v in _tp_cfg.point_cloud_range) logger.info("TopoMLP bev_range from config: %s", _tp_bev_range) except Exception as e: logger.warning("Failed to read point_cloud_range from TopoMLP config: %s. Using default: %s", e, _tp_bev_range) topomlp_adapter = TopoMLPToAtlasMapTokens( num_map_tokens=args.num_map_queries, hidden_size=args.visual_hidden_size, bev_range=_tp_bev_range, ).to(device) topomlp_adapter.load_state_dict(ckpt["adapter_state_dict"]) topomlp_adapter.eval() atlas.eval() if args.use_lora or any("lora_" in k for k in (ckpt.get("atlas_state_dict", {}))): try: logger.info("Merging LoRA weights into base model for faster inference...") atlas.llm = atlas.llm.merge_and_unload() logger.info("LoRA merge complete.") except Exception as e: logger.warning("LoRA merge failed (%s), continuing with LoRA active.", e) streampetr_model = load_frozen_encoder( args.streampetr_config, args.streampetr_ckpt, "streampetr", device, ) topomlp_model = load_frozen_encoder( args.topomlp_config, args.topomlp_ckpt, "topomlp", device, ) dataset = AtlasDataset( json_file=args.data_json, image_root=args.data_root, tokenizer=tokenizer, max_length=args.max_length, is_training=False, image_path_remap=args.image_path_remap, ) from src.dataset.scene_sampler import SceneSequentialSampler scene_groups = dataset.get_scene_groups() sampler = SceneSequentialSampler(scene_groups) collate_fn = make_atlas_collate_fn(tokenizer.pad_token_id) dataloader = torch.utils.data.DataLoader( dataset, batch_size=args.batch_size, shuffle=False, sampler=sampler, num_workers=args.num_workers, collate_fn=collate_fn, pin_memory=True, ) data_by_id = {str(item.get("id", i)): item for i, item in enumerate(dataset.data)} task_preds: Dict[str, List] = defaultdict(list) task_gts: Dict[str, List] = defaultdict(list) all_outputs: List[Dict] = [] sample_count = 0 prev_scene_id = None logger.info("Starting evaluation on %d samples...", len(dataset)) for batch_idx, batch in enumerate(dataloader): if args.max_samples > 0 and sample_count >= args.max_samples: break B = batch["input_ids"].shape[0] input_ids = batch["input_ids"].to(device) attention_mask = batch["attention_mask"].to(device) cur_scene = batch["scene_id"][0] if "scene_id" in batch else None if cur_scene == prev_scene_id and cur_scene is not None: prev_exists = torch.ones(B, device=device) else: prev_exists = torch.zeros(B, device=device) if streampetr_model is not None: streampetr_model.pts_bbox_head.reset_memory() prev_scene_id = cur_scene visual_features = extract_visual_tokens( streampetr_model, topomlp_model, topomlp_adapter, batch, device, args.num_det_queries, args.visual_hidden_size, prev_exists=prev_exists, query_token_id=tokenizer.convert_tokens_to_ids(""), ) with torch.no_grad(): generated_ids = atlas.generate( input_ids=input_ids, attention_mask=attention_mask, visual_features=visual_features, max_new_tokens=args.max_new_tokens, do_sample=False, ) for b in range(B): if args.max_samples > 0 and sample_count >= args.max_samples: break gen_text = tokenizer.decode(generated_ids[b], skip_special_tokens=True) sample_id = batch["sample_id"][b] if "sample_id" in batch else str(sample_count) item = data_by_id.get(sample_id) if item is None: logger.warning("sample_id %s not found in data_by_id, skipping", sample_id) sample_count += 1 continue task = infer_task(item) gt = parse_gt_from_item(item, task) record = { "sample_id": sample_id, "task": task, "generated_text": gen_text, } if task == "detection": preds = parse_atlas_output(gen_text) det_preds = [p for p in preds if p.get("type") == "detection"] gt_dets = gt.get("detections", []) task_preds["detection"].append(det_preds) task_gts["detection"].append(gt_dets) record["num_preds"] = len(det_preds) record["num_gt"] = len(gt_dets) elif task == "lane": preds = parse_atlas_output(gen_text) lane_preds = [p for p in preds if p.get("type") == "lane"] gt_lanes = gt.get("lanes", []) task_preds["lane"].append(lane_preds) task_gts["lane"].append(gt_lanes) record["num_preds"] = len(lane_preds) record["num_gt"] = len(gt_lanes) elif task == "planning": plan_pred = parse_planning_output(gen_text) gt_wps = gt.get("waypoints", []) gt_boxes = gt.get("gt_boxes", []) if plan_pred is None: plan_pred = {"waypoints": [[0.0, 0.0]] * max(len(gt_wps), 6)} record["planning_parse_failed"] = True else: record["planning_parse_failed"] = False task_preds["planning"].append(plan_pred) task_gts["planning"].append({ "waypoints": gt_wps, "gt_boxes": gt_boxes, }) record["has_plan"] = not record["planning_parse_failed"] all_outputs.append(record) sample_count += 1 if (batch_idx + 1) % 50 == 0: logger.info("Processed %d / %d samples", sample_count, len(dataset)) logger.info("Evaluation complete. Total samples: %d", sample_count) results = {} if task_preds["detection"] and task_gts["detection"]: from src.eval.metrics import match_detections thresholds = (0.5, 1.0, 2.0, 4.0) agg_counts = {t: {"tp": 0, "fp": 0, "fn": 0} for t in thresholds} for s_preds, s_gts in zip(task_preds["detection"], task_gts["detection"]): for t in thresholds: matches, fp_list, fn_list = match_detections(s_preds, s_gts, threshold=t) agg_counts[t]["tp"] += len(matches) agg_counts[t]["fp"] += len(fp_list) agg_counts[t]["fn"] += len(fn_list) det_results = {} f1_vals = [] for t in thresholds: tp, fp, fn = agg_counts[t]["tp"], agg_counts[t]["fp"], agg_counts[t]["fn"] p = tp / (tp + fp) if (tp + fp) > 0 else 0.0 r = tp / (tp + fn) if (tp + fn) > 0 else 0.0 f1 = 2 * p * r / (p + r) if (p + r) > 0 else 0.0 det_results[f"P@{t}m"] = p det_results[f"R@{t}m"] = r det_results[f"F1@{t}m"] = f1 f1_vals.append(f1) det_results["F1_avg"] = float(np.mean(f1_vals)) if f1_vals else 0.0 det_results["num_samples"] = len(task_preds["detection"]) results["detection"] = det_results logger.info("Detection results (micro-averaged):") for k, v in sorted(results["detection"].items()): if isinstance(v, float): logger.info(" %s: %.4f", k, v) if task_preds["lane"] and task_gts["lane"]: from src.eval.metrics import match_lanes lane_tp, lane_fp, lane_fn = 0, 0, 0 for pl, gl in zip(task_preds["lane"], task_gts["lane"]): matches, fp_list, fn_list = match_lanes(pl, gl, threshold=1.5) lane_tp += len(matches) lane_fp += len(fp_list) lane_fn += len(fn_list) lane_p = lane_tp / (lane_tp + lane_fp) if (lane_tp + lane_fp) > 0 else 0.0 lane_r = lane_tp / (lane_tp + lane_fn) if (lane_tp + lane_fn) > 0 else 0.0 lane_f1 = 2 * lane_p * lane_r / (lane_p + lane_r) if (lane_p + lane_r) > 0 else 0.0 results["lane"] = { "lane_precision": lane_p, "lane_recall": lane_r, "lane_f1": lane_f1, "lane_tp": lane_tp, "lane_fp": lane_fp, "lane_fn": lane_fn, } logger.info("Lane results (micro-averaged):") for k, v in sorted(results["lane"].items()): if isinstance(v, float): logger.info(" %s: %.4f", k, v) if task_preds["planning"] and task_gts["planning"]: results["planning"] = calculate_planning_metrics( task_preds["planning"], task_gts["planning"], ) n_plan_total = len(task_preds["planning"]) n_plan_failed = sum( 1 for r in all_outputs if r.get("planning_parse_failed", False) ) results["planning"]["num_samples"] = n_plan_total results["planning"]["parse_fail_count"] = n_plan_failed results["planning"]["parse_fail_rate"] = ( n_plan_failed / max(n_plan_total, 1) ) logger.info("Planning results:") for k, v in sorted(results["planning"].items()): if isinstance(v, float): logger.info(" %s: %.4f", k, v) else: logger.info(" %s: %s", k, v) output_path = args.output_json if output_path is None: ckpt_dir = Path(args.checkpoint).parent output_path = str(ckpt_dir / "eval_results.json") with open(output_path, "w") as f: json.dump({ "metrics": results, "num_samples": sample_count, "args": vars(args), "predictions": all_outputs[:100], }, f, indent=2, ensure_ascii=False) logger.info("Results saved to %s", output_path) if __name__ == "__main__": main()