AnsysLPFMTrame-App / dataset_loader.py
udbhav
Recreate Trame_app branch with clean history
67fb03c
import os
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
import pyvista as pv
import json
import glob
class Data_loader(Dataset):
def __init__(self, cfg, split, epoch_seed=None, mode='train'):
"""
data_dir: parent directory
split: list of int, e.g. [0,1,2,3,4] for train, [5] for val, [6] for test
num_points: number of points to sample per geometry
epoch_seed: seed for random sampling (for training)
mode: 'train', 'val', or 'test'
"""
self.data_dir = cfg.data_dir
self.split = split
self.num_points = cfg.num_points
self.epoch_seed = epoch_seed
self.mode = mode
self.cfg = cfg
self.meshes = []
self.mesh_names = []
for idx in split:
# Find folder matching *_{idx}
folder = os.path.join(self.data_dir, f"{idx}")
if not os.path.exists(folder):
raise FileNotFoundError(f"No folder matching '{idx}' found in {self.data_dir}")
# Find file matching *_{idx}.vtp inside the folder
vtp_files = glob.glob(os.path.join(folder, f"{idx}.vtp"))
if not vtp_files:
raise FileNotFoundError(f"No file matching '{idx}.vtp' found in {folder}")
vtp_file = vtp_files[0]
mesh = pv.read(vtp_file)
self.meshes.append(mesh)
self.mesh_names.append(os.path.splitext(os.path.basename(vtp_file))[0])
# For validation chunking
self.val_indices = None
self.val_chunk_ptr = 0
with open(cfg.json_file, "r") as f:
self.json_data = json.load(f)
def set_epoch(self, epoch):
self.epoch_seed = epoch
self.val_indices = None
self.val_chunk_ptr = 0
def __len__(self):
if self.mode == 'train':
return len(self.meshes)
elif self.mode == 'val':
return len(self.meshes)
elif self.mode == 'test':
# Number of chunks = total points in all val meshes // num_points + remainder chunk
total = 0
for mesh in self.meshes:
return len(self.meshes)
else:
raise ValueError(f"Unknown mode: {self.mode}")
def __getitem__(self, idx):
if self.mode == 'train' or self.mode == 'val':
# Each item is a geometry, sample num_points randomly
mesh = self.meshes[idx]
n_pts = mesh.points.shape[0]
rng = np.random.default_rng(self.epoch_seed+idx)
indices = rng.choice(n_pts, self.num_points, replace=False)
pos = mesh.points
pos = torch.tensor(pos, dtype=torch.float32)
pressure = torch.tensor( mesh["pressure"][indices], dtype=torch.float32).unsqueeze(-1)
if self.cfg.normalization == "std_norm":
target = (pressure - self.json_data["scalars"]["pressure"]["mean"]) / self.json_data["scalars"]["pressure"]["std"]
if self.cfg.diff_input_velocity:
inlet_x_vel = torch.tensor( mesh["inlet_x_velocity"], dtype=torch.float32).unsqueeze(-1)
pos = torch.cat((pos,inlet_x_vel),dim = 1)
if self.cfg.input_normalization == "shift_axis":
coords = pos[:,:3].clone()
# Shift x: set minimum x (front bumper) to 0
coords[:, 0] = coords[:, 0] - coords[:, 0].min()
# Shift z: set minimum z (ground) to 0
coords[:, 2] = coords[:, 2] - coords[:, 2].min()
# Shift y: center about 0 (left/right symmetry)
y_center = (coords[:, 1].max() + coords[:, 1].min()) / 2.0
coords[:, 1] = coords[:, 1] - y_center
pos[:,:3] = coords
if self.cfg.pos_embed_sincos:
if self.cfg.diff_input_velocity:
raise Exception("pos_embed_sincos not supported with diff_input_velocity=True")
input_pos_mins = torch.tensor(self.json_data["mesh_stats"]["min"])
input_pos_maxs = torch.tensor(self.json_data["mesh_stats"]["max"])
pos = 1000*(pos - input_pos_mins) / (input_pos_maxs - input_pos_mins)
assert torch.all(pos >= 0)
assert torch.all(pos <= 1000)
pos = pos[indices]
return {"input_pos": pos, "output_feat": target ,"data_id": self.mesh_names[idx]}
elif self.mode == 'test':
# For each mesh in test, scramble all points and return the full mesh
mesh = self.meshes[idx]
n_pts = mesh.points.shape[0]
rng = np.random.default_rng(self.epoch_seed+idx)
indices = rng.permutation(n_pts)
pos = mesh.points
pos = torch.tensor(pos, dtype=torch.float32)
pressure = torch.tensor( mesh["pressure"][indices], dtype=torch.float32).unsqueeze(-1)
if self.cfg.normalization == "std_norm":
target = (pressure - self.json_data["scalars"]["pressure"]["mean"]) / self.json_data["scalars"]["pressure"]["std"]
if hasattr(self.cfg, "diff_input_velocity") and self.cfg.diff_input_velocity:
inlet_x_vel = torch.tensor( mesh["inlet_x_velocity"], dtype=torch.float32).unsqueeze(-1)
pos = torch.cat((pos,inlet_x_vel),dim = 1)
if self.cfg.input_normalization == "shift_axis":
coords = pos[:,:3].clone()
# Shift x: set minimum x (front bumper) to 0
coords[:, 0] = coords[:, 0] - coords[:, 0].min()
# Shift z: set minimum z (ground) to 0
coords[:, 2] = coords[:, 2] - coords[:, 2].min()
# Shift y: center about 0 (left/right symmetry)
y_center = (coords[:, 1].max() + coords[:, 1].min()) / 2.0
coords[:, 1] = coords[:, 1] - y_center
pos[:,:3] = coords
if self.cfg.pos_embed_sincos:
if hasattr(self.cfg, "diff_input_velocity") and self.cfg.diff_input_velocity:
raise Exception("pos_embed_sincos not supported with diff_input_velocity=True")
input_pos_mins = torch.tensor(self.json_data["mesh_stats"]["min"])
input_pos_maxs = torch.tensor(self.json_data["mesh_stats"]["max"])
pos = 1000*(pos - input_pos_mins) / (input_pos_maxs - input_pos_mins)
assert torch.all(pos >= 0)
assert torch.all(pos <= 1000)
pos = pos[indices]
return {"input_pos": pos, "output_feat": target ,"data_id": self.mesh_names[idx],"physical_coordinates":mesh.points[indices]}
else:
raise ValueError(f"Unknown mode: {self.mode}")
def get_dataloaders(cfg):
with open(os.path.join(cfg.splits_file, "train.txt")) as f:
train_split = [line.strip() for line in f if line.strip()]
with open(os.path.join(cfg.splits_file, "test.txt")) as f:
val_split = [line.strip() for line in f if line.strip()]
with open(os.path.join(cfg.splits_file, "test.txt")) as f:
test_split = [line.strip() for line in f if line.strip()]
print("Indices in test_split:", test_split[:5]) # Print first 5 indices for verification
train_dataset = Data_loader(cfg, train_split, mode='train')
val_dataset = Data_loader(cfg, val_split, mode='val')
test_dataset = Data_loader(cfg, test_split, mode='test')
train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)
return train_loader, val_loader, test_loader