import numpy as np import time, json, os import torch import torch.nn as nn import torch.nn.functional as F from tqdm import tqdm import json from torch.utils.tensorboard import SummaryWriter from omegaconf import OmegaConf import math from torch.cuda.amp import GradScaler # For sampling before training starts # from datasets.DriveAerNet.data_loader import get_dataloaders, PRESSURE_MEAN, PRESSURE_STD # For full mesh load and then sample in each training iteration from datasets.DriveAerNet.data_loader_full import get_dataloaders, PRESSURE_MEAN, PRESSURE_STD def train(model, train_loader, optimizer, scheduler, cfg, accelerator, scaler): model.train() criterion = nn.MSELoss() losses_press = 0.0 pressure_mean = torch.tensor(PRESSURE_MEAN, device=accelerator.device) pressure_std = torch.tensor(PRESSURE_STD, device=accelerator.device) for data in train_loader: targets= data['output_feat'] targets = (targets - pressure_mean) / pressure_std optimizer.zero_grad() if cfg.mixed_precision: with torch.autocast(device_type = accelerator.device.type): out = model(data) total_loss = criterion(out, targets) scaler.scale(total_loss).backward() scaler.unscale_(optimizer) if cfg.max_grad_norm is not None: accelerator.clip_grad_norm_(model.parameters(), cfg.max_grad_norm) scaler.step(optimizer) scaler.update() else: out = model(data) total_loss = criterion(out, targets) accelerator.backward(total_loss) if cfg.max_grad_norm is not None: accelerator.clip_grad_norm_(model.parameters(), cfg.max_grad_norm) optimizer.step() if cfg.scheduler == "OneCycleLR": scheduler.step() losses_press += total_loss.item() return losses_press / len(train_loader) @torch.no_grad() def val(model, val_loader, cfg, accelerator): model.eval() criterion = nn.MSELoss() losses_press = 0.0 pressure_mean = torch.tensor(PRESSURE_MEAN, device=accelerator.device) pressure_std = torch.tensor(PRESSURE_STD, device=accelerator.device) for data in val_loader: targets= data['output_feat'] targets = (targets - pressure_mean) / pressure_std out = model(data) total_loss = criterion(out, targets) losses_press += total_loss.item() return losses_press / len(val_loader) def RelL2loss(x,y): """Calculate relative L2 error: mean(||x-y||_2 / ||y||_2) over the batch. Args: x: Predicted values (B, N, C) y: Target values (B, N, C) Returns: Mean relative L2 error across the batch """ # Ensure inputs are (B, NC) batch_size = x.size(0) x_flat = x.view(batch_size, -1) # (B, NC) y_flat = y.view(batch_size, -1) # (B, NC) # Calculate L2 norm for each sample in batch diff_norms = torch.norm(x_flat - y_flat, p=2, dim=1) # (B,) y_norms = torch.norm(y_flat, p=2, dim=1) # (B,) # Calculate RelL2 for each sample and take mean over batch return torch.mean(diff_norms / (y_norms)) def test_model(model, test_dataloader, criterion, path, cfg, accelerator): """Test the model and calculate metrics.""" model.eval() total_mse = 0.0 total_mae = 0.0 total_rel_l2 = 0.0 total_rel_l1 = 0.0 total_inference_time = 0.0 num_batches = 0 pressure_mean = torch.tensor(PRESSURE_MEAN, device=accelerator.device) pressure_std = torch.tensor(PRESSURE_STD, device=accelerator.device) # Store outputs and targets on all processes all_outputs = [] all_targets = [] with torch.no_grad(): for data in tqdm(test_dataloader, desc="[Testing]", disable=not accelerator.is_local_main_process): start_time = time.time() targets= data['output_feat'] targets = (targets - pressure_mean) / pressure_std # Match train/val normalization outputs = model(data) inference_time = time.time() - start_time total_inference_time += inference_time # Keep metrics as tensors for proper reduction across processes mse = criterion(outputs, targets) mae = F.l1_loss(outputs, targets) # rel_l2 = torch.mean(torch.norm(outputs.squeeze(-1) - targets.squeeze(-1), p=2, dim=-1) / # torch.norm(targets.squeeze(-1), p=2, dim=-1)) rel_l1 = torch.mean(torch.norm(outputs.squeeze(-1) - targets.squeeze(-1), p=1, dim=-1) / torch.norm(targets.squeeze(-1), p=1, dim=-1)) rel_l2 = RelL2loss(outputs, targets) total_mse += mse total_mae += mae total_rel_l2 += rel_l2 total_rel_l1 += rel_l1 num_batches += 1 # Store outputs and targets on all processes all_outputs.append(outputs.cpu()) all_targets.append(targets.cpu()) # Clear references to tensors del outputs, targets, mse, mae, rel_l2, rel_l1 # Clear GPU cache after all testing torch.cuda.empty_cache() # Convert to tensors for reduction metrics = { "total_mse": total_mse, "total_mae": total_mae, "total_rel_l2": total_rel_l2, "total_rel_l1": total_rel_l1, "num_batches": torch.tensor(num_batches, device=accelerator.device), "total_inference_time": torch.tensor(total_inference_time, device=accelerator.device) } # Gather metrics from all processes gathered_metrics = accelerator.gather(metrics) # Only calculate averages if we have data if gathered_metrics["num_batches"].sum().item() > 0: total_batches = gathered_metrics["num_batches"].sum().item() avg_mse = gathered_metrics["total_mse"].sum().item() / total_batches avg_mae = gathered_metrics["total_mae"].sum().item() / total_batches avg_rel_l2 = gathered_metrics["total_rel_l2"].sum().item() / total_batches avg_rel_l1 = gathered_metrics["total_rel_l1"].sum().item() / total_batches total_inference_time = gathered_metrics["total_inference_time"].sum().item() avg_inference_time = total_inference_time / total_batches # Gather all outputs and targets from all processes all_outputs = torch.cat(all_outputs, dim=0) all_targets = torch.cat(all_targets, dim=0) # Gather outputs and targets across processes all_outputs = accelerator.gather(all_outputs.to(accelerator.device)) all_targets = accelerator.gather(all_targets.to(accelerator.device)) # Calculate R² score using complete dataset if accelerator.is_main_process: all_outputs = all_outputs.to(torch.float32).cpu().numpy() all_targets = all_targets.to(torch.float32).cpu().numpy() ss_tot = np.sum((all_targets - np.mean(all_targets)) ** 2) ss_res = np.sum((all_targets - all_outputs) ** 2) r_squared = 1 - (ss_res / ss_tot) if ss_tot > 0 else 0 print(f"Test MSE: {avg_mse:.6f}, Test MAE: {avg_mae:.6f}, R²: {r_squared:.4f}") print(f"Relative L2 Error: {avg_rel_l2:.6f}, Relative L1 Error: {avg_rel_l1:.6f}") print(f"Average inference time per batch: {avg_inference_time:.4f}s") print(f"Total inference time: {total_inference_time:.2f}s for {total_batches} batches") # Save metrics to a text file metrics_file = os.path.join(path, 'test_metrics.txt') with open(metrics_file, 'w') as f: f.write(f"Test MSE: {avg_mse:.6f}\n") f.write(f"Test MAE: {avg_mae:.6f}\n") f.write(f"R² Score: {r_squared:.6f}\n") f.write(f"Relative L2 Error: {avg_rel_l2:.6f}\n") f.write(f"Relative L1 Error: {avg_rel_l1:.6f}\n") f.write(f"Average inference time per batch: {avg_inference_time:.4f}s\n") f.write(f"Total inference time: {total_inference_time:.2f}s for {total_batches} batches\n") else: r_squared = 0.0 # Will be overwritten by broadcast else: print("Warning: No data in test_dataloader") avg_mse = avg_mae = avg_rel_l2 = avg_rel_l1 = r_squared = 0.0 return avg_mse, avg_mae, avg_rel_l2, avg_rel_l1, r_squared, avg_inference_time def train_DriveAerNet_main(model, path, cfg, accelerator): train_loader, val_loader, test_loader = get_dataloaders(cfg, cfg.data_dir, cfg.subset_dir, cfg.num_points, cfg.batch_size, cfg.cache_dir, cfg.num_workers, cfg.model) if accelerator.is_main_process: print( f"Data loaded: {len(train_loader)} training batches, " f"{len(val_loader)} validation batches, " f"{len(test_loader)} test batches") #Select optimizer if cfg.optimizer.type == 'Adam': optimizer = torch.optim.Adam(model.parameters(), lr=cfg.lr, weight_decay=1e-4) elif cfg.optimizer.type == 'AdamW': optimizer = torch.optim.AdamW(model.parameters(), lr=cfg.lr,betas=(0.9, 0.95), weight_decay=0.05) #Select scheduler if cfg.scheduler == "ReduceLROnPlateau": scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=10, factor=0.1, verbose=True) elif cfg.scheduler == "LinearWarmupCosineAnnealing": # Linear warmup followed by cosine annealing warmup_steps = len(train_loader) * 5 # 5 epochs of warmup total_steps = len(train_loader) * cfg.epochs def lr_lambda(step): if step < warmup_steps: return float(step) / float(max(1, warmup_steps)) else: progress = float(step - warmup_steps) / float(max(1, total_steps - warmup_steps)) return max(0.0, 0.5 * (1.0 + math.cos(math.pi * progress))) scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda) else: scheduler = torch.optim.lr_scheduler.OneCycleLR( optimizer, pct_start=0.05, max_lr=cfg.lr, total_steps = len(train_loader) * cfg.epochs ) scaler = GradScaler() model, optimizer, train_loader, val_loader, test_loader, scheduler, scaler = accelerator.prepare( model, optimizer, train_loader, val_loader, test_loader, scheduler, scaler) criterion = torch.nn.MSELoss() if cfg.eval: # Load the saved state dict and create a fresh model state_dict = torch.load(os.path.join(path, f'best_model.pt')) unwrapped_model = accelerator.unwrap_model(model) unwrapped_model.load_state_dict(state_dict) best_model = accelerator.prepare(unwrapped_model) best_mse, best_mae, best_rel_l2, best_rel_l1, best_r2, inf_time = test_model(best_model, test_loader, criterion, path, cfg, accelerator) else: # Calculate total parameters total_params = sum(p.numel() for p in model.parameters()) trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) # Reset memory stats before training torch.cuda.reset_peak_memory_stats() torch.cuda.empty_cache() # Clear any existing cached memory best_val_loss = 1e5 val_MSE_list = [] start = time.time() # Only initialize tensorboard on the main process if accelerator.is_main_process: # Create a descriptive run name using model type and timestamp run_name = f"{cfg.model}{cfg.test_name}_{time.strftime('%Y%m%d_%H%M%S')}" project_name = os.path.join("tensorboard_logs", f"{cfg.project_name}") log_dir = os.path.join(project_name, run_name) writer = SummaryWriter(log_dir) # Add full config config_text = "```yaml\n" # Using yaml format for better readability config_text += OmegaConf.to_yaml(cfg) config_text += "```" writer.add_text('hyperparameters/full_config', config_text) pbar_train = tqdm(range(cfg.epochs), position=0) else: writer = None pbar_train = range(cfg.epochs) for epoch in pbar_train: train_loss = train(model, train_loader, optimizer, scheduler, cfg, accelerator, scaler) if cfg.val_iter is not None and (epoch == cfg.epochs - 1 or epoch % cfg.val_iter == 0): val_loss_MSE = val(model, val_loader, cfg, accelerator) if cfg.scheduler == "ReduceLROnPlateau": scheduler.step(val_loss_MSE) val_MSE_list.append(val_loss_MSE) if accelerator.is_main_process: # Get peak GPU memory in GB peak_mem_gb = torch.cuda.max_memory_allocated() / (1024 * 1024 * 1024) # Log metrics to tensorboard writer.add_scalar('Loss/train_MSE', train_loss, epoch) writer.add_scalar('Loss/val_MSE', val_loss_MSE, epoch) writer.add_scalar('Learning_rate', scheduler.get_last_lr()[0], epoch) writer.add_scalar('Memory/GPU', peak_mem_gb, epoch) with open(os.path.join(path,'MSE.json'), 'w') as f: json.dump(val_MSE_list, f, indent=2) pbar_train.set_postfix({ 'train_loss': train_loss, 'val_loss': val_loss_MSE, 'lr': scheduler.get_last_lr()[0], 'mem_gb': f'{peak_mem_gb:.1f}' }) if val_loss_MSE < best_val_loss: best_val_loss = val_loss_MSE unwrapped_model = accelerator.unwrap_model(model) torch.save(unwrapped_model.state_dict(), os.path.join(path, f'best_model.pt')) elif accelerator.is_main_process: # Simple progress display without validation peak_mem_gb = torch.cuda.max_memory_allocated() / (1024 * 1024 * 1024) pbar_train.set_postfix({ 'train_loss': train_loss, 'mem_gb': f'{peak_mem_gb:.1f}' }) end = time.time() time_elapsed = end - start # Get final peak memory for reporting if accelerator.is_main_process: peak_mem_gb = torch.cuda.max_memory_allocated() / (1024 * 1024 * 1024) # Reset memory stats before final evaluation torch.cuda.reset_peak_memory_stats() torch.cuda.empty_cache() # Test final model final_mse, final_mae, final_rel_l2, final_rel_l1, final_r2, inf_time = test_model( model, test_loader, criterion, path, cfg, accelerator) # Get peak memory during testing if accelerator.is_main_process: test_peak_mem_gb = torch.cuda.max_memory_allocated() / (1024 * 1024 * 1024) # Create metrics text for final model metrics_text = f"Test MSE: {final_mse:.6f}\n" metrics_text += f"Test MAE: {final_mae:.6f}\n" metrics_text += f"Test RelL1: {final_rel_l1:.6f}\n" metrics_text += f"Test RelL2: {final_rel_l2:.6f}\n" metrics_text += f"Test R2: {final_r2:.6f}\n" metrics_text += f"Inference time: {inf_time:.6f}s\n" metrics_text += f"Total training time: {time_elapsed:.2f}s\n" metrics_text += f"Average epoch time: {time_elapsed/cfg.epochs:.2f}s\n" metrics_text += f"Total parameters: {total_params}\n" metrics_text += f"Trainable parameters: {trainable_params}\n" metrics_text += f"Peak GPU memory usage:\n" metrics_text += f" - During training: {peak_mem_gb:.1f} GB\n" metrics_text += f" - During testing: {test_peak_mem_gb:.1f} GB\n" # Write to file and add to tensorboard metrics_file = os.path.join(path, 'final_test_metrics.txt') with open(metrics_file, 'w') as f: f.write(metrics_text) # Add final metrics to tensorboard as text (replace \n with markdown line break) writer.add_text('metrics/final_metrics', metrics_text.replace('\n', ' \n')) # Load the best model using state_dict for compatibility state_dict = torch.load(os.path.join(path, f'best_model.pt')) unwrapped_model = accelerator.unwrap_model(model) unwrapped_model.load_state_dict(state_dict) best_model = accelerator.prepare(unwrapped_model) best_mse, best_mae, best_rel_l2, best_rel_l1, best_r2, inf_time = test_model( best_model, test_loader, criterion, path, cfg, accelerator) if accelerator.is_main_process: # Create metrics text for best model metrics_text = f"Test MSE: {best_mse:.6f}\n" metrics_text += f"Test MAE: {best_mae:.6f}\n" metrics_text += f"Test RelL1: {best_rel_l1:.6f}\n" metrics_text += f"Test RelL2: {best_rel_l2:.6f}\n" metrics_text += f"Test R2: {best_r2:.6f}\n" metrics_text += f"Inference time: {inf_time:.6f}s\n" metrics_text += f"Total training time: {time_elapsed:.2f}s\n" metrics_text += f"Average epoch time: {time_elapsed/cfg.epochs:.2f}s\n" metrics_text += f"Total parameters: {total_params}\n" metrics_text += f"Trainable parameters: {trainable_params}\n" metrics_text += f"Peak GPU memory usage:\n" metrics_text += f" - During training: {peak_mem_gb:.1f} GB\n" metrics_text += f" - During testing: {test_peak_mem_gb:.1f} GB\n" # Write to file and add to tensorboard metrics_file = os.path.join(path, 'best_test_metrics.txt') with open(metrics_file, 'w') as f: f.write(metrics_text) # Add best metrics to tensorboard as text (replace \n with markdown line break) writer.add_text('metrics/best_metrics', metrics_text.replace('\n', ' \n')) print(f"\nFinal model metrics - MSE: {final_mse:.6f}, MAE: {final_mae:.6f}, R²: {final_r2:.4f}") print(f"Best model metrics - MSE: {best_mse:.6f}, MAE: {best_mae:.6f}, R²: {best_r2:.4f}") # Close tensorboard writer writer.close()