Spaces:
Sleeping
Sleeping
| import numpy as np | |
| import time, json, os, math | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from tqdm import tqdm | |
| import json | |
| from torch.utils.tensorboard import SummaryWriter | |
| from omegaconf import OmegaConf | |
| from lion_pytorch import Lion | |
| import glob | |
| # For full mesh load and then sample in each training iteration | |
| from datasets.plane.dataset_plane_engine_test import get_dataloaders | |
| from torch.cuda.amp import GradScaler | |
| import re | |
| import glob | |
| def save_checkpoint(model, optimizer, scheduler, epoch, best_val_loss, val_MSE_list, | |
| cfg, path, accelerator, log_dir=None): | |
| """Save a complete training checkpoint.""" | |
| if accelerator.is_main_process: | |
| rng_state = torch.get_rng_state() | |
| cuda_rng_states = None | |
| if torch.cuda.is_available(): | |
| cuda_rng_states = [] | |
| for i in range(torch.cuda.device_count()): | |
| cuda_rng_states.append(torch.cuda.get_rng_state(device=i)) | |
| checkpoint = { | |
| 'epoch': epoch, | |
| 'model_state_dict': accelerator.unwrap_model(model).state_dict(), | |
| 'optimizer_state_dict': optimizer.state_dict(), | |
| 'scheduler_state_dict': scheduler.state_dict(), | |
| 'best_val_loss': best_val_loss, | |
| 'val_MSE_list': val_MSE_list, | |
| 'cfg': cfg, | |
| 'log_dir': log_dir, | |
| 'rng_state': rng_state, | |
| 'cuda_rng_states': cuda_rng_states, | |
| } | |
| checkpoint_path = os.path.join(path, f'checkpoint_epoch_{epoch}.pt') | |
| torch.save(checkpoint, checkpoint_path) | |
| # Also save as latest checkpoint | |
| latest_checkpoint_path = os.path.join(path, 'latest_checkpoint.pt') | |
| torch.save(checkpoint, latest_checkpoint_path) | |
| print(f"Checkpoint saved at epoch {epoch}") | |
| def load_checkpoint(path, model, optimizer, scheduler, accelerator): | |
| """Load the latest checkpoint and return training state.""" | |
| latest_checkpoint_path = os.path.join(path, 'latest_checkpoint.pt') | |
| if not os.path.exists(latest_checkpoint_path): | |
| print("No checkpoint found, starting from scratch") | |
| return None, 0, 1e5, [], None | |
| print(f"Loading checkpoint from {latest_checkpoint_path}") | |
| checkpoint = torch.load(latest_checkpoint_path, map_location='cpu') | |
| # Load model state | |
| unwrapped_model = accelerator.unwrap_model(model) | |
| unwrapped_model.load_state_dict(checkpoint['model_state_dict']) | |
| # Load optimizer and scheduler states | |
| optimizer.load_state_dict(checkpoint['optimizer_state_dict']) | |
| scheduler.load_state_dict(checkpoint['scheduler_state_dict']) | |
| # Restore random states for reproducibility with error handling | |
| try: | |
| if 'rng_state' in checkpoint and checkpoint['rng_state'] is not None: | |
| torch.set_rng_state(checkpoint['rng_state']) | |
| except Exception as e: | |
| print(f"Warning: Could not restore CPU RNG state: {e}") | |
| try: | |
| # Handle both old and new checkpoint formats | |
| cuda_rng_key = 'cuda_rng_states' if 'cuda_rng_states' in checkpoint else 'cuda_rng_state' | |
| if cuda_rng_key in checkpoint and checkpoint[cuda_rng_key] is not None and torch.cuda.is_available(): | |
| cuda_rng_states = checkpoint[cuda_rng_key] | |
| if isinstance(cuda_rng_states, list) and len(cuda_rng_states) > 0: | |
| # Set RNG state for each device | |
| for i, state in enumerate(cuda_rng_states): | |
| if i < torch.cuda.device_count() and state is not None: | |
| torch.cuda.set_rng_state(state, device=i) | |
| except Exception as e: | |
| print(f"Warning: Could not restore CUDA RNG state: {e}") | |
| start_epoch = checkpoint['epoch'] + 1 | |
| best_val_loss = checkpoint['best_val_loss'] | |
| val_MSE_list = checkpoint['val_MSE_list'] | |
| log_dir = checkpoint.get('log_dir', None) | |
| print(f"Resumed from epoch {checkpoint['epoch']}, best val loss: {best_val_loss:.6f}") | |
| return checkpoint, start_epoch, best_val_loss, val_MSE_list, log_dir | |
| def cleanup_old_checkpoints(path, keep_last=3): | |
| """Remove old checkpoint files, keeping only the most recent ones.""" | |
| checkpoint_pattern = os.path.join(path, '*_epoch_*.pt') | |
| checkpoint_files = glob.glob(checkpoint_pattern) | |
| if len(checkpoint_files) <= keep_last: | |
| return | |
| # Sort by modification time and remove oldest | |
| checkpoint_files.sort(key=os.path.getmtime) | |
| files_to_remove = checkpoint_files[:-keep_last] | |
| for file_path in files_to_remove: | |
| try: | |
| os.remove(file_path) | |
| print(f"Removed old checkpoint: {os.path.basename(file_path)}") | |
| except OSError: | |
| pass | |
| def train(model, train_loader, optimizer, scheduler, criterion, cfg, accelerator, scaler): | |
| model.train() | |
| losses_press = 0.0 | |
| for data in train_loader: | |
| optimizer.zero_grad() | |
| targets = data['output_feat'] | |
| if cfg.mixed_precision: | |
| with torch.autocast(device_type = accelerator.device.type): | |
| out = model(data) | |
| total_loss = criterion(out, targets) | |
| scaler.scale(total_loss).backward() | |
| scaler.unscale_(optimizer) | |
| if cfg.max_grad_norm is not None: | |
| accelerator.clip_grad_norm_(model.parameters(), cfg.max_grad_norm) | |
| scaler.step(optimizer) | |
| scaler.update() | |
| else: | |
| out = model(data) | |
| total_loss = criterion(out, targets) | |
| accelerator.backward(total_loss) | |
| if cfg.max_grad_norm is not None: | |
| accelerator.clip_grad_norm_(model.parameters(), cfg.max_grad_norm) | |
| optimizer.step() | |
| # Only step OneCycleLR every batch | |
| if cfg.scheduler == "OneCycleLR": | |
| scheduler.step() | |
| losses_press += total_loss.item() | |
| return losses_press / len(train_loader) | |
| def val(model, val_loader, criterion, cfg, accelerator): | |
| model.eval() | |
| losses_press = 0.0 | |
| for data in val_loader: | |
| targets = data['output_feat'] | |
| out = model(data) | |
| # Loss computation in FP32 for maximum stability | |
| targets = targets.float() # Ensure FP32 | |
| out = out.float() # Ensure FP32 | |
| total_loss = criterion(out, targets) | |
| losses_press += total_loss.item() | |
| return losses_press / len(val_loader) | |
| def test_model(model, test_dataloader, criterion, path, cfg, accelerator): | |
| """Test the model and calculate metrics.""" ## You reported test models everywhere? Complete evaluation? | |
| model.eval() | |
| total_mse = 0.0 | |
| total_mae = 0.0 | |
| total_huber = 0.0 | |
| total_rel_l2 = 0.0 | |
| total_rel_l1 = 0.0 | |
| total_mse_list = [] | |
| total_mae_list = [] | |
| total_huber_list = [] | |
| total_rel_l2_list = [] | |
| total_rel_l1_list = [] | |
| r_2_squared_list = [] | |
| total_inference_time = 0.0 | |
| num_batches = 0 | |
| if cfg.normalization == "std_norm": | |
| with open(cfg.json_file, 'r') as f: | |
| json_data = json.load(f) | |
| pressure_mean = torch.tensor(json_data["scalars"]["pressure"]["mean"], device=accelerator.device) | |
| pressure_std = torch.tensor(json_data["scalars"]["pressure"]["std"], device=accelerator.device) | |
| # Store outputs and targets on all processes | |
| all_outputs = [] | |
| all_targets = [] | |
| all_physical_coordinates = [] | |
| with torch.no_grad(): | |
| for data in tqdm(test_dataloader, desc="[Testing]", disable=not accelerator.is_local_main_process): | |
| start_time = time.time() | |
| targets = data['output_feat'] | |
| if cfg.chunked_eval: | |
| input_pos = data['input_pos'] | |
| B, N, C = input_pos.shape | |
| chunk_size = cfg.num_points | |
| outputs = [] | |
| for i in range(0, N, chunk_size): | |
| # start with the raw slice | |
| chunk = input_pos[:, i:i+chunk_size, :] # (B, n_valid, C) | |
| n_valid = chunk.shape[1] | |
| # Pad if last chunk is short | |
| if n_valid < chunk_size: | |
| shape_diff = chunk_size - n_valid | |
| # Wrap from the beginning to make a full chunk | |
| pad = input_pos[:, :shape_diff, :] # (B, shape_diff, C) | |
| chunk = torch.cat([chunk, pad], dim=1) # (B, chunk_size, C) | |
| data['input_pos'] = chunk | |
| out_chunk = model(data) # (B, chunk_size, D) | |
| # Keep only the valid part that corresponds to real points | |
| out_chunk = out_chunk[:, :n_valid, :] # (B, n_valid, D) | |
| else: | |
| data['input_pos'] = chunk | |
| out_chunk = model(data) # (B, chunk_size, D) | |
| outputs.append(out_chunk) | |
| outputs = torch.cat(outputs, dim=1) # (B, N, 3) | |
| else: | |
| outputs = model(data) | |
| # Metric computations in FP32 for maximum stability | |
| targets = targets.float() # Ensure FP32 | |
| outputs = outputs.float() # Ensure FP32 | |
| if cfg.physical_scale_for_test == True: | |
| targets[:,:,0] = targets[:,:,0] * pressure_std + pressure_mean | |
| outputs[:,:,0] = outputs[:,:,0] * pressure_std + pressure_mean | |
| inference_time = time.time() - start_time | |
| total_inference_time += inference_time | |
| # Compute all relevant losses and metrics for documentation and analysis | |
| criterion_mse = nn.MSELoss() | |
| criterion_mae = nn.L1Loss() | |
| criterion_huber = nn.HuberLoss(delta=1.0) | |
| mse = criterion_mse(outputs, targets) | |
| mae = criterion_mae(outputs, targets) | |
| huber = criterion_huber(outputs, targets) | |
| # Relative L2 error: mean over batch of (L2 norm of error / L2 norm of target) | |
| rel_l2 = torch.mean(torch.norm(outputs.squeeze(-1) - targets.squeeze(-1), p=2, dim=-1) / | |
| torch.norm(targets.squeeze(-1), p=2, dim=-1)) | |
| # Relative L1 error: mean over batch of (L1 norm of error / L1 norm of target) | |
| rel_l1 = torch.mean(torch.norm(outputs.squeeze(-1) - targets.squeeze(-1), p=1, dim=-1) / | |
| torch.norm(targets.squeeze(-1), p=1, dim=-1)) | |
| ss_tot = torch.sum((outputs - torch.mean(targets)) ** 2) | |
| ss_res = torch.sum((targets - outputs) ** 2) | |
| r_squared = 1 - (ss_res / ss_tot) if ss_tot > 0 else 0 | |
| total_mse_list.append(mse.item()) | |
| total_mae_list.append(mae.item()) | |
| total_huber_list.append(huber.item()) | |
| total_rel_l2_list.append(rel_l2.item()) | |
| total_rel_l1_list.append(rel_l1.item()) | |
| r_2_squared_list.append(r_squared.item()) | |
| total_mse += mse | |
| total_mae += mae | |
| total_huber += huber | |
| total_rel_l2 += rel_l2 | |
| total_rel_l1 += rel_l1 | |
| num_batches += 1 | |
| # Store outputs and targets on all processes for later aggregation and R² computation | |
| all_outputs.append(outputs.cpu()) | |
| all_targets.append(targets.cpu()) | |
| all_physical_coordinates.append(data['physical_coordinates'].cpu()) | |
| path_vtk = path + "/vtk_files" | |
| # Save VTK files for each plane (if any data exists) | |
| if len(all_outputs) > 0 and len(all_targets) > 0 and len(all_physical_coordinates) > 0: | |
| try: | |
| from utils.vtk_writer import vtk_writer | |
| vtk_writer(outputs, targets, data['physical_coordinates'].cpu(), path_vtk, prefix=data["data_id"][0], config_json_path=cfg.json_file) | |
| except Exception as e: | |
| print(f"[Warning] Could not save VTK files: {e}") | |
| # Clear references to tensors | |
| del outputs, targets, mse, mae, huber, rel_l2, rel_l1 | |
| metrics_list = { | |
| "total_mse_list": total_mse_list, | |
| "total_mae_list": total_mae_list, | |
| "total_huber_list": total_huber_list, | |
| "total_rel_l2_list": total_rel_l2_list, | |
| "total_rel_l1_list": total_rel_l1_list, | |
| "r_2_squared_list": r_2_squared_list, | |
| } | |
| # Save metrics_list as a JSON file for per-batch analysis | |
| metrics_list_file = os.path.join(path, 'test_metrics_list.txt') | |
| with open(metrics_list_file, 'w') as f: | |
| json.dump(metrics_list, f, indent=2) | |
| # Convert to tensors for reduction | |
| metrics = { | |
| "total_mse": total_mse, | |
| "total_mae": total_mae, | |
| "total_huber": total_huber, | |
| "total_rel_l2": total_rel_l2, | |
| "total_rel_l1": total_rel_l1, | |
| "num_batches": torch.tensor(num_batches, device=accelerator.device), | |
| "total_inference_time": torch.tensor(total_inference_time, device=accelerator.device) | |
| } | |
| # Gather metrics from all processes | |
| gathered_metrics = accelerator.gather(metrics) | |
| # Only calculate averages if we have data | |
| if gathered_metrics["num_batches"].sum().item() > 0: | |
| total_batches = gathered_metrics["num_batches"].sum().item() | |
| avg_mse = gathered_metrics["total_mse"].sum().item() / total_batches | |
| avg_mae = gathered_metrics["total_mae"].sum().item() / total_batches | |
| avg_huber = gathered_metrics["total_huber"].sum().item() / total_batches | |
| avg_rel_l2 = gathered_metrics["total_rel_l2"].sum().item() / total_batches | |
| avg_rel_l1 = gathered_metrics["total_rel_l1"].sum().item() / total_batches | |
| total_inference_time = gathered_metrics["total_inference_time"].sum().item() | |
| avg_inference_time = total_inference_time / total_batches | |
| # Gather all outputs and targets from all processes | |
| all_outputs = torch.cat(all_outputs, dim=1) | |
| all_targets = torch.cat(all_targets, dim=1) | |
| # Gather outputs and targets across processes | |
| all_outputs = accelerator.gather(all_outputs.to(accelerator.device)) | |
| all_targets = accelerator.gather(all_targets.to(accelerator.device)) | |
| # Calculate R² score using complete dataset | |
| if accelerator.is_main_process: | |
| all_outputs = all_outputs.to(torch.float32).cpu().numpy() | |
| all_targets = all_targets.to(torch.float32).cpu().numpy() | |
| ss_tot = np.sum((all_targets - np.mean(all_targets)) ** 2) | |
| ss_res = np.sum((all_targets - all_outputs) ** 2) | |
| r_squared = 1 - (ss_res / ss_tot) if ss_tot > 0 else 0 | |
| print(f"Test MSE: {avg_mse:.6f}, Test MAE: {avg_mae:.6f}, Test Huber: {avg_huber:.6f}, R²: {r_squared:.4f}") | |
| print(f"Relative L2 Error: {avg_rel_l2:.6f}, Relative L1 Error: {avg_rel_l1:.6f}") | |
| print(f"Average inference time per batch: {avg_inference_time:.4f}s") | |
| print(f"Total inference time: {total_inference_time:.2f}s for {total_batches} batches") | |
| # Save metrics to a text file | |
| metrics_file = os.path.join(path, 'test_metrics.txt') | |
| with open(metrics_file, 'w') as f: | |
| f.write(f"Test MSE: {avg_mse:.6f}\n") | |
| f.write(f"Test MAE: {avg_mae:.6f}\n") | |
| f.write(f"Test Huber: {avg_huber:.6f}\n") | |
| f.write(f"R2 Score: {r_squared:.6f}\n") | |
| f.write(f"Relative L2 Error: {avg_rel_l2:.6f}\n") | |
| f.write(f"Relative L1 Error: {avg_rel_l1:.6f}\n") | |
| f.write(f"Average inference time per batch: {avg_inference_time:.4f}s\n") | |
| f.write(f"Total inference time: {total_inference_time:.2f}s for {total_batches} batches\n") | |
| # Save outputs and targets as .npy files for further analysis | |
| #np.save(os.path.join(path, 'test_outputs.npy'), all_outputs) | |
| #np.save(os.path.join(path, 'test_targets.npy'), all_targets) | |
| else: | |
| r_squared = 0.0 # Will be overwritten by broadcast | |
| else: | |
| print("Warning: No data in test_dataloader") | |
| avg_mse = avg_mae = avg_huber = avg_rel_l2 = avg_rel_l1 = r_squared = 0.0 | |
| # Clear GPU cache after all testing | |
| torch.cuda.empty_cache() | |
| return avg_mse, avg_mae, avg_huber, avg_rel_l2, avg_rel_l1, r_squared, avg_inference_time | |
| def train_plane_engine_test_main(model, path, cfg, accelerator): | |
| train_loader, val_loader, test_loader = get_dataloaders(cfg) | |
| if accelerator.is_main_process: | |
| print( | |
| f"Data loaded: {len(train_loader)} training batches, " | |
| f"{len(val_loader)} validation batches, " | |
| f"{len(test_loader)} test batches") | |
| #Select optimizer | |
| if cfg.optimizer.type == 'Adam': | |
| optimizer = torch.optim.Adam(model.parameters(), lr=cfg.lr, weight_decay=1e-4) | |
| elif cfg.optimizer.type == 'AdamW': | |
| optimizer = torch.optim.AdamW(model.parameters(), lr=cfg.lr, weight_decay=0.05) | |
| elif cfg.optimizer.type == 'LION': | |
| optimizer = Lion(model.parameters(), lr=cfg.lr, weight_decay=0.05) | |
| #Select scheduler | |
| if cfg.scheduler == "ReduceLROnPlateau": | |
| scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=10, factor=0.1, verbose=True) | |
| elif cfg.scheduler == "LinearWarmupCosineAnnealingLR": | |
| warmup_epochs = int(cfg.epochs * 0.05) # Convert back to epochs | |
| # Linear warmup scheduler | |
| warmup_scheduler = torch.optim.lr_scheduler.LinearLR( | |
| optimizer, | |
| start_factor=1e-6, # Start very low (almost zero) | |
| end_factor=1.0, # End at base lr | |
| total_iters=warmup_epochs | |
| ) | |
| # Cosine decay scheduler | |
| cosine_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( | |
| optimizer, | |
| T_max=cfg.epochs - warmup_epochs, # Remaining epochs | |
| eta_min=1e-6 # End at 1e-6 learning rate | |
| ) | |
| # Combine schedulers | |
| scheduler = torch.optim.lr_scheduler.SequentialLR( | |
| optimizer, | |
| schedulers=[warmup_scheduler, cosine_scheduler], | |
| milestones=[warmup_epochs] | |
| ) | |
| else: | |
| scheduler = torch.optim.lr_scheduler.OneCycleLR( | |
| optimizer, | |
| pct_start=0.05, | |
| max_lr=cfg.lr, | |
| total_steps = len(train_loader) * cfg.epochs | |
| ) | |
| if cfg.loss_type == "mse": | |
| criterion = nn.MSELoss() | |
| elif cfg.loss_type == "mae": | |
| criterion = nn.L1Loss() | |
| elif cfg.loss_type == "huber": | |
| criterion = nn.HuberLoss(delta=1.0) | |
| else: | |
| raise ValueError(f"Unknown loss_type: {cfg.loss_type}") | |
| scaler = GradScaler() | |
| model, optimizer, train_loader, val_loader, test_loader, scheduler, scaler = accelerator.prepare( | |
| model, optimizer, train_loader, val_loader, test_loader, scheduler, scaler) | |
| best_epoch = 0 | |
| # Try to load checkpoint before evaluation or training | |
| checkpoint, start_epoch, best_val_loss, val_MSE_list, resumed_log_dir = load_checkpoint( | |
| path, model, optimizer, scheduler, accelerator) | |
| if cfg.eval: | |
| # For evaluation, try to load from checkpoint first, then fall back to best_model.pt | |
| if (cfg.train_ckpt_load): | |
| print("Using model from checkpoint for evaluation") | |
| else: | |
| # Load the saved state dict and create a fresh model | |
| load_path = f'metrics/{cfg.project_name}/{cfg.model}_{cfg.test_name}' | |
| # Find all best_model_epoch_*.pt files and get the epoch number from the last one | |
| pattern = os.path.join(os.getcwd(), load_path.lstrip('/'), 'best_case', 'best_model_epoch_*.pt') | |
| best_model_files = glob.glob(pattern) | |
| if not best_model_files: | |
| raise FileNotFoundError(f"No best_model_epoch_*.pt files found in {os.path.join(load_path, 'best_case')}") | |
| # Extract epoch numbers | |
| epoch_numbers = [] | |
| for fname in best_model_files: | |
| match = re.search(r'best_model_epoch_(\d+)\.pt', os.path.basename(fname)) | |
| if match: | |
| epoch_numbers.append(int(match.group(1))) | |
| if not epoch_numbers: | |
| raise ValueError("No epoch numbers found in best_model_epoch_*.pt filenames") | |
| last_best_epoch = max(epoch_numbers) | |
| state_dict = torch.load(os.path.join(load_path, 'best_case', f'best_model_epoch_{last_best_epoch}.pt')) | |
| unwrapped_model = accelerator.unwrap_model(model) | |
| unwrapped_model.load_state_dict(state_dict) | |
| model = accelerator.prepare(unwrapped_model) | |
| path = os.path.join(path, "best_case") # Update path to point to best model directory | |
| print("Using best model for evaluation at epoch", last_best_epoch) | |
| best_mse, best_mae, best_huber, best_rel_l2, best_rel_l1, best_r2, inf_time = test_model(model, test_loader, criterion, path, cfg, accelerator) | |
| else: | |
| # Calculate total parameters | |
| total_params = sum(p.numel() for p in model.parameters()) | |
| trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) | |
| # Reset memory stats before training | |
| torch.cuda.reset_peak_memory_stats() | |
| torch.cuda.empty_cache() # Clear any existing cached memory | |
| start = time.time() | |
| # Only initialize tensorboard on the main process | |
| if accelerator.is_main_process: | |
| # Create a descriptive run name using model type and timestamp | |
| if resumed_log_dir is not None: | |
| # Resume logging to the same directory | |
| log_dir = resumed_log_dir | |
| print(f"Resuming tensorboard logging to: {log_dir}") | |
| else: | |
| # Create new log directory | |
| run_name = f"{cfg.model}_{cfg.test_name}_{time.strftime('%Y%m%d_%H%M%S')}" | |
| project_name = os.path.join("tensorboard_logs", f"{cfg.project_name}") | |
| log_dir = os.path.join(project_name, run_name) | |
| print(f"Starting new tensorboard logging to: {log_dir}") | |
| writer = SummaryWriter(log_dir) | |
| # Add full config (only if starting fresh) | |
| if checkpoint is None: | |
| config_text = "```yaml\n" # Using yaml format for better readability | |
| config_text += OmegaConf.to_yaml(cfg) | |
| config_text += "```" | |
| writer.add_text('hyperparameters/full_config', config_text) | |
| pbar_train = tqdm(range(start_epoch, cfg.epochs), position=0) | |
| pbar_train.set_description(f"Training (resumed from epoch {start_epoch})" if checkpoint else "Training") | |
| else: | |
| writer = None | |
| log_dir = None | |
| pbar_train = tqdm(range(start_epoch, cfg.epochs), position=0) | |
| # Add checkpoint saving frequency to config (default every 10 epochs) | |
| checkpoint_freq = getattr(cfg, 'checkpoint_freq', 10) | |
| for epoch in pbar_train: | |
| train_loss = train(model, train_loader, optimizer, scheduler, criterion, cfg, accelerator, scaler) | |
| if cfg.val_iter is not None and (epoch == cfg.epochs - 1 or epoch % cfg.val_iter == 0): | |
| val_loss_MSE = val(model, val_loader,criterion, cfg, accelerator) | |
| if cfg.scheduler == "ReduceLROnPlateau": | |
| scheduler.step(val_loss_MSE) | |
| elif cfg.scheduler == "LinearWarmupCosineAnnealingLR": | |
| scheduler.step() | |
| val_MSE_list.append(val_loss_MSE) | |
| if accelerator.is_main_process: | |
| # Get peak GPU memory in GB | |
| peak_mem_gb = torch.cuda.max_memory_allocated() / (1024 * 1024 * 1024) | |
| # Log metrics to tensorboard | |
| writer.add_scalar('Loss/train_MSE', train_loss, epoch) | |
| writer.add_scalar('Loss/val_MSE', val_loss_MSE, epoch) | |
| writer.add_scalar('Learning_rate', scheduler.get_last_lr()[0], epoch) | |
| writer.add_scalar('Memory/GPU', peak_mem_gb, epoch) | |
| with open(os.path.join(path,'MSE.json'), 'w') as f: | |
| json.dump(val_MSE_list, f, indent=2) | |
| pbar_train.set_postfix({ | |
| 'train_loss': train_loss, | |
| 'val_loss': val_loss_MSE, | |
| 'lr': scheduler.get_last_lr()[0], | |
| 'mem_gb': f'{peak_mem_gb:.1f}' | |
| }) | |
| if val_loss_MSE < best_val_loss: | |
| best_val_loss = val_loss_MSE | |
| unwrapped_model = accelerator.unwrap_model(model) | |
| os.makedirs(os.path.join(path, 'best_case'), exist_ok=True) | |
| best_epoch = epoch | |
| # Save the best model state_dict | |
| cleanup_old_checkpoints(os.path.join(path, 'best_case'), keep_last=1) | |
| torch.save(unwrapped_model.state_dict(), os.path.join(path, f'best_case/best_model_epoch_{best_epoch}.pt')) | |
| print("saving best model at epoch", epoch) | |
| elif accelerator.is_main_process: | |
| # Simple progress display without validation | |
| peak_mem_gb = torch.cuda.max_memory_allocated() / (1024 * 1024 * 1024) | |
| pbar_train.set_postfix({ | |
| 'train_loss': train_loss, | |
| 'mem_gb': f'{peak_mem_gb:.1f}' | |
| }) | |
| # Save checkpoint periodically | |
| if accelerator.is_main_process and (epoch % checkpoint_freq == 0 or epoch == cfg.epochs - 1): | |
| save_checkpoint(model, optimizer, scheduler, epoch, best_val_loss, val_MSE_list, | |
| cfg, path, accelerator, log_dir) | |
| # Clean up old checkpoints to save disk space | |
| cleanup_old_checkpoints(path, keep_last=3) | |
| end = time.time() | |
| time_elapsed = end - start | |
| # Get final peak memory for reporting | |
| if accelerator.is_main_process: | |
| peak_mem_gb = torch.cuda.max_memory_allocated() / (1024 * 1024 * 1024) | |
| # Reset memory stats before final evaluation | |
| torch.cuda.reset_peak_memory_stats() | |
| torch.cuda.empty_cache() | |
| # Save final checkpoint BEFORE loading best model for evaluation | |
| if accelerator.is_main_process: | |
| save_checkpoint(model, optimizer, scheduler, cfg.epochs - 1, best_val_loss, val_MSE_list, | |
| cfg, path, accelerator, log_dir) | |
| # Test final model (last epoch) | |
| final_mse, final_mae, final_huber, final_rel_l2, final_rel_l1, final_r2, inf_time = test_model( | |
| model, test_loader, criterion, path, cfg, accelerator) | |
| # Get peak memory during testing | |
| if accelerator.is_main_process: | |
| test_peak_mem_gb = torch.cuda.max_memory_allocated() / (1024 * 1024 * 1024) | |
| # Create metrics text for final model | |
| metrics_text = f"Test MSE: {final_mse:.6f}\n" | |
| metrics_text += f"Test MAE: {final_mae:.6f}\n" | |
| metrics_text += f"Test Huber: {final_huber:.6f}\n" | |
| metrics_text += f"Test RelL1: {final_rel_l1:.6f}\n" | |
| metrics_text += f"Test RelL2: {final_rel_l2:.6f}\n" | |
| metrics_text += f"Test R2: {final_r2:.6f}\n" | |
| metrics_text += f"Inference time: {inf_time:.6f}s\n" | |
| metrics_text += f"Total training time: {time_elapsed:.2f}s\n" | |
| metrics_text += f"Average epoch time: {time_elapsed/cfg.epochs:.2f}s\n" | |
| metrics_text += f"Total parameters: {total_params}\n" | |
| metrics_text += f"Trainable parameters: {trainable_params}\n" | |
| metrics_text += f"Peak GPU memory usage:\n" | |
| metrics_text += f" - During training: {peak_mem_gb:.1f} GB\n" | |
| metrics_text += f" - During testing: {test_peak_mem_gb:.1f} GB\n" | |
| # Write to file and add to tensorboard | |
| metrics_file = os.path.join(path, 'final_test_metrics.txt') | |
| with open(metrics_file, 'w') as f: | |
| f.write(metrics_text) | |
| # Add final metrics to tensorboard as text (replace \n with markdown line break) | |
| writer.add_text('metrics/final_metrics', metrics_text.replace('\n', ' \n')) | |
| # --- Log per-batch test metrics for final model --- | |
| metrics_list_file = os.path.join(path, 'test_metrics_list.txt') | |
| if os.path.exists(metrics_list_file): | |
| with open(metrics_list_file, 'r') as f: | |
| metrics_list = json.load(f) | |
| for metric_name, values in metrics_list.items(): | |
| for i, v in enumerate(values): | |
| writer.add_scalar(f'per_batch_test_metrics/final/{metric_name}', v, i) | |
| # Load the best model using state_dict for compatibility (into a separate model instance) | |
| from copy import deepcopy | |
| best_model = deepcopy(model) | |
| state_dict = torch.load(os.path.join(path, f'best_case/best_model_epoch_{best_epoch}.pt')) | |
| unwrapped_best_model = accelerator.unwrap_model(best_model) | |
| unwrapped_best_model.load_state_dict(state_dict) | |
| best_model = accelerator.prepare(unwrapped_best_model) | |
| # Test the best model | |
| path_best = os.path.join(path, 'best_case') # Do not overwrite path for final model logging | |
| best_mse, best_mae, best_huber, best_rel_l2, best_rel_l1, best_r2, inf_time = test_model( | |
| best_model, test_loader, criterion, path_best, cfg, accelerator) | |
| if accelerator.is_main_process: | |
| # Create metrics text for best model | |
| metrics_text = f"Test MSE: {best_mse:.6f}\n" | |
| metrics_text += f"Test MAE: {best_mae:.6f}\n" | |
| metrics_text += f"Test Huber: {best_huber:.6f}\n" | |
| metrics_text += f"Test RelL1: {best_rel_l1:.6f}\n" | |
| metrics_text += f"Test RelL2: {best_rel_l2:.6f}\n" | |
| metrics_text += f"Test R2: {best_r2:.6f}\n" | |
| metrics_text += f"Inference time: {inf_time:.6f}s\n" | |
| metrics_text += f"Total training time: {time_elapsed:.2f}s\n" | |
| metrics_text += f"Average epoch time: {time_elapsed/cfg.epochs:.2f}s\n" | |
| metrics_text += f"Total parameters: {total_params}\n" | |
| metrics_text += f"Trainable parameters: {trainable_params}\n" | |
| metrics_text += f"Peak GPU memory usage:\n" | |
| metrics_text += f" - During training: {peak_mem_gb:.1f} GB\n" | |
| metrics_text += f" - During testing: {test_peak_mem_gb:.1f} GB\n" | |
| # Write to file and add to tensorboard | |
| metrics_file = os.path.join(path_best, 'best_test_metrics.txt') | |
| with open(metrics_file, 'w') as f: | |
| f.write(metrics_text) | |
| # Add best metrics to tensorboard as text (replace \n with markdown line break) | |
| writer.add_text('metrics/best_metrics', metrics_text.replace('\n', ' \n')) | |
| # --- Log per-batch test metrics for best-case model --- | |
| metrics_list_file = os.path.join(path_best, 'test_metrics_list.txt') | |
| if os.path.exists(metrics_list_file): | |
| with open(metrics_list_file, 'r') as f: | |
| metrics_list = json.load(f) | |
| for metric_name, values in metrics_list.items(): | |
| for i, v in enumerate(values): | |
| writer.add_scalar(f'per_batch_test_metrics/best/{metric_name}', v, i) | |
| print(f"\nFinal model metrics - MSE: {final_mse:.6f}, MAE: {final_mae:.6f}, huber: {final_huber:.6f}, R²: {final_r2:.4f}") | |
| print(f"Best model metrics - MSE: {best_mse:.6f}, MAE: {best_mae:.6f}, huber: {best_huber:.6f}, R²: {best_r2:.4f}") | |
| # Close tensorboard writer | |
| writer.close() |