import numpy as np import time, json, os import torch import torch.nn as nn from lion_pytorch import Lion # from torch_geometric.loader import DataLoader from torch.utils.data import DataLoader from tqdm import tqdm import json from torch.utils.tensorboard import SummaryWriter import datasets from omegaconf import OmegaConf from datasets.elasticity.dataset_elasticity import get_dataloaders def train(model, train_loader, optimizer, scheduler, cfg, accelerator): model.train() model_dtype = next(model.parameters()).dtype train_loader.dataset.y_normalizer.mean = train_loader.dataset.y_normalizer.mean.to(model_dtype) train_loader.dataset.y_normalizer.std = train_loader.dataset.y_normalizer.std.to(model_dtype) criterion = TestLoss(size_average=True) losses_field = 0.0 for data in train_loader: optimizer.zero_grad() field = data['output_feat'].to(dtype=model_dtype) out = model(data) out = train_loader.dataset.y_normalizer.decode(out) targets = train_loader.dataset.y_normalizer.decode(field) total_loss = criterion(out.squeeze(-1), targets.squeeze(-1)) accelerator.backward(total_loss) if cfg.max_grad_norm is not None: torch.nn.utils.clip_grad_norm_(model.parameters(), cfg.max_grad_norm) optimizer.step() if cfg.scheduler == "OneCycleLR": scheduler.step() losses_field += total_loss.item() return losses_field / len(train_loader) # def RelL2loss(x,y): # """Calculate relative L2 error: mean(||x-y||_2 / ||y||_2) over the batch. # Args: # x: Predicted values (B, N, C) # y: Target values (B, N, C) # Returns: # Mean relative L2 error across the batch # """ # # Make inputs (B, N) # batch_size = x.size(0) # x_flat = x.view(batch_size, -1) # (B, N) # y_flat = y.view(batch_size, -1) # (B, N) # # Calculate L2 norm for each sample in batch # diff_norms = torch.norm(x_flat - y_flat, p=2, dim=1) # (B,) # y_norms = torch.norm(y_flat, p=2, dim=1) # (B,) # # Calculate RelL2 for each sample and take mean over batch # return torch.mean(diff_norms / y_norms) class TestLoss(object): def __init__(self, d=2, p=2, size_average=True, reduction=True): super(TestLoss, self).__init__() assert d > 0 and p > 0 self.d = d self.p = p self.reduction = reduction self.size_average = size_average def abs(self, x, y): num_examples = x.size()[0] h = 1.0 / (x.size()[1] - 1.0) all_norms = (h ** (self.d / self.p)) * torch.norm(x.view(num_examples, -1) - y.view(num_examples, -1), self.p, 1) if self.reduction: if self.size_average: return torch.mean(all_norms) else: return torch.sum(all_norms) return all_norms def rel(self, x, y): num_examples = x.size()[0] diff_norms = torch.norm(x.reshape(num_examples, -1) - y.reshape(num_examples, -1), self.p, 1) y_norms = torch.norm(y.reshape(num_examples, -1), self.p, 1) if self.reduction: if self.size_average: return torch.mean(diff_norms / y_norms) else: return torch.sum(diff_norms / y_norms) return diff_norms / y_norms def __call__(self, x, y): return self.rel(x, y) @torch.no_grad() def test(model, test_loader, y_normalizer, cfg, accelerator): model.eval() total_mse = 0.0 total_mae = 0.0 total_rell2 = 0.0 num_batches = 0 r_squared = 0.0 criterion = TestLoss(size_average=False) model_dtype = next(model.parameters()).dtype y_normalizer.mean = y_normalizer.mean.to(model_dtype) y_normalizer.std = y_normalizer.std.to(model_dtype) # Only store outputs and targets on main process if accelerator.is_main_process: all_outputs = [] all_targets = [] for data in test_loader: field = data['output_feat'].to(dtype=model_dtype) out = model(data) out = y_normalizer.decode(out) targets = y_normalizer.decode(field) # Ensure proper shapes for loss calculation out_flat = out.view(-1, 1) targets_flat = targets.view(-1, 1) # Keep losses as tensors for proper reduction loss_mse = nn.MSELoss()(out_flat, targets_flat) loss_mae = torch.mean(torch.abs(out_flat - targets_flat)) loss_rell2 = criterion(out.squeeze(-1), targets.squeeze(-1)) # loss_rell2 = RelL2loss(out, targets) total_mse += loss_mse total_mae += loss_mae total_rell2 += loss_rell2 num_batches += 1 # Only store outputs and targets on main process if accelerator.is_main_process: all_outputs.append(out.cpu()) all_targets.append(targets.cpu()) # Clear references to tensors del out, targets, out_flat, targets_flat, loss_mse, loss_mae, loss_rell2 # Clear GPU cache after all testing torch.cuda.empty_cache() # Convert to tensors for reduction metrics = { "total_mse": total_mse, "total_mae": total_mae, "total_rell2": total_rell2, "num_batches": torch.tensor(num_batches, device=accelerator.device) } # Gather metrics from all processes gathered_metrics = accelerator.gather(metrics) # Sum up metrics from all processes total_batches = gathered_metrics["num_batches"].sum().item() avg_mse = gathered_metrics["total_mse"].sum().item() / total_batches avg_mae = gathered_metrics["total_mae"].sum().item() / total_batches avg_rell2 = gathered_metrics["total_rell2"].sum().item() / total_batches # Calculate R² score only on main process if accelerator.is_main_process: # Concatenate and reshape all outputs and targets all_outputs = torch.cat(all_outputs, dim=0).to(torch.float32).cpu().numpy().reshape(-1) all_targets = torch.cat(all_targets, dim=0).to(torch.float32).cpu().numpy().reshape(-1) # Calculate R² score ss_tot = np.sum((all_targets - np.mean(all_targets)) ** 2) ss_res = np.sum((all_targets - all_outputs) ** 2) r_squared = 1 - (ss_res / ss_tot) if ss_tot > 0 else 0 # Clear large numpy arrays del all_outputs, all_targets # Broadcast r_squared to all processes r_squared = torch.tensor(r_squared, device=accelerator.device) r_squared = accelerator.reduce(r_squared, reduction="mean") r_squared = r_squared.item() return avg_mse, avg_mae, avg_rell2, r_squared def train_elasticity_main(model, path, cfg, accelerator): train_loader, val_loader = get_dataloaders(cfg) optimizer = torch.optim.AdamW(model.parameters(), lr=cfg.lr, weight_decay=1e-5) if cfg.scheduler == "OneCycleLR": scheduler = torch.optim.lr_scheduler.OneCycleLR( optimizer, pct_start= 0.05, max_lr=cfg.lr, total_steps=len(train_loader) * cfg.epochs ) elif cfg.scheduler == "CosineAnnealingLR": scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=cfg.epochs) elif cfg.scheduler == "CosineAnnealingWarmRestarts": scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=int(cfg.epochs*0.05)) elif cfg.scheduler == "LinearWarmupCosineAnnealingLR": warmup_epochs = int(cfg.epochs * 0.05) # Convert back to epochs # Linear warmup scheduler warmup_scheduler = torch.optim.lr_scheduler.LinearLR( optimizer, start_factor=1e-6, # Start very low (almost zero) end_factor=1.0, # End at base lr total_iters=warmup_epochs ) # Cosine decay scheduler cosine_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( optimizer, T_max=cfg.epochs - warmup_epochs, # Remaining epochs eta_min=1e-6 # End at 1e-6 learning rate ) # Combine schedulers scheduler = torch.optim.lr_scheduler.SequentialLR( optimizer, schedulers=[warmup_scheduler, cosine_scheduler], milestones=[warmup_epochs] ) elif cfg.scheduler == "ConstantLR": scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer, factor=1.0) model, optimizer, train_loader, val_loader, scheduler = accelerator.prepare( model, optimizer, train_loader, val_loader, scheduler) if cfg.eval: state_dict = torch.load(os.path.join(path, f'final_model.pt')) # state_dict = torch.load('../Transolver/PDE-Solving-StandardBenchmark/checkpoints/elas_Transolver.pt') unwrapped_model = accelerator.unwrap_model(model) unwrapped_model.load_state_dict(state_dict) # model.load_state_dict(state_dict) final_model = accelerator.prepare(unwrapped_model) final_mse, final_mae, final_rell2, final_r2 = test(final_model, val_loader, train_loader.dataset.y_normalizer, cfg, accelerator) # final_mse, final_mae, final_rell2, final_r2 = test(model, val_loader, train_loader.dataset.mean, train_loader.dataset.std, cfg, accelerator) print(f"Final model metrics - MSE: {final_mse:.6f}, MAE: {final_mae:.6f}, RelL2: {final_rell2:.6f}, R²: {final_r2:.4f}") with open(os.path.join(path, 'final_test_metrics.txt'), 'w') as f: f.write(f"Test MSE: {final_mse:.6f}\n") f.write(f"Test MAE: {final_mae:.6f}\n") f.write(f"Test RelL2: {final_rell2:.6f}\n") f.write(f"Test R2: {final_r2:.4f}\n") else: # Calculate total parameters total_params = sum(p.numel() for p in model.parameters()) trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) # Reset memory stats before training torch.cuda.reset_peak_memory_stats() torch.cuda.empty_cache() # Clear any existing cached memory best_val_loss = 1e5 val_MSE_list = [] val_MAE_list = [] val_RelL2_list = [] val_R2_list = [] start = time.time() # Only initialize tensorboard on the main process if accelerator.is_main_process: # Create a descriptive run name using model type and timestamp run_name = f"{cfg.model}{cfg.test_name}_{time.strftime('%Y%m%d_%H%M%S')}" project_name = os.path.join("tensorboard_logs", f"{cfg.project_name}") log_dir = os.path.join(project_name, run_name) writer = SummaryWriter(log_dir) # Add full config config_text = "```yaml\n" # Using yaml format for better readability config_text += OmegaConf.to_yaml(cfg) config_text += "```" writer.add_text('hyperparameters/full_config', config_text) pbar_train = tqdm(range(cfg.epochs), position=0) else: writer = None pbar_train = range(cfg.epochs) for epoch in pbar_train: train_loss = train(model, train_loader, optimizer, scheduler, cfg, accelerator) # if accelerator.is_main_process: # # After a forward and backward pass # grad_stats = model.check_gradients() # # Print gradient statistics # for layer_name, stats in grad_stats.items(): # print(f"\nLayer: {layer_name}") # print(f"Gradient norm: {stats['norm']:.4f}") # print(f"Gradient mean: {stats['mean']:.4f}") # print(f"Gradient std: {stats['std']:.4f}") # if stats['is_exploding']: # print("WARNING: Gradients are exploding!") # if stats['is_vanishing']: # print("WARNING: Gradients are vanishing!") if cfg.scheduler == "CosineAnnealingLR" or cfg.scheduler == "CosineAnnealingWarmRestarts" or cfg.scheduler == "LinearWarmupCosineAnnealingLR": scheduler.step() if cfg.val_iter is not None and (epoch == cfg.epochs - 1 or epoch % cfg.val_iter == 0): val_loss_MSE, val_loss_MAE, val_loss_RelL2, val_r_squared = test(model, val_loader, train_loader.dataset.y_normalizer, cfg, accelerator) val_MSE_list.append(val_loss_MSE) val_MAE_list.append(val_loss_MAE) val_RelL2_list.append(val_loss_RelL2) val_R2_list.append(val_r_squared) # Only log to tensorboard and save files on main process if accelerator.is_main_process: # Get peak GPU memory in GB peak_mem_gb = torch.cuda.max_memory_allocated() / (1024 * 1024 * 1024) # Log metrics to tensorboard writer.add_scalar('Loss/train_RelL2', train_loss, epoch) writer.add_scalar('Loss/test_MSE', val_loss_MSE, epoch) writer.add_scalar('Loss/test_MAE', val_loss_MAE, epoch) writer.add_scalar('Loss/test_RelL2', val_loss_RelL2, epoch) writer.add_scalar('Metrics/test_R2', val_r_squared, epoch) writer.add_scalar('Learning_rate', scheduler.get_last_lr()[0], epoch) writer.add_scalar('Memory/peak_gpu_memory_gb', peak_mem_gb, epoch) with open(os.path.join(path,'MSE.json'), 'w') as f: json.dump(val_MSE_list, f, indent=2) with open(os.path.join(path,'MAE.json'), 'w') as f: json.dump(val_MAE_list, f, indent=2) with open(os.path.join(path,'RelL2.json'), 'w') as f: json.dump(val_RelL2_list, f, indent=2) with open(os.path.join(path,'R2.json'), 'w') as f: json.dump(val_R2_list, f, indent=2) pbar_train.set_postfix({ 'train_loss': train_loss, 'val_loss': val_loss_RelL2, 'lr': scheduler.get_last_lr()[0], 'mem_gb': f'{peak_mem_gb:.1f}' }) # Save best model if val_loss_MSE < best_val_loss: best_val_loss = val_loss_MSE unwrapped_model = accelerator.unwrap_model(model) torch.save(unwrapped_model.state_dict(), os.path.join(path, f'best_model.pt')) elif accelerator.is_main_process: # Simple progress display without validation peak_mem_gb = torch.cuda.max_memory_allocated() / (1024 * 1024 * 1024) pbar_train.set_postfix({ 'train_loss': train_loss, 'mem_gb': f'{peak_mem_gb:.1f}', 'lr': scheduler.get_last_lr()[0] }) end = time.time() time_elapsed = end - start # Get final peak memory for reporting if accelerator.is_main_process: peak_mem_gb = torch.cuda.max_memory_allocated() / (1024 * 1024 * 1024) # Reset memory stats before final evaluation torch.cuda.reset_peak_memory_stats() torch.cuda.empty_cache() # Test final model inf_time_start = time.time() final_mse, final_mae, final_rell2, final_r2 = test(model, val_loader, train_loader.dataset.y_normalizer, cfg, accelerator) inf_time_end = time.time() inf_time = inf_time_end - inf_time_start # Get peak memory during testing if accelerator.is_main_process: test_peak_mem_gb = torch.cuda.max_memory_allocated() / (1024 * 1024 * 1024) # Only print and save logs on main process if accelerator.is_main_process: print('Time elapsed: {0:.2f} seconds'.format(time_elapsed)) with open(os.path.join(path,'logs.json'), 'w') as f: json.dump([time_elapsed], f, indent=2) metrics_file = os.path.join(path, 'final_test_metrics.txt') with open(metrics_file, 'w') as f: metrics_text = f"Test MSE: {final_mse:.6f}\n" metrics_text += f"Test MAE: {final_mae:.6f}\n" metrics_text += f"Test RelL2: {final_rell2:.6f}\n" metrics_text += f"Test R2: {final_r2:.6f}\n" metrics_text += f"Total training time: {time_elapsed:.2f}s\n" metrics_text += f"Average epoch time: {time_elapsed/cfg.epochs:.2f}s\n" metrics_text += f"Total Inference time: {inf_time:.2f}s\n" metrics_text += f"Total parameters: {total_params}\n" metrics_text += f"Trainable parameters: {trainable_params}\n" metrics_text += f"Peak GPU memory usage:\n" metrics_text += f" - During training: {peak_mem_gb:.1f} GB\n" metrics_text += f" - During testing: {test_peak_mem_gb:.1f} GB\n" f.write(metrics_text) # Add final metrics to tensorboard writer.add_text('metrics/final_metrics', metrics_text.replace('\n', ' \n')) unwrapped_model = accelerator.unwrap_model(model) torch.save(unwrapped_model.state_dict(), os.path.join(path, f'final_model.pt')) # Load the best model using state_dict for compatibility state_dict = torch.load(os.path.join(path, f'best_model.pt')) unwrapped_model = accelerator.unwrap_model(model) unwrapped_model.load_state_dict(state_dict) best_model = accelerator.prepare(unwrapped_model) best_mse, best_mae, best_rell2, best_r2 = test(best_model, val_loader, train_loader.dataset.y_normalizer, cfg, accelerator) if accelerator.is_main_process: metrics_file = os.path.join(path, 'best_test_metrics.txt') with open(metrics_file, 'w') as f: metrics_text = f"Test MSE: {best_mse:.6f}\n" metrics_text += f"Test MAE: {best_mae:.6f}\n" metrics_text += f"RelL2: {best_rell2:.6f}\n" metrics_text += f"R2: {best_r2:.6f}\n" metrics_text += f"Total training time: {time_elapsed:.2f}s\n" metrics_text += f"Average epoch time: {time_elapsed/cfg.epochs:.2f}s\n" metrics_text += f"Total Inference time: {inf_time:.2f}s\n" metrics_text += f"Total parameters: {total_params}\n" metrics_text += f"Trainable parameters: {trainable_params}\n" metrics_text += f"Peak GPU memory usage:\n" metrics_text += f" - During training: {peak_mem_gb:.1f} GB\n" metrics_text += f" - During testing: {test_peak_mem_gb:.1f} GB\n" f.write(metrics_text) # Add best metrics to tensorboard writer.add_text('metrics/best_metrics', metrics_text.replace('\n', ' \n')) print(f"\nFinal model metrics - MSE: {final_mse:.6f}, MAE: {final_mae:.6f}, RelL2: {final_rell2:.6f}, R²: {final_r2:.4f}") print(f"Best model metrics - MSE: {best_mse:.6f}, MAE: {best_mae:.6f}, RelL2: {best_rell2:.6f}, R²: {best_r2:.4f}") writer.close()