import numpy as np import time, json, os, math import torch import torch.nn as nn import torch.nn.functional as F from tqdm import tqdm import json from torch.utils.tensorboard import SummaryWriter from omegaconf import OmegaConf from lion_pytorch import Lion import glob # For full mesh load and then sample in each training iteration from datasets.DrivAerML.dataset_drivaerml import get_dataloaders, PRESSURE_MEAN, PRESSURE_STD from torch.cuda.amp import GradScaler def save_checkpoint(model, optimizer, scheduler, epoch, best_val_loss, val_MSE_list, cfg, path, accelerator, log_dir=None): """Save a complete training checkpoint.""" if accelerator.is_main_process: rng_state = torch.get_rng_state() cuda_rng_states = None if torch.cuda.is_available(): cuda_rng_states = [] for i in range(torch.cuda.device_count()): cuda_rng_states.append(torch.cuda.get_rng_state(device=i)) checkpoint = { 'epoch': epoch, 'model_state_dict': accelerator.unwrap_model(model).state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'scheduler_state_dict': scheduler.state_dict(), 'best_val_loss': best_val_loss, 'val_MSE_list': val_MSE_list, 'cfg': cfg, 'log_dir': log_dir, 'rng_state': rng_state, 'cuda_rng_states': cuda_rng_states, } checkpoint_path = os.path.join(path, f'checkpoint_epoch_{epoch}.pt') torch.save(checkpoint, checkpoint_path) # Also save as latest checkpoint latest_checkpoint_path = os.path.join(path, 'latest_checkpoint.pt') torch.save(checkpoint, latest_checkpoint_path) print(f"Checkpoint saved at epoch {epoch}") def load_checkpoint(path, model, optimizer, scheduler, accelerator): """Load the latest checkpoint and return training state.""" latest_checkpoint_path = os.path.join(path, 'latest_checkpoint.pt') if not os.path.exists(latest_checkpoint_path): print("No checkpoint found, starting from scratch") return None, 0, 1e5, [], None print(f"Loading checkpoint from {latest_checkpoint_path}") checkpoint = torch.load(latest_checkpoint_path, map_location='cpu') # Load model state unwrapped_model = accelerator.unwrap_model(model) unwrapped_model.load_state_dict(checkpoint['model_state_dict']) # Load optimizer and scheduler states optimizer.load_state_dict(checkpoint['optimizer_state_dict']) scheduler.load_state_dict(checkpoint['scheduler_state_dict']) # Restore random states for reproducibility with error handling try: if 'rng_state' in checkpoint and checkpoint['rng_state'] is not None: torch.set_rng_state(checkpoint['rng_state']) except Exception as e: print(f"Warning: Could not restore CPU RNG state: {e}") try: # Handle both old and new checkpoint formats cuda_rng_key = 'cuda_rng_states' if 'cuda_rng_states' in checkpoint else 'cuda_rng_state' if cuda_rng_key in checkpoint and checkpoint[cuda_rng_key] is not None and torch.cuda.is_available(): cuda_rng_states = checkpoint[cuda_rng_key] if isinstance(cuda_rng_states, list) and len(cuda_rng_states) > 0: # Set RNG state for each device for i, state in enumerate(cuda_rng_states): if i < torch.cuda.device_count() and state is not None: torch.cuda.set_rng_state(state, device=i) except Exception as e: print(f"Warning: Could not restore CUDA RNG state: {e}") start_epoch = checkpoint['epoch'] + 1 best_val_loss = checkpoint['best_val_loss'] val_MSE_list = checkpoint['val_MSE_list'] log_dir = checkpoint.get('log_dir', None) print(f"Resumed from epoch {checkpoint['epoch']}, best val loss: {best_val_loss:.6f}") return checkpoint, start_epoch, best_val_loss, val_MSE_list, log_dir def cleanup_old_checkpoints(path, keep_last=3): """Remove old checkpoint files, keeping only the most recent ones.""" checkpoint_pattern = os.path.join(path, 'checkpoint_epoch_*.pt') checkpoint_files = glob.glob(checkpoint_pattern) if len(checkpoint_files) <= keep_last: return # Sort by modification time and remove oldest checkpoint_files.sort(key=os.path.getmtime) files_to_remove = checkpoint_files[:-keep_last] for file_path in files_to_remove: try: os.remove(file_path) print(f"Removed old checkpoint: {os.path.basename(file_path)}") except OSError: pass def train(model, train_loader, optimizer, scheduler, cfg, accelerator, scaler): model.train() criterion = nn.MSELoss() #Any changes here? losses_press = 0.0 for data in train_loader: optimizer.zero_grad() targets = data['output_feat'] ## If batch > 1 whats the drop in accuracy? More epochs in same time negate that? if cfg.mixed_precision: with torch.autocast(device_type = accelerator.device.type): out = model(data) total_loss = criterion(out, targets) scaler.scale(total_loss).backward() scaler.unscale_(optimizer) if cfg.max_grad_norm is not None: accelerator.clip_grad_norm_(model.parameters(), cfg.max_grad_norm) scaler.step(optimizer) scaler.update() else: out = model(data) total_loss = criterion(out, targets) accelerator.backward(total_loss) if cfg.max_grad_norm is not None: accelerator.clip_grad_norm_(model.parameters(), cfg.max_grad_norm) optimizer.step() # Only step OneCycleLR every batch if cfg.scheduler == "OneCycleLR": scheduler.step() losses_press += total_loss.item() return losses_press / len(train_loader) @torch.no_grad() def val(model, val_loader, cfg, accelerator): model.eval() criterion = nn.MSELoss() losses_press = 0.0 for data in val_loader: targets = data['output_feat'] out = model(data) # Loss computation in FP32 for maximum stability targets = targets.float() # Ensure FP32 out = out.float() # Ensure FP32 total_loss = criterion(out, targets) losses_press += total_loss.item() return losses_press / len(val_loader) def test_model(model, test_dataloader, criterion, path, cfg, accelerator): """Test the model and calculate metrics.""" ## You reported test models everywhere? Complete evaluation? model.eval() total_mse = 0.0 total_mae = 0.0 total_rel_l2 = 0.0 total_rel_l1 = 0.0 total_inference_time = 0.0 num_batches = 0 # Convert normalization constants to appropriate dtype and device pressure_mean = torch.tensor(PRESSURE_MEAN, device=accelerator.device) pressure_std = torch.tensor(PRESSURE_STD, device=accelerator.device) # Store outputs and targets on all processes all_outputs = [] all_targets = [] with torch.no_grad(): for data in tqdm(test_dataloader, desc="[Testing]", disable=not accelerator.is_local_main_process): start_time = time.time() targets = data['output_feat'] targets = targets * pressure_std + pressure_mean if cfg.chunked_eval: input_pos = data['input_pos'] B, N, C = input_pos.shape chunk_size = cfg.num_points outputs = [] for i in range(0, N, chunk_size): data['input_pos'] = input_pos[:, i:i+chunk_size, :] # (B, chunk_size, C) out_chunk = model(data) # (B, chunk_size, 3) outputs.append(out_chunk) outputs = torch.cat(outputs, dim=1) # (B, N, 3) else: outputs = model(data) # Metric computations in FP32 for maximum stability targets = targets.float() # Ensure FP32 outputs = outputs.float() # Ensure FP32 outputs = outputs * pressure_std + pressure_mean inference_time = time.time() - start_time total_inference_time += inference_time # Keep metrics as tensors for proper reduction across processes mse = criterion(outputs, targets) mae = F.l1_loss(outputs, targets) rel_l2 = torch.mean(torch.norm(outputs.squeeze(-1) - targets.squeeze(-1), p=2, dim=-1) / torch.norm(targets.squeeze(-1), p=2, dim=-1)) rel_l1 = torch.mean(torch.norm(outputs.squeeze(-1) - targets.squeeze(-1), p=1, dim=-1) / torch.norm(targets.squeeze(-1), p=1, dim=-1)) total_mse += mse total_mae += mae total_rel_l2 += rel_l2 total_rel_l1 += rel_l1 num_batches += 1 # Store outputs and targets on all processes all_outputs.append(outputs.cpu()) all_targets.append(targets.cpu()) # Clear references to tensors del outputs, targets, mse, mae, rel_l2, rel_l1 # Clear GPU cache after all testing torch.cuda.empty_cache() # Convert to tensors for reduction metrics = { "total_mse": total_mse, "total_mae": total_mae, "total_rel_l2": total_rel_l2, "total_rel_l1": total_rel_l1, "num_batches": torch.tensor(num_batches, device=accelerator.device), "total_inference_time": torch.tensor(total_inference_time, device=accelerator.device) } # Gather metrics from all processes gathered_metrics = accelerator.gather(metrics) # Only calculate averages if we have data if gathered_metrics["num_batches"].sum().item() > 0: total_batches = gathered_metrics["num_batches"].sum().item() avg_mse = gathered_metrics["total_mse"].sum().item() / total_batches avg_mae = gathered_metrics["total_mae"].sum().item() / total_batches avg_rel_l2 = gathered_metrics["total_rel_l2"].sum().item() / total_batches avg_rel_l1 = gathered_metrics["total_rel_l1"].sum().item() / total_batches total_inference_time = gathered_metrics["total_inference_time"].sum().item() avg_inference_time = total_inference_time / total_batches # Gather all outputs and targets from all processes all_outputs = torch.cat(all_outputs, dim=1) all_targets = torch.cat(all_targets, dim=1) # Gather outputs and targets across processes all_outputs = accelerator.gather(all_outputs.to(accelerator.device)) all_targets = accelerator.gather(all_targets.to(accelerator.device)) # Calculate R² score using complete dataset if accelerator.is_main_process: all_outputs = all_outputs.to(torch.float32).cpu().numpy() all_targets = all_targets.to(torch.float32).cpu().numpy() ss_tot = np.sum((all_targets - np.mean(all_targets)) ** 2) ss_res = np.sum((all_targets - all_outputs) ** 2) r_squared = 1 - (ss_res / ss_tot) if ss_tot > 0 else 0 print(f"Test MSE: {avg_mse:.6f}, Test MAE: {avg_mae:.6f}, R²: {r_squared:.4f}") print(f"Relative L2 Error: {avg_rel_l2:.6f}, Relative L1 Error: {avg_rel_l1:.6f}") print(f"Average inference time per batch: {avg_inference_time:.4f}s") print(f"Total inference time: {total_inference_time:.2f}s for {total_batches} batches") # Save metrics to a text file metrics_file = os.path.join(path, 'test_metrics.txt') with open(metrics_file, 'w') as f: f.write(f"Test MSE: {avg_mse:.6f}\n") f.write(f"Test MAE: {avg_mae:.6f}\n") f.write(f"R² Score: {r_squared:.6f}\n") f.write(f"Relative L2 Error: {avg_rel_l2:.6f}\n") f.write(f"Relative L1 Error: {avg_rel_l1:.6f}\n") f.write(f"Average inference time per batch: {avg_inference_time:.4f}s\n") f.write(f"Total inference time: {total_inference_time:.2f}s for {total_batches} batches\n") else: r_squared = 0.0 # Will be overwritten by broadcast else: print("Warning: No data in test_dataloader") avg_mse = avg_mae = avg_rel_l2 = avg_rel_l1 = r_squared = 0.0 return avg_mse, avg_mae, avg_rel_l2, avg_rel_l1, r_squared, avg_inference_time def train_DrivAerML_main(model, path, cfg, accelerator): train_loader, val_loader, test_loader = get_dataloaders(cfg) if accelerator.is_main_process: print( f"Data loaded: {len(train_loader)} training batches, " f"{len(val_loader)} validation batches, " f"{len(test_loader)} test batches") #Select optimizer if cfg.optimizer.type == 'Adam': optimizer = torch.optim.Adam(model.parameters(), lr=cfg.lr, weight_decay=1e-4) elif cfg.optimizer.type == 'AdamW': optimizer = torch.optim.AdamW(model.parameters(), lr=cfg.lr, weight_decay=0.05) elif cfg.optimizer.type == 'LION': optimizer = Lion(model.parameters(), lr=cfg.lr, weight_decay=0.05) #Select scheduler if cfg.scheduler == "ReduceLROnPlateau": scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=10, factor=0.1, verbose=True) elif cfg.scheduler == "LinearWarmupCosineAnnealingLR": warmup_epochs = int(cfg.epochs * 0.05) # Convert back to epochs # Linear warmup scheduler warmup_scheduler = torch.optim.lr_scheduler.LinearLR( optimizer, start_factor=1e-6, # Start very low (almost zero) end_factor=1.0, # End at base lr total_iters=warmup_epochs ) # Cosine decay scheduler cosine_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( optimizer, T_max=cfg.epochs - warmup_epochs, # Remaining epochs eta_min=1e-6 # End at 1e-6 learning rate ) # Combine schedulers scheduler = torch.optim.lr_scheduler.SequentialLR( optimizer, schedulers=[warmup_scheduler, cosine_scheduler], milestones=[warmup_epochs] ) else: scheduler = torch.optim.lr_scheduler.OneCycleLR( optimizer, pct_start=0.05, max_lr=cfg.lr, total_steps = len(train_loader) * cfg.epochs ) scaler = GradScaler() model, optimizer, train_loader, val_loader, test_loader, scheduler, scaler = accelerator.prepare( model, optimizer, train_loader, val_loader, test_loader, scheduler, scaler) criterion = torch.nn.MSELoss() # Try to load checkpoint before evaluation or training checkpoint, start_epoch, best_val_loss, val_MSE_list, resumed_log_dir = load_checkpoint( path, model, optimizer, scheduler, accelerator) if cfg.eval: # For evaluation, try to load from checkpoint first, then fall back to best_model.pt if checkpoint is not None: print("Using model from checkpoint for evaluation") else: # Load the saved state dict and create a fresh model load_path = f'metrics/{cfg.project_name}/{cfg.model}{cfg.test_name}' state_dict = torch.load(os.path.join(load_path, f'best_model.pt')) unwrapped_model = accelerator.unwrap_model(model) unwrapped_model.load_state_dict(state_dict) model = accelerator.prepare(unwrapped_model) best_mse, best_mae, best_rel_l2, best_rel_l1, best_r2, inf_time = test_model(model, test_loader, criterion, path, cfg, accelerator) else: # Calculate total parameters total_params = sum(p.numel() for p in model.parameters()) trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) # Reset memory stats before training torch.cuda.reset_peak_memory_stats() torch.cuda.empty_cache() # Clear any existing cached memory start = time.time() # Only initialize tensorboard on the main process if accelerator.is_main_process: # Create a descriptive run name using model type and timestamp if resumed_log_dir is not None: # Resume logging to the same directory log_dir = resumed_log_dir print(f"Resuming tensorboard logging to: {log_dir}") else: # Create new log directory run_name = f"{cfg.model}{cfg.test_name}_{time.strftime('%Y%m%d_%H%M%S')}" project_name = os.path.join("tensorboard_logs", f"{cfg.project_name}") log_dir = os.path.join(project_name, run_name) print(f"Starting new tensorboard logging to: {log_dir}") writer = SummaryWriter(log_dir) # Add full config (only if starting fresh) if checkpoint is None: config_text = "```yaml\n" # Using yaml format for better readability config_text += OmegaConf.to_yaml(cfg) config_text += "```" writer.add_text('hyperparameters/full_config', config_text) pbar_train = tqdm(range(start_epoch, cfg.epochs), position=0) pbar_train.set_description(f"Training (resumed from epoch {start_epoch})" if checkpoint else "Training") else: writer = None log_dir = None pbar_train = tqdm(range(start_epoch, cfg.epochs), position=0) # Add checkpoint saving frequency to config (default every 10 epochs) checkpoint_freq = getattr(cfg, 'checkpoint_freq', 10) for epoch in pbar_train: train_loss = train(model, train_loader, optimizer, scheduler, cfg, accelerator, scaler) if cfg.val_iter is not None and (epoch == cfg.epochs - 1 or epoch % cfg.val_iter == 0): val_loss_MSE = val(model, val_loader, cfg, accelerator) if cfg.scheduler == "ReduceLROnPlateau": scheduler.step(val_loss_MSE) elif cfg.scheduler == "LinearWarmupCosineAnnealingLR": scheduler.step() val_MSE_list.append(val_loss_MSE) if accelerator.is_main_process: # Get peak GPU memory in GB peak_mem_gb = torch.cuda.max_memory_allocated() / (1024 * 1024 * 1024) # Log metrics to tensorboard writer.add_scalar('Loss/train_MSE', train_loss, epoch) writer.add_scalar('Loss/val_MSE', val_loss_MSE, epoch) writer.add_scalar('Learning_rate', scheduler.get_last_lr()[0], epoch) writer.add_scalar('Memory/GPU', peak_mem_gb, epoch) with open(os.path.join(path,'MSE.json'), 'w') as f: json.dump(val_MSE_list, f, indent=2) pbar_train.set_postfix({ 'train_loss': train_loss, 'val_loss': val_loss_MSE, 'lr': scheduler.get_last_lr()[0], 'mem_gb': f'{peak_mem_gb:.1f}' }) if val_loss_MSE < best_val_loss: best_val_loss = val_loss_MSE unwrapped_model = accelerator.unwrap_model(model) torch.save(unwrapped_model.state_dict(), os.path.join(path, f'best_model.pt')) elif accelerator.is_main_process: # Simple progress display without validation peak_mem_gb = torch.cuda.max_memory_allocated() / (1024 * 1024 * 1024) pbar_train.set_postfix({ 'train_loss': train_loss, 'mem_gb': f'{peak_mem_gb:.1f}' }) # Save checkpoint periodically if accelerator.is_main_process and (epoch % checkpoint_freq == 0 or epoch == cfg.epochs - 1): save_checkpoint(model, optimizer, scheduler, epoch, best_val_loss, val_MSE_list, cfg, path, accelerator, log_dir) # Clean up old checkpoints to save disk space cleanup_old_checkpoints(path, keep_last=3) end = time.time() time_elapsed = end - start # Get final peak memory for reporting if accelerator.is_main_process: peak_mem_gb = torch.cuda.max_memory_allocated() / (1024 * 1024 * 1024) # Reset memory stats before final evaluation torch.cuda.reset_peak_memory_stats() torch.cuda.empty_cache() # Test final model final_mse, final_mae, final_rel_l2, final_rel_l1, final_r2, inf_time = test_model( model, test_loader, criterion, path, cfg, accelerator) # Get peak memory during testing if accelerator.is_main_process: test_peak_mem_gb = torch.cuda.max_memory_allocated() / (1024 * 1024 * 1024) # Create metrics text for final model metrics_text = f"Test MSE: {final_mse:.6f}\n" metrics_text += f"Test MAE: {final_mae:.6f}\n" metrics_text += f"Test RelL1: {final_rel_l1:.6f}\n" metrics_text += f"Test RelL2: {final_rel_l2:.6f}\n" metrics_text += f"Test R2: {final_r2:.6f}\n" metrics_text += f"Inference time: {inf_time:.6f}s\n" metrics_text += f"Total training time: {time_elapsed:.2f}s\n" metrics_text += f"Average epoch time: {time_elapsed/cfg.epochs:.2f}s\n" metrics_text += f"Total parameters: {total_params}\n" metrics_text += f"Trainable parameters: {trainable_params}\n" metrics_text += f"Peak GPU memory usage:\n" metrics_text += f" - During training: {peak_mem_gb:.1f} GB\n" metrics_text += f" - During testing: {test_peak_mem_gb:.1f} GB\n" # Write to file and add to tensorboard metrics_file = os.path.join(path, 'final_test_metrics.txt') with open(metrics_file, 'w') as f: f.write(metrics_text) # Add final metrics to tensorboard as text (replace \n with markdown line break) writer.add_text('metrics/final_metrics', metrics_text.replace('\n', ' \n')) # Load the best model using state_dict for compatibility state_dict = torch.load(os.path.join(path, f'best_model.pt')) unwrapped_model = accelerator.unwrap_model(model) unwrapped_model.load_state_dict(state_dict) best_model = accelerator.prepare(unwrapped_model) best_mse, best_mae, best_rel_l2, best_rel_l1, best_r2, inf_time = test_model( best_model, test_loader, criterion, path, cfg, accelerator) if accelerator.is_main_process: # Create metrics text for best model metrics_text = f"Test MSE: {best_mse:.6f}\n" metrics_text += f"Test MAE: {best_mae:.6f}\n" metrics_text += f"Test RelL1: {best_rel_l1:.6f}\n" metrics_text += f"Test RelL2: {best_rel_l2:.6f}\n" metrics_text += f"Test R2: {best_r2:.6f}\n" metrics_text += f"Inference time: {inf_time:.6f}s\n" metrics_text += f"Total training time: {time_elapsed:.2f}s\n" metrics_text += f"Average epoch time: {time_elapsed/cfg.epochs:.2f}s\n" metrics_text += f"Total parameters: {total_params}\n" metrics_text += f"Trainable parameters: {trainable_params}\n" metrics_text += f"Peak GPU memory usage:\n" metrics_text += f" - During training: {peak_mem_gb:.1f} GB\n" metrics_text += f" - During testing: {test_peak_mem_gb:.1f} GB\n" # Write to file and add to tensorboard metrics_file = os.path.join(path, 'best_test_metrics.txt') with open(metrics_file, 'w') as f: f.write(metrics_text) # Add best metrics to tensorboard as text (replace \n with markdown line break) writer.add_text('metrics/best_metrics', metrics_text.replace('\n', ' \n')) print(f"\nFinal model metrics - MSE: {final_mse:.6f}, MAE: {final_mae:.6f}, R²: {final_r2:.4f}") print(f"Best model metrics - MSE: {best_mse:.6f}, MAE: {best_mae:.6f}, R²: {best_r2:.4f}") # Save final checkpoint save_checkpoint(model, optimizer, scheduler, cfg.epochs - 1, best_val_loss, val_MSE_list, cfg, path, accelerator, log_dir) # Close tensorboard writer writer.close()