Spaces:
Sleeping
Sleeping
| import os | |
| import numpy as np | |
| import torch | |
| from torch.utils.data import Dataset, DataLoader | |
| import pyvista as pv | |
| import json | |
| import glob | |
| class Data_loader(Dataset): | |
| def __init__(self, cfg, split, epoch_seed=None, mode='train'): | |
| """ | |
| data_dir: parent directory | |
| split: list of int, e.g. [0,1,2,3,4] for train, [5] for val, [6] for test | |
| num_points: number of points to sample per geometry | |
| epoch_seed: seed for random sampling (for training) | |
| mode: 'train', 'val', or 'test' | |
| """ | |
| self.data_dir = cfg.data_dir | |
| self.split = split | |
| self.num_points = cfg.num_points | |
| self.epoch_seed = epoch_seed | |
| self.mode = mode | |
| self.cfg = cfg | |
| self.meshes = [] | |
| self.mesh_names = [] | |
| for idx in split: | |
| # Find folder matching *_{idx} | |
| folder = os.path.join(self.data_dir, f"{idx}") | |
| if not os.path.exists(folder): | |
| raise FileNotFoundError(f"No folder matching '{idx}' found in {self.data_dir}") | |
| # Find file matching *_{idx}.vtp inside the folder | |
| vtp_files = glob.glob(os.path.join(folder, f"{idx}.vtp")) | |
| if not vtp_files: | |
| raise FileNotFoundError(f"No file matching '{idx}.vtp' found in {folder}") | |
| vtp_file = vtp_files[0] | |
| mesh = pv.read(vtp_file) | |
| self.meshes.append(mesh) | |
| self.mesh_names.append(os.path.splitext(os.path.basename(vtp_file))[0]) | |
| # For validation chunking | |
| self.val_indices = None | |
| self.val_chunk_ptr = 0 | |
| with open(cfg.json_file, "r") as f: | |
| self.json_data = json.load(f) | |
| def set_epoch(self, epoch): | |
| self.epoch_seed = epoch | |
| self.val_indices = None | |
| self.val_chunk_ptr = 0 | |
| def __len__(self): | |
| if self.mode == 'train': | |
| return len(self.meshes) | |
| elif self.mode == 'val': | |
| return len(self.meshes) | |
| elif self.mode == 'test': | |
| # Number of chunks = total points in all val meshes // num_points + remainder chunk | |
| total = 0 | |
| for mesh in self.meshes: | |
| return len(self.meshes) | |
| else: | |
| raise ValueError(f"Unknown mode: {self.mode}") | |
| def __getitem__(self, idx): | |
| if self.mode == 'train' or self.mode == 'val': | |
| # Each item is a geometry, sample num_points randomly | |
| mesh = self.meshes[idx] | |
| n_pts = mesh.points.shape[0] | |
| rng = np.random.default_rng(self.epoch_seed+idx) | |
| indices = rng.choice(n_pts, self.num_points, replace=False) | |
| pos = mesh.points | |
| pos = torch.tensor(pos, dtype=torch.float32) | |
| pressure = torch.tensor( mesh["pressure"][indices], dtype=torch.float32).unsqueeze(-1) | |
| if self.cfg.normalization == "std_norm": | |
| target = (pressure - self.json_data["scalars"]["pressure"]["mean"]) / self.json_data["scalars"]["pressure"]["std"] | |
| if self.cfg.diff_input_velocity: | |
| inlet_x_vel = torch.tensor( mesh["inlet_x_velocity"], dtype=torch.float32).unsqueeze(-1) | |
| pos = torch.cat((pos,inlet_x_vel),dim = 1) | |
| if self.cfg.input_normalization == "shift_axis": | |
| coords = pos[:,:3].clone() | |
| # Shift x: set minimum x (front bumper) to 0 | |
| coords[:, 0] = coords[:, 0] - coords[:, 0].min() | |
| # Shift z: set minimum z (ground) to 0 | |
| coords[:, 2] = coords[:, 2] - coords[:, 2].min() | |
| # Shift y: center about 0 (left/right symmetry) | |
| y_center = (coords[:, 1].max() + coords[:, 1].min()) / 2.0 | |
| coords[:, 1] = coords[:, 1] - y_center | |
| pos[:,:3] = coords | |
| if self.cfg.pos_embed_sincos: | |
| if self.cfg.diff_input_velocity: | |
| raise Exception("pos_embed_sincos not supported with diff_input_velocity=True") | |
| input_pos_mins = torch.tensor(self.json_data["mesh_stats"]["min"]) | |
| input_pos_maxs = torch.tensor(self.json_data["mesh_stats"]["max"]) | |
| pos = 1000*(pos - input_pos_mins) / (input_pos_maxs - input_pos_mins) | |
| assert torch.all(pos >= 0) | |
| assert torch.all(pos <= 1000) | |
| pos = pos[indices] | |
| return {"input_pos": pos, "output_feat": target ,"data_id": self.mesh_names[idx]} | |
| elif self.mode == 'test': | |
| # For each mesh in test, scramble all points and return the full mesh | |
| mesh = self.meshes[idx] | |
| n_pts = mesh.points.shape[0] | |
| rng = np.random.default_rng(self.epoch_seed+idx) | |
| indices = rng.permutation(n_pts) | |
| pos = mesh.points | |
| pos = torch.tensor(pos, dtype=torch.float32) | |
| pressure = torch.tensor( mesh["pressure"][indices], dtype=torch.float32).unsqueeze(-1) | |
| if self.cfg.normalization == "std_norm": | |
| target = (pressure - self.json_data["scalars"]["pressure"]["mean"]) / self.json_data["scalars"]["pressure"]["std"] | |
| if hasattr(self.cfg, "diff_input_velocity") and self.cfg.diff_input_velocity: | |
| inlet_x_vel = torch.tensor( mesh["inlet_x_velocity"], dtype=torch.float32).unsqueeze(-1) | |
| pos = torch.cat((pos,inlet_x_vel),dim = 1) | |
| if self.cfg.input_normalization == "shift_axis": | |
| coords = pos[:,:3].clone() | |
| # Shift x: set minimum x (front bumper) to 0 | |
| coords[:, 0] = coords[:, 0] - coords[:, 0].min() | |
| # Shift z: set minimum z (ground) to 0 | |
| coords[:, 2] = coords[:, 2] - coords[:, 2].min() | |
| # Shift y: center about 0 (left/right symmetry) | |
| y_center = (coords[:, 1].max() + coords[:, 1].min()) / 2.0 | |
| coords[:, 1] = coords[:, 1] - y_center | |
| pos[:,:3] = coords | |
| if self.cfg.pos_embed_sincos: | |
| if hasattr(self.cfg, "diff_input_velocity") and self.cfg.diff_input_velocity: | |
| raise Exception("pos_embed_sincos not supported with diff_input_velocity=True") | |
| input_pos_mins = torch.tensor(self.json_data["mesh_stats"]["min"]) | |
| input_pos_maxs = torch.tensor(self.json_data["mesh_stats"]["max"]) | |
| pos = 1000*(pos - input_pos_mins) / (input_pos_maxs - input_pos_mins) | |
| assert torch.all(pos >= 0) | |
| assert torch.all(pos <= 1000) | |
| pos = pos[indices] | |
| return {"input_pos": pos, "output_feat": target ,"data_id": self.mesh_names[idx],"physical_coordinates":mesh.points[indices]} | |
| else: | |
| raise ValueError(f"Unknown mode: {self.mode}") | |
| def get_dataloaders(cfg): | |
| with open(os.path.join(cfg.splits_file, "train.txt")) as f: | |
| train_split = [line.strip() for line in f if line.strip()] | |
| with open(os.path.join(cfg.splits_file, "test.txt")) as f: | |
| val_split = [line.strip() for line in f if line.strip()] | |
| with open(os.path.join(cfg.splits_file, "test.txt")) as f: | |
| test_split = [line.strip() for line in f if line.strip()] | |
| print("Indices in test_split:", test_split[:5]) # Print first 5 indices for verification | |
| train_dataset = Data_loader(cfg, train_split, mode='train') | |
| val_dataset = Data_loader(cfg, val_split, mode='val') | |
| test_dataset = Data_loader(cfg, test_split, mode='test') | |
| train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True) | |
| val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False) | |
| test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False) | |
| return train_loader, val_loader, test_loader | |