Spaces:
Sleeping
Sleeping
| import numpy as np | |
| import time, json, os | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from tqdm import tqdm | |
| import json | |
| from torch.utils.tensorboard import SummaryWriter | |
| from omegaconf import OmegaConf | |
| import math | |
| from torch.cuda.amp import GradScaler | |
| # For sampling before training starts | |
| # from datasets.DriveAerNet.data_loader import get_dataloaders, PRESSURE_MEAN, PRESSURE_STD | |
| # For full mesh load and then sample in each training iteration | |
| from datasets.DriveAerNet.data_loader_full import get_dataloaders, PRESSURE_MEAN, PRESSURE_STD | |
| def train(model, train_loader, optimizer, scheduler, cfg, accelerator, scaler): | |
| model.train() | |
| criterion = nn.MSELoss() | |
| losses_press = 0.0 | |
| pressure_mean = torch.tensor(PRESSURE_MEAN, device=accelerator.device) | |
| pressure_std = torch.tensor(PRESSURE_STD, device=accelerator.device) | |
| for data in train_loader: | |
| targets= data['output_feat'] | |
| targets = (targets - pressure_mean) / pressure_std | |
| optimizer.zero_grad() | |
| if cfg.mixed_precision: | |
| with torch.autocast(device_type = accelerator.device.type): | |
| out = model(data) | |
| total_loss = criterion(out, targets) | |
| scaler.scale(total_loss).backward() | |
| scaler.unscale_(optimizer) | |
| if cfg.max_grad_norm is not None: | |
| accelerator.clip_grad_norm_(model.parameters(), cfg.max_grad_norm) | |
| scaler.step(optimizer) | |
| scaler.update() | |
| else: | |
| out = model(data) | |
| total_loss = criterion(out, targets) | |
| accelerator.backward(total_loss) | |
| if cfg.max_grad_norm is not None: | |
| accelerator.clip_grad_norm_(model.parameters(), cfg.max_grad_norm) | |
| optimizer.step() | |
| if cfg.scheduler == "OneCycleLR": | |
| scheduler.step() | |
| losses_press += total_loss.item() | |
| return losses_press / len(train_loader) | |
| def val(model, val_loader, cfg, accelerator): | |
| model.eval() | |
| criterion = nn.MSELoss() | |
| losses_press = 0.0 | |
| pressure_mean = torch.tensor(PRESSURE_MEAN, device=accelerator.device) | |
| pressure_std = torch.tensor(PRESSURE_STD, device=accelerator.device) | |
| for data in val_loader: | |
| targets= data['output_feat'] | |
| targets = (targets - pressure_mean) / pressure_std | |
| out = model(data) | |
| total_loss = criterion(out, targets) | |
| losses_press += total_loss.item() | |
| return losses_press / len(val_loader) | |
| def RelL2loss(x,y): | |
| """Calculate relative L2 error: mean(||x-y||_2 / ||y||_2) over the batch. | |
| Args: | |
| x: Predicted values (B, N, C) | |
| y: Target values (B, N, C) | |
| Returns: | |
| Mean relative L2 error across the batch | |
| """ | |
| # Ensure inputs are (B, NC) | |
| batch_size = x.size(0) | |
| x_flat = x.view(batch_size, -1) # (B, NC) | |
| y_flat = y.view(batch_size, -1) # (B, NC) | |
| # Calculate L2 norm for each sample in batch | |
| diff_norms = torch.norm(x_flat - y_flat, p=2, dim=1) # (B,) | |
| y_norms = torch.norm(y_flat, p=2, dim=1) # (B,) | |
| # Calculate RelL2 for each sample and take mean over batch | |
| return torch.mean(diff_norms / (y_norms)) | |
| def test_model(model, test_dataloader, criterion, path, cfg, accelerator): | |
| """Test the model and calculate metrics.""" | |
| model.eval() | |
| total_mse = 0.0 | |
| total_mae = 0.0 | |
| total_rel_l2 = 0.0 | |
| total_rel_l1 = 0.0 | |
| total_inference_time = 0.0 | |
| num_batches = 0 | |
| pressure_mean = torch.tensor(PRESSURE_MEAN, device=accelerator.device) | |
| pressure_std = torch.tensor(PRESSURE_STD, device=accelerator.device) | |
| # Store outputs and targets on all processes | |
| all_outputs = [] | |
| all_targets = [] | |
| with torch.no_grad(): | |
| for data in tqdm(test_dataloader, desc="[Testing]", disable=not accelerator.is_local_main_process): | |
| start_time = time.time() | |
| targets= data['output_feat'] | |
| targets = (targets - pressure_mean) / pressure_std # Match train/val normalization | |
| outputs = model(data) | |
| inference_time = time.time() - start_time | |
| total_inference_time += inference_time | |
| # Keep metrics as tensors for proper reduction across processes | |
| mse = criterion(outputs, targets) | |
| mae = F.l1_loss(outputs, targets) | |
| # rel_l2 = torch.mean(torch.norm(outputs.squeeze(-1) - targets.squeeze(-1), p=2, dim=-1) / | |
| # torch.norm(targets.squeeze(-1), p=2, dim=-1)) | |
| rel_l1 = torch.mean(torch.norm(outputs.squeeze(-1) - targets.squeeze(-1), p=1, dim=-1) / | |
| torch.norm(targets.squeeze(-1), p=1, dim=-1)) | |
| rel_l2 = RelL2loss(outputs, targets) | |
| total_mse += mse | |
| total_mae += mae | |
| total_rel_l2 += rel_l2 | |
| total_rel_l1 += rel_l1 | |
| num_batches += 1 | |
| # Store outputs and targets on all processes | |
| all_outputs.append(outputs.cpu()) | |
| all_targets.append(targets.cpu()) | |
| # Clear references to tensors | |
| del outputs, targets, mse, mae, rel_l2, rel_l1 | |
| # Clear GPU cache after all testing | |
| torch.cuda.empty_cache() | |
| # Convert to tensors for reduction | |
| metrics = { | |
| "total_mse": total_mse, | |
| "total_mae": total_mae, | |
| "total_rel_l2": total_rel_l2, | |
| "total_rel_l1": total_rel_l1, | |
| "num_batches": torch.tensor(num_batches, device=accelerator.device), | |
| "total_inference_time": torch.tensor(total_inference_time, device=accelerator.device) | |
| } | |
| # Gather metrics from all processes | |
| gathered_metrics = accelerator.gather(metrics) | |
| # Only calculate averages if we have data | |
| if gathered_metrics["num_batches"].sum().item() > 0: | |
| total_batches = gathered_metrics["num_batches"].sum().item() | |
| avg_mse = gathered_metrics["total_mse"].sum().item() / total_batches | |
| avg_mae = gathered_metrics["total_mae"].sum().item() / total_batches | |
| avg_rel_l2 = gathered_metrics["total_rel_l2"].sum().item() / total_batches | |
| avg_rel_l1 = gathered_metrics["total_rel_l1"].sum().item() / total_batches | |
| total_inference_time = gathered_metrics["total_inference_time"].sum().item() | |
| avg_inference_time = total_inference_time / total_batches | |
| # Gather all outputs and targets from all processes | |
| all_outputs = torch.cat(all_outputs, dim=0) | |
| all_targets = torch.cat(all_targets, dim=0) | |
| # Gather outputs and targets across processes | |
| all_outputs = accelerator.gather(all_outputs.to(accelerator.device)) | |
| all_targets = accelerator.gather(all_targets.to(accelerator.device)) | |
| # Calculate R² score using complete dataset | |
| if accelerator.is_main_process: | |
| all_outputs = all_outputs.to(torch.float32).cpu().numpy() | |
| all_targets = all_targets.to(torch.float32).cpu().numpy() | |
| ss_tot = np.sum((all_targets - np.mean(all_targets)) ** 2) | |
| ss_res = np.sum((all_targets - all_outputs) ** 2) | |
| r_squared = 1 - (ss_res / ss_tot) if ss_tot > 0 else 0 | |
| print(f"Test MSE: {avg_mse:.6f}, Test MAE: {avg_mae:.6f}, R²: {r_squared:.4f}") | |
| print(f"Relative L2 Error: {avg_rel_l2:.6f}, Relative L1 Error: {avg_rel_l1:.6f}") | |
| print(f"Average inference time per batch: {avg_inference_time:.4f}s") | |
| print(f"Total inference time: {total_inference_time:.2f}s for {total_batches} batches") | |
| # Save metrics to a text file | |
| metrics_file = os.path.join(path, 'test_metrics.txt') | |
| with open(metrics_file, 'w') as f: | |
| f.write(f"Test MSE: {avg_mse:.6f}\n") | |
| f.write(f"Test MAE: {avg_mae:.6f}\n") | |
| f.write(f"R² Score: {r_squared:.6f}\n") | |
| f.write(f"Relative L2 Error: {avg_rel_l2:.6f}\n") | |
| f.write(f"Relative L1 Error: {avg_rel_l1:.6f}\n") | |
| f.write(f"Average inference time per batch: {avg_inference_time:.4f}s\n") | |
| f.write(f"Total inference time: {total_inference_time:.2f}s for {total_batches} batches\n") | |
| else: | |
| r_squared = 0.0 # Will be overwritten by broadcast | |
| else: | |
| print("Warning: No data in test_dataloader") | |
| avg_mse = avg_mae = avg_rel_l2 = avg_rel_l1 = r_squared = 0.0 | |
| return avg_mse, avg_mae, avg_rel_l2, avg_rel_l1, r_squared, avg_inference_time | |
| def train_DriveAerNet_main(model, path, cfg, accelerator): | |
| train_loader, val_loader, test_loader = get_dataloaders(cfg, cfg.data_dir, cfg.subset_dir, cfg.num_points, | |
| cfg.batch_size, cfg.cache_dir, cfg.num_workers, cfg.model) | |
| if accelerator.is_main_process: | |
| print( | |
| f"Data loaded: {len(train_loader)} training batches, " | |
| f"{len(val_loader)} validation batches, " | |
| f"{len(test_loader)} test batches") | |
| #Select optimizer | |
| if cfg.optimizer.type == 'Adam': | |
| optimizer = torch.optim.Adam(model.parameters(), lr=cfg.lr, weight_decay=1e-4) | |
| elif cfg.optimizer.type == 'AdamW': | |
| optimizer = torch.optim.AdamW(model.parameters(), lr=cfg.lr,betas=(0.9, 0.95), weight_decay=0.05) | |
| #Select scheduler | |
| if cfg.scheduler == "ReduceLROnPlateau": | |
| scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=10, factor=0.1, verbose=True) | |
| elif cfg.scheduler == "LinearWarmupCosineAnnealing": | |
| # Linear warmup followed by cosine annealing | |
| warmup_steps = len(train_loader) * 5 # 5 epochs of warmup | |
| total_steps = len(train_loader) * cfg.epochs | |
| def lr_lambda(step): | |
| if step < warmup_steps: | |
| return float(step) / float(max(1, warmup_steps)) | |
| else: | |
| progress = float(step - warmup_steps) / float(max(1, total_steps - warmup_steps)) | |
| return max(0.0, 0.5 * (1.0 + math.cos(math.pi * progress))) | |
| scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda) | |
| else: | |
| scheduler = torch.optim.lr_scheduler.OneCycleLR( | |
| optimizer, | |
| pct_start=0.05, | |
| max_lr=cfg.lr, | |
| total_steps = len(train_loader) * cfg.epochs | |
| ) | |
| scaler = GradScaler() | |
| model, optimizer, train_loader, val_loader, test_loader, scheduler, scaler = accelerator.prepare( | |
| model, optimizer, train_loader, val_loader, test_loader, scheduler, scaler) | |
| criterion = torch.nn.MSELoss() | |
| if cfg.eval: | |
| # Load the saved state dict and create a fresh model | |
| state_dict = torch.load(os.path.join(path, f'best_model.pt')) | |
| unwrapped_model = accelerator.unwrap_model(model) | |
| unwrapped_model.load_state_dict(state_dict) | |
| best_model = accelerator.prepare(unwrapped_model) | |
| best_mse, best_mae, best_rel_l2, best_rel_l1, best_r2, inf_time = test_model(best_model, test_loader, criterion, path, cfg, accelerator) | |
| else: | |
| # Calculate total parameters | |
| total_params = sum(p.numel() for p in model.parameters()) | |
| trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) | |
| # Reset memory stats before training | |
| torch.cuda.reset_peak_memory_stats() | |
| torch.cuda.empty_cache() # Clear any existing cached memory | |
| best_val_loss = 1e5 | |
| val_MSE_list = [] | |
| start = time.time() | |
| # Only initialize tensorboard on the main process | |
| if accelerator.is_main_process: | |
| # Create a descriptive run name using model type and timestamp | |
| run_name = f"{cfg.model}{cfg.test_name}_{time.strftime('%Y%m%d_%H%M%S')}" | |
| project_name = os.path.join("tensorboard_logs", f"{cfg.project_name}") | |
| log_dir = os.path.join(project_name, run_name) | |
| writer = SummaryWriter(log_dir) | |
| # Add full config | |
| config_text = "```yaml\n" # Using yaml format for better readability | |
| config_text += OmegaConf.to_yaml(cfg) | |
| config_text += "```" | |
| writer.add_text('hyperparameters/full_config', config_text) | |
| pbar_train = tqdm(range(cfg.epochs), position=0) | |
| else: | |
| writer = None | |
| pbar_train = range(cfg.epochs) | |
| for epoch in pbar_train: | |
| train_loss = train(model, train_loader, optimizer, scheduler, cfg, accelerator, scaler) | |
| if cfg.val_iter is not None and (epoch == cfg.epochs - 1 or epoch % cfg.val_iter == 0): | |
| val_loss_MSE = val(model, val_loader, cfg, accelerator) | |
| if cfg.scheduler == "ReduceLROnPlateau": | |
| scheduler.step(val_loss_MSE) | |
| val_MSE_list.append(val_loss_MSE) | |
| if accelerator.is_main_process: | |
| # Get peak GPU memory in GB | |
| peak_mem_gb = torch.cuda.max_memory_allocated() / (1024 * 1024 * 1024) | |
| # Log metrics to tensorboard | |
| writer.add_scalar('Loss/train_MSE', train_loss, epoch) | |
| writer.add_scalar('Loss/val_MSE', val_loss_MSE, epoch) | |
| writer.add_scalar('Learning_rate', scheduler.get_last_lr()[0], epoch) | |
| writer.add_scalar('Memory/GPU', peak_mem_gb, epoch) | |
| with open(os.path.join(path,'MSE.json'), 'w') as f: | |
| json.dump(val_MSE_list, f, indent=2) | |
| pbar_train.set_postfix({ | |
| 'train_loss': train_loss, | |
| 'val_loss': val_loss_MSE, | |
| 'lr': scheduler.get_last_lr()[0], | |
| 'mem_gb': f'{peak_mem_gb:.1f}' | |
| }) | |
| if val_loss_MSE < best_val_loss: | |
| best_val_loss = val_loss_MSE | |
| unwrapped_model = accelerator.unwrap_model(model) | |
| torch.save(unwrapped_model.state_dict(), os.path.join(path, f'best_model.pt')) | |
| elif accelerator.is_main_process: | |
| # Simple progress display without validation | |
| peak_mem_gb = torch.cuda.max_memory_allocated() / (1024 * 1024 * 1024) | |
| pbar_train.set_postfix({ | |
| 'train_loss': train_loss, | |
| 'mem_gb': f'{peak_mem_gb:.1f}' | |
| }) | |
| end = time.time() | |
| time_elapsed = end - start | |
| # Get final peak memory for reporting | |
| if accelerator.is_main_process: | |
| peak_mem_gb = torch.cuda.max_memory_allocated() / (1024 * 1024 * 1024) | |
| # Reset memory stats before final evaluation | |
| torch.cuda.reset_peak_memory_stats() | |
| torch.cuda.empty_cache() | |
| # Test final model | |
| final_mse, final_mae, final_rel_l2, final_rel_l1, final_r2, inf_time = test_model( | |
| model, test_loader, criterion, path, cfg, accelerator) | |
| # Get peak memory during testing | |
| if accelerator.is_main_process: | |
| test_peak_mem_gb = torch.cuda.max_memory_allocated() / (1024 * 1024 * 1024) | |
| # Create metrics text for final model | |
| metrics_text = f"Test MSE: {final_mse:.6f}\n" | |
| metrics_text += f"Test MAE: {final_mae:.6f}\n" | |
| metrics_text += f"Test RelL1: {final_rel_l1:.6f}\n" | |
| metrics_text += f"Test RelL2: {final_rel_l2:.6f}\n" | |
| metrics_text += f"Test R2: {final_r2:.6f}\n" | |
| metrics_text += f"Inference time: {inf_time:.6f}s\n" | |
| metrics_text += f"Total training time: {time_elapsed:.2f}s\n" | |
| metrics_text += f"Average epoch time: {time_elapsed/cfg.epochs:.2f}s\n" | |
| metrics_text += f"Total parameters: {total_params}\n" | |
| metrics_text += f"Trainable parameters: {trainable_params}\n" | |
| metrics_text += f"Peak GPU memory usage:\n" | |
| metrics_text += f" - During training: {peak_mem_gb:.1f} GB\n" | |
| metrics_text += f" - During testing: {test_peak_mem_gb:.1f} GB\n" | |
| # Write to file and add to tensorboard | |
| metrics_file = os.path.join(path, 'final_test_metrics.txt') | |
| with open(metrics_file, 'w') as f: | |
| f.write(metrics_text) | |
| # Add final metrics to tensorboard as text (replace \n with markdown line break) | |
| writer.add_text('metrics/final_metrics', metrics_text.replace('\n', ' \n')) | |
| # Load the best model using state_dict for compatibility | |
| state_dict = torch.load(os.path.join(path, f'best_model.pt')) | |
| unwrapped_model = accelerator.unwrap_model(model) | |
| unwrapped_model.load_state_dict(state_dict) | |
| best_model = accelerator.prepare(unwrapped_model) | |
| best_mse, best_mae, best_rel_l2, best_rel_l1, best_r2, inf_time = test_model( | |
| best_model, test_loader, criterion, path, cfg, accelerator) | |
| if accelerator.is_main_process: | |
| # Create metrics text for best model | |
| metrics_text = f"Test MSE: {best_mse:.6f}\n" | |
| metrics_text += f"Test MAE: {best_mae:.6f}\n" | |
| metrics_text += f"Test RelL1: {best_rel_l1:.6f}\n" | |
| metrics_text += f"Test RelL2: {best_rel_l2:.6f}\n" | |
| metrics_text += f"Test R2: {best_r2:.6f}\n" | |
| metrics_text += f"Inference time: {inf_time:.6f}s\n" | |
| metrics_text += f"Total training time: {time_elapsed:.2f}s\n" | |
| metrics_text += f"Average epoch time: {time_elapsed/cfg.epochs:.2f}s\n" | |
| metrics_text += f"Total parameters: {total_params}\n" | |
| metrics_text += f"Trainable parameters: {trainable_params}\n" | |
| metrics_text += f"Peak GPU memory usage:\n" | |
| metrics_text += f" - During training: {peak_mem_gb:.1f} GB\n" | |
| metrics_text += f" - During testing: {test_peak_mem_gb:.1f} GB\n" | |
| # Write to file and add to tensorboard | |
| metrics_file = os.path.join(path, 'best_test_metrics.txt') | |
| with open(metrics_file, 'w') as f: | |
| f.write(metrics_text) | |
| # Add best metrics to tensorboard as text (replace \n with markdown line break) | |
| writer.add_text('metrics/best_metrics', metrics_text.replace('\n', ' \n')) | |
| print(f"\nFinal model metrics - MSE: {final_mse:.6f}, MAE: {final_mae:.6f}, R²: {final_r2:.4f}") | |
| print(f"Best model metrics - MSE: {best_mse:.6f}, MAE: {best_mae:.6f}, R²: {best_r2:.4f}") | |
| # Close tensorboard writer | |
| writer.close() | |