import numpy as np import time, json, os, math import torch import torch.nn as nn import torch.nn.functional as F from tqdm import tqdm import json from torch.utils.tensorboard import SummaryWriter from omegaconf import OmegaConf from lion_pytorch import Lion import glob from utils.vtk_writer import vtk_writer # For full mesh load and then sample in each training iteration from dataset_loader import get_dataloaders from torch.cuda.amp import GradScaler import re import glob import h5py import pyvista as pv from omegaconf import OmegaConf import models def save_checkpoint(model, optimizer, scheduler, epoch, best_val_loss, val_MSE_list, cfg, path, accelerator, log_dir=None): """Save a complete training checkpoint.""" if accelerator.is_main_process: rng_state = torch.get_rng_state() cuda_rng_states = None if torch.cuda.is_available(): cuda_rng_states = [] for i in range(torch.cuda.device_count()): cuda_rng_states.append(torch.cuda.get_rng_state(device=i)) checkpoint = { 'epoch': epoch, 'model_state_dict': accelerator.unwrap_model(model).state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'scheduler_state_dict': scheduler.state_dict(), 'best_val_loss': best_val_loss, 'val_MSE_list': val_MSE_list, 'cfg': cfg, 'log_dir': log_dir, 'rng_state': rng_state, 'cuda_rng_states': cuda_rng_states, } checkpoint_path = os.path.join(path, f'checkpoint_epoch_{epoch}.pt') torch.save(checkpoint, checkpoint_path) # Also save as latest checkpoint latest_checkpoint_path = os.path.join(path, 'latest_checkpoint.pt') torch.save(checkpoint, latest_checkpoint_path) print(f"Checkpoint saved at epoch {epoch}") def load_checkpoint(path, model, optimizer, scheduler, accelerator): """Load the latest checkpoint and return training state.""" latest_checkpoint_path = os.path.join(path, 'latest_checkpoint.pt') if not os.path.exists(latest_checkpoint_path): print("No checkpoint found, starting from scratch") return None, 0, 1e5, [], None print(f"Loading checkpoint from {latest_checkpoint_path}") checkpoint = torch.load(latest_checkpoint_path, map_location='cpu') # Load model state unwrapped_model = accelerator.unwrap_model(model) unwrapped_model.load_state_dict(checkpoint['model_state_dict']) # Load optimizer and scheduler states optimizer.load_state_dict(checkpoint['optimizer_state_dict']) scheduler.load_state_dict(checkpoint['scheduler_state_dict']) # Restore random states for reproducibility with error handling try: if 'rng_state' in checkpoint and checkpoint['rng_state'] is not None: torch.set_rng_state(checkpoint['rng_state']) except Exception as e: print(f"Warning: Could not restore CPU RNG state: {e}") try: # Handle both old and new checkpoint formats cuda_rng_key = 'cuda_rng_states' if 'cuda_rng_states' in checkpoint else 'cuda_rng_state' if cuda_rng_key in checkpoint and checkpoint[cuda_rng_key] is not None and torch.cuda.is_available(): cuda_rng_states = checkpoint[cuda_rng_key] if isinstance(cuda_rng_states, list) and len(cuda_rng_states) > 0: # Set RNG state for each device for i, state in enumerate(cuda_rng_states): if i < torch.cuda.device_count() and state is not None: torch.cuda.set_rng_state(state, device=i) except Exception as e: print(f"Warning: Could not restore CUDA RNG state: {e}") start_epoch = checkpoint['epoch'] + 1 best_val_loss = checkpoint['best_val_loss'] val_MSE_list = checkpoint['val_MSE_list'] log_dir = checkpoint.get('log_dir', None) print(f"Resumed from epoch {checkpoint['epoch']}, best val loss: {best_val_loss:.6f}") return checkpoint, start_epoch, best_val_loss, val_MSE_list, log_dir def cleanup_old_checkpoints(path, keep_last=3): """Remove old checkpoint files, keeping only the most recent ones.""" checkpoint_pattern = os.path.join(path, '*_epoch_*.pt') checkpoint_files = glob.glob(checkpoint_pattern) if len(checkpoint_files) <= keep_last: return # Sort by modification time and remove oldest checkpoint_files.sort(key=os.path.getmtime) files_to_remove = checkpoint_files[:-keep_last] for file_path in files_to_remove: try: os.remove(file_path) print(f"Removed old checkpoint: {os.path.basename(file_path)}") except OSError: pass def train(model, train_loader, optimizer, scheduler, criterion, cfg, accelerator, scaler): model.train() losses_press = 0.0 for data in train_loader: optimizer.zero_grad() targets = data['output_feat'] if cfg.mixed_precision: with torch.autocast(device_type = accelerator.device.type): out = model(data) total_loss = criterion(out, targets) scaler.scale(total_loss).backward() scaler.unscale_(optimizer) if cfg.max_grad_norm is not None: accelerator.clip_grad_norm_(model.parameters(), cfg.max_grad_norm) scaler.step(optimizer) scaler.update() else: out = model(data) total_loss = criterion(out, targets) accelerator.backward(total_loss) if cfg.max_grad_norm is not None: accelerator.clip_grad_norm_(model.parameters(), cfg.max_grad_norm) optimizer.step() # Only step OneCycleLR every batch if cfg.scheduler == "OneCycleLR": scheduler.step() losses_press += total_loss.item() return losses_press / len(train_loader) @torch.no_grad() def val(model, val_loader, criterion, cfg, accelerator): model.eval() losses_press = 0.0 for data in val_loader: targets = data['output_feat'] out = model(data) # Loss computation in FP32 for maximum stability targets = targets.float() # Ensure FP32 out = out.float() # Ensure FP32 total_loss = criterion(out, targets) losses_press += total_loss.item() return losses_press / len(val_loader) def test_model(model, test_dataloader, criterion, path, cfg, accelerator): """Test the model and calculate metrics.""" ## You reported test models everywhere? Complete evaluation? model.eval() total_mse = 0.0 total_mae = 0.0 total_huber = 0.0 total_rel_l2 = 0.0 total_rel_l1 = 0.0 total_mse_list = [] total_mae_list = [] total_huber_list = [] total_rel_l2_list = [] total_rel_l1_list = [] r_2_squared_list = [] data_id_list = [] total_inference_time = 0.0 num_batches = 0 if cfg.normalization == "std_norm": with open(cfg.json_file, 'r') as f: json_data = json.load(f) pressure_mean = torch.tensor(json_data["scalars"]["pressure"]["mean"], device=accelerator.device) pressure_std = torch.tensor(json_data["scalars"]["pressure"]["std"], device=accelerator.device) # Store outputs and targets on all processes all_outputs = [] all_targets = [] all_physical_coordinates = [] latent_dict = {} with torch.no_grad(): for data in tqdm(test_dataloader, desc="[Testing]", disable=not accelerator.is_local_main_process): start_time = time.time() targets = data['output_feat'] if cfg.chunked_eval: input_pos = data['input_pos'] B, N, C = input_pos.shape chunk_size = cfg.num_points outputs = [] latent_outputs = [] for i in range(0, N, chunk_size): # start with the raw slice chunk = input_pos[:, i:i+chunk_size, :] # (B, n_valid, C) n_valid = chunk.shape[1] # Pad if last chunk is short if n_valid < chunk_size: shape_diff = chunk_size - n_valid # Wrap from the beginning to make a full chunk pad = input_pos[:, :shape_diff, :] # (B, shape_diff, C) chunk = torch.cat([chunk, pad], dim=1) # (B, chunk_size, C) data['input_pos'] = chunk if cfg.save_latent: out_chunk, latent = model(data) # (B, chunk_size, D) else: out_chunk = model(data) # (B, chunk_size, D) # Keep only the valid part that corresponds to real points out_chunk = out_chunk[:, :n_valid, :] # (B, n_valid, D) latent = latent[:, :n_valid, :] else: data['input_pos'] = chunk if cfg.save_latent: out_chunk, latent = model(data) # (B, chunk_size, D) else: out_chunk = model(data) # (B, chunk_size, D) outputs.append(out_chunk) latent_outputs.append(latent) if cfg.save_latent else None outputs = torch.cat(outputs, dim=1) # (B, N, 3) latent_outputs = torch.cat(latent_outputs, dim=1) if cfg.save_latent else None data_id = data['data_id'][0] # Move latent_outputs to CPU and convert to numpy latent_np = latent_outputs.cpu().numpy() if cfg.save_latent else None latent_dict[data_id] = latent_np if cfg.save_latent else None else: if cfg.save_latent: outputs, latent = model(data) else: outputs = model(data) latent_outputs = torch.cat(latent_outputs, dim=1) if cfg.save_latent else None data_id = data['data_id'][0] # Move latent_outputs to CPU and convert to numpy latent_np = latent_outputs.cpu().numpy() if cfg.save_latent else None latent_dict[data_id] = latent_np if cfg.save_latent else None # Metric computations in FP32 for maximum stability targets = targets.float() # Ensure FP32 outputs = outputs.float() # Ensure FP32 if cfg.physical_scale_for_test == True: targets[:,:,0] = targets[:,:,0] * pressure_std + pressure_mean outputs[:,:,0] = outputs[:,:,0] * pressure_std + pressure_mean inference_time = time.time() - start_time total_inference_time += inference_time # Compute all relevant losses and metrics for documentation and analysis criterion_mse = nn.MSELoss() criterion_mae = nn.L1Loss() criterion_huber = nn.HuberLoss(delta=1.0) mse = criterion_mse(outputs, targets) mae = criterion_mae(outputs, targets) huber = criterion_huber(outputs, targets) # Relative L2 error: mean over batch of (L2 norm of error / L2 norm of target) rel_l2 = torch.mean(torch.norm(outputs.squeeze(-1) - targets.squeeze(-1), p=2, dim=-1) / torch.norm(targets.squeeze(-1), p=2, dim=-1)) # Relative L1 error: mean over batch of (L1 norm of error / L1 norm of target) rel_l1 = torch.mean(torch.norm(outputs.squeeze(-1) - targets.squeeze(-1), p=1, dim=-1) / torch.norm(targets.squeeze(-1), p=1, dim=-1)) ss_tot = torch.sum((outputs - torch.mean(targets)) ** 2) ss_res = torch.sum((targets - outputs) ** 2) r_squared = 1 - (ss_res / ss_tot) if ss_tot > 0 else 0 total_mse_list.append(mse.item()) total_mae_list.append(mae.item()) total_huber_list.append(huber.item()) total_rel_l2_list.append(rel_l2.item()) total_rel_l1_list.append(rel_l1.item()) r_2_squared_list.append(r_squared.item()) data_id_list.append(data['data_id'][0]) # Add data ID to track which test case total_mse += mse total_mae += mae total_huber += huber total_rel_l2 += rel_l2 total_rel_l1 += rel_l1 num_batches += 1 # Store outputs and targets on all processes for later aggregation and R² computation all_outputs.append(outputs.cpu()) all_targets.append(targets.cpu()) all_physical_coordinates.append(data['physical_coordinates'].cpu()) path_vtk = path + "/vtk_files" # Save VTK files for each data(if any data exists) if len(all_outputs) > 0 and len(all_targets) > 0 and len(all_physical_coordinates) > 0: if cfg.physical_scale_for_test == False: targets[:,:,0] = targets[:,:,0] * pressure_std + pressure_mean outputs[:,:,0] = outputs[:,:,0] * pressure_std + pressure_mean try: vtk_writer(outputs, targets, data['physical_coordinates'].cpu(), path_vtk, prefix=data["data_id"][0], config_json_path=cfg.json_file) except Exception as e: print(f"[Warning] Could not save VTK files: {e}") # Clear references to tensors del outputs, targets, mse, mae, huber, rel_l2, rel_l1 # Save latent dictionary as an HDF5 file if cfg.save_latent: if accelerator.is_main_process: latent_file = os.path.join(path, f'latent_{cfg.test_name}.h5') with h5py.File(latent_file, 'w') as f: for k, v in latent_dict.items(): f.create_dataset(str(k), data=v) print(f"Saved latent representations to {latent_file}") metrics_list = { "data_id_list": data_id_list, "total_mse_list": total_mse_list, "total_mae_list": total_mae_list, "total_huber_list": total_huber_list, "total_rel_l2_list": total_rel_l2_list, "total_rel_l1_list": total_rel_l1_list, "r_2_squared_list": r_2_squared_list, } # Save metrics_list as a JSON file for per-batch analysis metrics_list_file = os.path.join(path, 'test_metrics_list.txt') with open(metrics_list_file, 'w') as f: json.dump(metrics_list, f, indent=2) # Create test directory for individual metric files test_dir = os.path.join(path, 'test') os.makedirs(test_dir, exist_ok=True) # Save individual metric files with data IDs as first column individual_metrics = { 'mse': total_mse_list, 'mae': total_mae_list, 'huber': total_huber_list, 'rel_l2': total_rel_l2_list, 'rel_l1': total_rel_l1_list, 'r_squared': r_2_squared_list } for metric_name, metric_values in individual_metrics.items(): metric_file = os.path.join(test_dir, f'{metric_name}.txt') with open(metric_file, 'w') as f: for data_id, value in zip(data_id_list, metric_values): f.write(f"{data_id} {value:.6f}\n") # Convert to tensors for reduction metrics = { "total_mse": total_mse, "total_mae": total_mae, "total_huber": total_huber, "total_rel_l2": total_rel_l2, "total_rel_l1": total_rel_l1, "num_batches": torch.tensor(num_batches, device=accelerator.device), "total_inference_time": torch.tensor(total_inference_time, device=accelerator.device) } # Gather metrics from all processes gathered_metrics = accelerator.gather(metrics) # Only calculate averages if we have data if gathered_metrics["num_batches"].sum().item() > 0: total_batches = gathered_metrics["num_batches"].sum().item() avg_mse = gathered_metrics["total_mse"].sum().item() / total_batches avg_mae = gathered_metrics["total_mae"].sum().item() / total_batches avg_huber = gathered_metrics["total_huber"].sum().item() / total_batches avg_rel_l2 = gathered_metrics["total_rel_l2"].sum().item() / total_batches avg_rel_l1 = gathered_metrics["total_rel_l1"].sum().item() / total_batches total_inference_time = gathered_metrics["total_inference_time"].sum().item() avg_inference_time = total_inference_time / total_batches # Gather all outputs and targets from all processes all_outputs = torch.cat(all_outputs, dim=1) all_targets = torch.cat(all_targets, dim=1) # Gather outputs and targets across processes all_outputs = accelerator.gather(all_outputs.to(accelerator.device)) all_targets = accelerator.gather(all_targets.to(accelerator.device)) # Calculate R² score using complete dataset if accelerator.is_main_process: all_outputs = all_outputs.to(torch.float32).cpu().numpy() all_targets = all_targets.to(torch.float32).cpu().numpy() ss_tot = np.sum((all_targets - np.mean(all_targets)) ** 2) ss_res = np.sum((all_targets - all_outputs) ** 2) r_squared = 1 - (ss_res / ss_tot) if ss_tot > 0 else 0 print(f"Test MSE: {avg_mse:.6f}, Test MAE: {avg_mae:.6f}, Test Huber: {avg_huber:.6f}, R²: {r_squared:.4f}") print(f"Relative L2 Error: {avg_rel_l2:.6f}, Relative L1 Error: {avg_rel_l1:.6f}") print(f"Average inference time per batch: {avg_inference_time:.4f}s") print(f"Total inference time: {total_inference_time:.2f}s for {total_batches} batches") # Save metrics to a text file metrics_file = os.path.join(path, 'test_metrics.txt') with open(metrics_file, 'w') as f: f.write(f"Test MSE: {avg_mse:.6f}\n") f.write(f"Test MAE: {avg_mae:.6f}\n") f.write(f"Test Huber: {avg_huber:.6f}\n") f.write(f"R2 Score: {r_squared:.6f}\n") f.write(f"Relative L2 Error: {avg_rel_l2:.6f}\n") f.write(f"Relative L1 Error: {avg_rel_l1:.6f}\n") f.write(f"Average inference time per batch: {avg_inference_time:.4f}s\n") f.write(f"Total inference time: {total_inference_time:.2f}s for {total_batches} batches\n") # Save outputs and targets as .npy files for further analysis #np.save(os.path.join(path, 'test_outputs.npy'), all_outputs) #np.save(os.path.join(path, 'test_targets.npy'), all_targets) else: r_squared = 0.0 # Will be overwritten by broadcast else: print("Warning: No data in test_dataloader") avg_mse = avg_mae = avg_huber = avg_rel_l2 = avg_rel_l1 = r_squared = 0.0 # Clear GPU cache after all testing torch.cuda.empty_cache() return avg_mse, avg_mae, avg_huber, avg_rel_l2, avg_rel_l1, r_squared, avg_inference_time def train_main(model, path, cfg, accelerator): train_loader, val_loader, test_loader = get_dataloaders(cfg) if accelerator.is_main_process: print( f"Data loaded: {len(train_loader)} training batches, " f"{len(val_loader)} validation batches, " f"{len(test_loader)} test batches") #Select optimizer if cfg.optimizer.type == 'Adam': optimizer = torch.optim.Adam(model.parameters(), lr=cfg.lr, weight_decay=1e-4) elif cfg.optimizer.type == 'AdamW': optimizer = torch.optim.AdamW(model.parameters(), lr=cfg.lr, weight_decay=0.05) elif cfg.optimizer.type == 'LION': optimizer = Lion(model.parameters(), lr=cfg.lr, weight_decay=0.05) #Select scheduler if cfg.scheduler == "ReduceLROnPlateau": scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=10, factor=0.1, verbose=True) elif cfg.scheduler == "LinearWarmupCosineAnnealingLR": warmup_epochs = int(cfg.epochs * 0.05) # Convert back to epochs # Linear warmup scheduler warmup_scheduler = torch.optim.lr_scheduler.LinearLR( optimizer, start_factor=1e-6, # Start very low (almost zero) end_factor=1.0, # End at base lr total_iters=warmup_epochs ) # Cosine decay scheduler cosine_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( optimizer, T_max=cfg.epochs - warmup_epochs, # Remaining epochs eta_min=1e-6 # End at 1e-6 learning rate ) # Combine schedulers scheduler = torch.optim.lr_scheduler.SequentialLR( optimizer, schedulers=[warmup_scheduler, cosine_scheduler], milestones=[warmup_epochs] ) else: scheduler = torch.optim.lr_scheduler.OneCycleLR( optimizer, pct_start=0.05, max_lr=cfg.lr, total_steps = len(train_loader) * cfg.epochs ) if cfg.loss_type == "mse": criterion = nn.MSELoss() elif cfg.loss_type == "mae": criterion = nn.L1Loss() elif cfg.loss_type == "huber": criterion = nn.HuberLoss(delta=1.0) else: raise ValueError(f"Unknown loss_type: {cfg.loss_type}") scaler = GradScaler() model, optimizer, train_loader, val_loader, test_loader, scheduler, scaler = accelerator.prepare( model, optimizer, train_loader, val_loader, test_loader, scheduler, scaler) best_epoch = 0 # Try to load checkpoint before evaluation or training checkpoint, start_epoch, best_val_loss, val_MSE_list, resumed_log_dir = load_checkpoint( path, model, optimizer, scheduler, accelerator) # Before the training loop, after loading checkpoint: # or -1 if you want to indicate "not set" if cfg.eval: # For evaluation, try to load from checkpoint first, then fall back to best_model.pt if (cfg.train_ckpt_load): print("Using model from checkpoint for evaluation") else: # Load the saved state dict and create a fresh model load_path = f'metrics/{cfg.project_name}/{cfg.model}_{cfg.test_name}' # Find all best_model_epoch_*.pt files and get the epoch number from the last one pattern = os.path.join(os.getcwd(), load_path.lstrip('/'), 'best_case', 'best_model_epoch_*.pt') best_model_files = glob.glob(pattern) if not best_model_files: raise FileNotFoundError(f"No best_model_epoch_*.pt files found in {os.path.join(load_path, 'best_case')}") # Extract epoch numbers epoch_numbers = [] for fname in best_model_files: match = re.search(r'best_model_epoch_(\d+)\.pt', os.path.basename(fname)) if match: epoch_numbers.append(int(match.group(1))) if not epoch_numbers: raise ValueError("No epoch numbers found in best_model_epoch_*.pt filenames") last_best_epoch = max(epoch_numbers) state_dict = torch.load(os.path.join(load_path, 'best_case', f'best_model_epoch_{last_best_epoch}.pt')) unwrapped_model = accelerator.unwrap_model(model) unwrapped_model.load_state_dict(state_dict) model = accelerator.prepare(unwrapped_model) path = os.path.join(path, "best_case") # Update path to point to best model directory print("Using best model for evaluation at epoch", last_best_epoch) best_mse, best_mae, best_huber, best_rel_l2, best_rel_l1, best_r2, inf_time = test_model(model, test_loader, criterion, path, cfg, accelerator) else: # Calculate total parameters total_params = sum(p.numel() for p in model.parameters()) trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) # Reset memory stats before training torch.cuda.reset_peak_memory_stats() torch.cuda.empty_cache() # Clear any existing cached memory start = time.time() # Only initialize tensorboard on the main process if accelerator.is_main_process: # Create a descriptive run name using model type and timestamp if resumed_log_dir is not None: # Resume logging to the same directory log_dir = resumed_log_dir print(f"Resuming tensorboard logging to: {log_dir}") else: # Create new log directory run_name = f"{cfg.model}_{cfg.test_name}_{time.strftime('%Y%m%d_%H%M%S')}" project_name = os.path.join("tensorboard_logs", f"{cfg.project_name}") log_dir = os.path.join(project_name, run_name) print(f"Starting new tensorboard logging to: {log_dir}") writer = SummaryWriter(log_dir) # Add full config (only if starting fresh) if checkpoint is None: config_text = "```yaml\n" # Using yaml format for better readability config_text += OmegaConf.to_yaml(cfg) config_text += "```" writer.add_text('hyperparameters/full_config', config_text) pbar_train = tqdm(range(start_epoch, cfg.epochs), position=0) pbar_train.set_description(f"Training (resumed from epoch {start_epoch})" if checkpoint else "Training") else: writer = None log_dir = None pbar_train = tqdm(range(start_epoch, cfg.epochs), position=0) # Add checkpoint saving frequency to config (default every 10 epochs) checkpoint_freq = getattr(cfg, 'checkpoint_freq', 10) for epoch in pbar_train: train_loss = train(model, train_loader, optimizer, scheduler, criterion, cfg, accelerator, scaler) if cfg.val_iter is not None and (epoch == cfg.epochs - 1 or epoch % cfg.val_iter == 0): val_loss_MSE = val(model, val_loader,criterion, cfg, accelerator) if cfg.scheduler == "ReduceLROnPlateau": scheduler.step(val_loss_MSE) elif cfg.scheduler == "LinearWarmupCosineAnnealingLR": scheduler.step() val_MSE_list.append(val_loss_MSE) if accelerator.is_main_process: # Get peak GPU memory in GB peak_mem_gb = torch.cuda.max_memory_allocated() / (1024 * 1024 * 1024) # Log metrics to tensorboard writer.add_scalar('Loss/train_MSE', train_loss, epoch) writer.add_scalar('Loss/val_MSE', val_loss_MSE, epoch) writer.add_scalar('Learning_rate', scheduler.get_last_lr()[0], epoch) writer.add_scalar('Memory/GPU', peak_mem_gb, epoch) with open(os.path.join(path,'MSE.json'), 'w') as f: json.dump(val_MSE_list, f, indent=2) pbar_train.set_postfix({ 'train_loss': train_loss, 'val_loss': val_loss_MSE, 'lr': scheduler.get_last_lr()[0], 'mem_gb': f'{peak_mem_gb:.1f}' }) if val_loss_MSE < best_val_loss: best_val_loss = val_loss_MSE unwrapped_model = accelerator.unwrap_model(model) os.makedirs(os.path.join(path, 'best_case'), exist_ok=True) best_epoch = epoch # Save the best model state_dict cleanup_old_checkpoints(os.path.join(path, 'best_case'), keep_last=1) torch.save(unwrapped_model.state_dict(), os.path.join(path, f'best_case/best_model_epoch_{best_epoch}.pt')) print("saving best model at epoch", epoch) elif accelerator.is_main_process: # Simple progress display without validation peak_mem_gb = torch.cuda.max_memory_allocated() / (1024 * 1024 * 1024) pbar_train.set_postfix({ 'train_loss': train_loss, 'mem_gb': f'{peak_mem_gb:.1f}' }) # Save checkpoint periodically if accelerator.is_main_process and (epoch % checkpoint_freq == 0 or epoch == cfg.epochs - 1): save_checkpoint(model, optimizer, scheduler, epoch, best_val_loss, val_MSE_list, cfg, path, accelerator, log_dir) # Clean up old checkpoints to save disk space cleanup_old_checkpoints(path, keep_last=3) end = time.time() time_elapsed = end - start # Get final peak memory for reporting if accelerator.is_main_process: peak_mem_gb = torch.cuda.max_memory_allocated() / (1024 * 1024 * 1024) # Reset memory stats before final evaluation torch.cuda.reset_peak_memory_stats() torch.cuda.empty_cache() # Save final checkpoint BEFORE loading best model for evaluation if accelerator.is_main_process: save_checkpoint(model, optimizer, scheduler, cfg.epochs - 1, best_val_loss, val_MSE_list, cfg, path, accelerator, log_dir) # Test final model (last epoch) final_mse, final_mae, final_huber, final_rel_l2, final_rel_l1, final_r2, inf_time = test_model( model, test_loader, criterion, path, cfg, accelerator) # Get peak memory during testing if accelerator.is_main_process: test_peak_mem_gb = torch.cuda.max_memory_allocated() / (1024 * 1024 * 1024) # Create metrics text for final model metrics_text = f"Test MSE: {final_mse:.6f}\n" metrics_text += f"Test MAE: {final_mae:.6f}\n" metrics_text += f"Test Huber: {final_huber:.6f}\n" metrics_text += f"Test RelL1: {final_rel_l1:.6f}\n" metrics_text += f"Test RelL2: {final_rel_l2:.6f}\n" metrics_text += f"Test R2: {final_r2:.6f}\n" metrics_text += f"Inference time: {inf_time:.6f}s\n" metrics_text += f"Total training time: {time_elapsed:.2f}s\n" metrics_text += f"Average epoch time: {time_elapsed/cfg.epochs:.2f}s\n" metrics_text += f"Total parameters: {total_params}\n" metrics_text += f"Trainable parameters: {trainable_params}\n" metrics_text += f"Peak GPU memory usage:\n" metrics_text += f" - During training: {peak_mem_gb:.1f} GB\n" metrics_text += f" - During testing: {test_peak_mem_gb:.1f} GB\n" # Write to file and add to tensorboard metrics_file = os.path.join(path, 'final_test_metrics.txt') with open(metrics_file, 'w') as f: f.write(metrics_text) # Add final metrics to tensorboard as text (replace \n with markdown line break) writer.add_text('metrics/final_metrics', metrics_text.replace('\n', ' \n')) # --- Log per-batch test metrics for final model --- metrics_list_file = os.path.join(path, 'test_metrics_list.txt') if os.path.exists(metrics_list_file): with open(metrics_list_file, 'r') as f: metrics_list = json.load(f) for metric_name, values in metrics_list.items(): # Only log numeric metrics if all(isinstance(v, (int, float)) for v in values): for i, v in enumerate(values): writer.add_scalar(f'per_batch_test_metrics/final/{metric_name}', float(v), i) # Load the best model using state_dict for compatibility (into a separate model instance) from copy import deepcopy best_model = deepcopy(model) state_dict = torch.load(os.path.join(path, f'best_case/best_model_epoch_{best_epoch}.pt')) unwrapped_best_model = accelerator.unwrap_model(best_model) unwrapped_best_model.load_state_dict(state_dict) best_model = accelerator.prepare(unwrapped_best_model) # Test the best model path_best = os.path.join(path, 'best_case') # Do not overwrite path for final model logging best_mse, best_mae, best_huber, best_rel_l2, best_rel_l1, best_r2, inf_time = test_model( best_model, test_loader, criterion, path_best, cfg, accelerator) if accelerator.is_main_process: # Create metrics text for best model metrics_text = f"Test MSE: {best_mse:.6f}\n" metrics_text += f"Test MAE: {best_mae:.6f}\n" metrics_text += f"Test Huber: {best_huber:.6f}\n" metrics_text += f"Test RelL1: {best_rel_l1:.6f}\n" metrics_text += f"Test RelL2: {best_rel_l2:.6f}\n" metrics_text += f"Test R2: {best_r2:.6f}\n" metrics_text += f"Inference time: {inf_time:.6f}s\n" metrics_text += f"Total training time: {time_elapsed:.2f}s\n" metrics_text += f"Average epoch time: {time_elapsed/cfg.epochs:.2f}s\n" metrics_text += f"Total parameters: {total_params}\n" metrics_text += f"Trainable parameters: {trainable_params}\n" metrics_text += f"Peak GPU memory usage:\n" metrics_text += f" - During training: {peak_mem_gb:.1f} GB\n" metrics_text += f" - During testing: {test_peak_mem_gb:.1f} GB\n" # Write to file and add to tensorboard metrics_file = os.path.join(path_best, 'best_test_metrics.txt') with open(metrics_file, 'w') as f: f.write(metrics_text) # Add best metrics to tensorboard as text (replace \n with markdown line break) writer.add_text('metrics/best_metrics', metrics_text.replace('\n', ' \n')) # --- Log per-batch test metrics for best-case model --- metrics_list_file = os.path.join(path_best, 'test_metrics_list.txt') if os.path.exists(metrics_list_file): with open(metrics_list_file, 'r') as f: metrics_list = json.load(f) for metric_name, values in metrics_list.items(): for i, v in enumerate(values): # skip if not numeric (e.g., strings like "cadillac_73") if isinstance(v, (int, float)): writer.add_scalar(f'per_batch_test_metrics/best/{metric_name}', float(v), i) print(f"\nFinal model metrics - MSE: {final_mse:.6f}, MAE: {final_mae:.6f}, huber: {final_huber:.6f}, R²: {final_r2:.4f}") print(f"Best model metrics - MSE: {best_mse:.6f}, MAE: {best_mae:.6f}, huber: {best_huber:.6f}, R²: {best_r2:.4f}") # Close tensorboard writer writer.close() @torch.no_grad() def get_single_latent(mesh_path: str, config_path: str, device: str = None, custom_velocity: float = None, use_training_velocity: bool = True, model=None) -> np.ndarray: """ Load a trained model and extract the latent representation for a single mesh. Args: mesh_path: Path to a mesh file (.vtp/.vtm/.stl/.vtk/.ply/.obj/.vtu) config_path: Path to the training config.yaml used for this model device: 'cuda' or 'cpu'. Defaults to cuda if available custom_velocity: Custom velocity value to use (overrides mesh velocity if provided) use_training_velocity: Whether to use velocity from mesh data (True) or default velocity (False) Returns: Numpy array of shape [N_points, hidden_dim] with the latent vectors """ import glob import re # Set deterministic behavior for reproducible results torch.manual_seed(42) torch.cuda.manual_seed(42) torch.cuda.manual_seed_all(42) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False np.random.seed(42) # Load config and device cfg = OmegaConf.load(config_path) use_device = torch.device(device if device is not None else ("cuda" if torch.cuda.is_available() else "cpu")) # Ensure latent is returned from the model cfg.save_latent = True # Build and load model if model is None: if not hasattr(models, cfg.model): raise ValueError(f"Model '{cfg.model}' not found in models module") model_cls = getattr(models, cfg.model) model = model_cls(cfg).to(use_device) # Auto-find the best checkpoint (same logic as in train.py) load_path = f'metrics/{cfg.project_name}/{cfg.model}_{cfg.test_name}' # Find all best_model_epoch_*.pt files and get the epoch number from the last one pattern = os.path.join(os.path.dirname(os.path.dirname(os.getcwd())), load_path.lstrip('/'), 'best_case', 'best_model_epoch_*.pt') best_model_files = glob.glob(pattern) print("pattern", pattern) print("best model files", best_model_files) if not best_model_files: raise FileNotFoundError(f"No best_model_epoch_*.pt files found in {pattern}") # Extract epoch numbers epoch_numbers = [] for fname in best_model_files: match = re.search(r'best_model_epoch_(\d+)\.pt', os.path.basename(fname)) if match: epoch_numbers.append(int(match.group(1))) if not epoch_numbers: raise ValueError("No epoch numbers found in best_model_epoch_*.pt filenames") last_best_epoch = max(epoch_numbers) checkpoint_path = os.path.join(os.path.dirname(os.path.dirname(os.getcwd())), load_path, 'best_case', f'best_model_epoch_{last_best_epoch}.pt') if not os.path.exists(checkpoint_path): raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}") state = torch.load(checkpoint_path, map_location=use_device) model.load_state_dict(state) model.eval() else: # Use the provided model and ensure it's in eval mode model.eval() # Read mesh if not os.path.exists(mesh_path): raise FileNotFoundError(f"Mesh not found: {mesh_path}") mesh = pv.read(mesh_path) points_np = np.asarray(mesh.points, dtype=np.float32) # Deterministic permutation of all points (matching test mode) # Use a fixed seed to ensure reproducible results np.random.seed(42) indices = np.random.permutation(points_np.shape[0]) # Base positions tensor pos = torch.from_numpy(points_np) # Optional: append inlet velocity as extra channel if cfg.diff_input_velocity: print("mesh.point_data.keys()", mesh.point_data.keys()) if use_training_velocity and ('inlet_x_velocity' in mesh.point_data.keys() or 'inlet_velocity' in mesh.point_data.keys()): # Use velocity from mesh data print("Using training velocity from mesh data") key = 'inlet_x_velocity' if 'inlet_x_velocity' in mesh.point_data.keys() else 'inlet_velocity' print("Inlet velocity values", mesh[key]) inlet_x_vel = torch.tensor(mesh[key], dtype=torch.float32).unsqueeze(-1) pos = torch.cat((pos, inlet_x_vel), dim=1) # Determine velocity to use based on parameters elif not use_training_velocity and custom_velocity is not None: # Use custom velocity value print(f"Using custom velocity: {custom_velocity}") inlet_x_vel = custom_velocity * torch.ones(pos.shape[0], 1, dtype=torch.float32) pos = torch.cat((pos, inlet_x_vel), dim=1) else: # Use default velocity raise ValueError("No velocity field found and no custom velocity provided") # Optional: input normalization shifts if getattr(cfg, "input_normalization", None) == "shift_axis": coords = pos[:, :3].clone() coords[:, 0] = coords[:, 0] - coords[:, 0].min() # shift x coords[:, 2] = coords[:, 2] - coords[:, 2].min() # shift z y_center = (coords[:, 1].max() + coords[:, 1].min()) / 2.0 coords[:, 1] = coords[:, 1] - y_center # center y pos[:, :3] = coords # Positional scaling with sincos embedding (0..1000) if cfg.pos_embed_sincos: if cfg.diff_input_velocity: raise Exception("pos_embed_sincos not supported with diff_input_velocity=True") print("applying pos_embed_sincos") input_pos_mins = torch.tensor(mesh.bounds[0]) input_pos_maxs = torch.tensor(mesh.bounds[1]) pos = 1000.0 * (pos - input_pos_mins) / (input_pos_maxs - input_pos_mins) # assertions in dataset_loader: ensure within [0,1000] pos = torch.clamp(pos, 0, 1000) # Apply permutation after all transforms (matches dataset_loader) pos = pos[indices] data = {"input_pos": pos.unsqueeze(0).to(use_device)} if cfg.chunked_eval: input_pos = data['input_pos'] B, N, C = input_pos.shape chunk_size = cfg.num_points latent_outputs = [] for i in range(0, N, chunk_size): # start with the raw slice chunk = input_pos[:, i:i+chunk_size, :] # (B, n_valid, C) n_valid = chunk.shape[1] # Pad if last chunk is short if n_valid < chunk_size: shape_diff = chunk_size - n_valid # Wrap from the beginning to make a full chunk pad = input_pos[:, :shape_diff, :] # (B, shape_diff, C) chunk = torch.cat([chunk, pad], dim=1) # (B, chunk_size, C) data['input_pos'] = chunk if cfg.save_latent: out_chunk, latent = model(data) # (B, chunk_size, D) else: out_chunk = model(data) # (B, chunk_size, D) # Keep only the valid part that corresponds to real points out_chunk = out_chunk[:, :n_valid, :] # (B, n_valid, D) latent = latent[:, :n_valid, :] else: data['input_pos'] = chunk if cfg.save_latent: out_chunk, latent = model(data) # (B, chunk_size, D) else: out_chunk = model(data) # (B, chunk_size, D) # outputs.append(out_chunk) latent_outputs.append(latent) if cfg.save_latent else None latent_outputs = torch.cat(latent_outputs, dim=1) if cfg.save_latent else None latent = latent_outputs.cpu().numpy() if cfg.save_latent else None else: if cfg.save_latent: outputs, latent = model(data) else: outputs = model(data) latent_outputs = torch.cat(latent_outputs, dim=1) if cfg.save_latent else None latent = latent_outputs.cpu().numpy() if cfg.save_latent else None # Return [N_points, hidden_dim] numpy array print("latent shape", latent.shape) return latent.squeeze(0)