AnsysLPFMTrame-App / trainers /train_driveaerpp.py
udbhav
Recreate Trame_app branch with clean history
67fb03c
import numpy as np
import time, json, os, math
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
import json
from torch.utils.tensorboard import SummaryWriter
from omegaconf import OmegaConf
from lion_pytorch import Lion
import glob
from utils.vtk_writer import vtk_writer
# For full mesh load and then sample in each training iteration
from datasets.driveaerpp.dataset_loader import get_dataloaders
from torch.cuda.amp import GradScaler
import re
import glob
def save_checkpoint(model, optimizer, scheduler, epoch, best_val_loss, val_MSE_list,
cfg, path, accelerator, log_dir=None):
"""Save a complete training checkpoint."""
if accelerator.is_main_process:
rng_state = torch.get_rng_state()
cuda_rng_states = None
if torch.cuda.is_available():
cuda_rng_states = []
for i in range(torch.cuda.device_count()):
cuda_rng_states.append(torch.cuda.get_rng_state(device=i))
checkpoint = {
'epoch': epoch,
'model_state_dict': accelerator.unwrap_model(model).state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'scheduler_state_dict': scheduler.state_dict(),
'best_val_loss': best_val_loss,
'val_MSE_list': val_MSE_list,
'cfg': cfg,
'log_dir': log_dir,
'rng_state': rng_state,
'cuda_rng_states': cuda_rng_states,
}
checkpoint_path = os.path.join(path, f'checkpoint_epoch_{epoch}.pt')
torch.save(checkpoint, checkpoint_path)
# Also save as latest checkpoint
latest_checkpoint_path = os.path.join(path, 'latest_checkpoint.pt')
torch.save(checkpoint, latest_checkpoint_path)
print(f"Checkpoint saved at epoch {epoch}")
def load_checkpoint(path, model, optimizer, scheduler, accelerator):
"""Load the latest checkpoint and return training state."""
latest_checkpoint_path = os.path.join(path, 'latest_checkpoint.pt')
if not os.path.exists(latest_checkpoint_path):
print("No checkpoint found, starting from scratch")
return None, 0, 1e5, [], None
print(f"Loading checkpoint from {latest_checkpoint_path}")
checkpoint = torch.load(latest_checkpoint_path, map_location='cpu')
# Load model state
unwrapped_model = accelerator.unwrap_model(model)
unwrapped_model.load_state_dict(checkpoint['model_state_dict'])
# Load optimizer and scheduler states
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
# Restore random states for reproducibility with error handling
try:
if 'rng_state' in checkpoint and checkpoint['rng_state'] is not None:
torch.set_rng_state(checkpoint['rng_state'])
except Exception as e:
print(f"Warning: Could not restore CPU RNG state: {e}")
try:
# Handle both old and new checkpoint formats
cuda_rng_key = 'cuda_rng_states' if 'cuda_rng_states' in checkpoint else 'cuda_rng_state'
if cuda_rng_key in checkpoint and checkpoint[cuda_rng_key] is not None and torch.cuda.is_available():
cuda_rng_states = checkpoint[cuda_rng_key]
if isinstance(cuda_rng_states, list) and len(cuda_rng_states) > 0:
# Set RNG state for each device
for i, state in enumerate(cuda_rng_states):
if i < torch.cuda.device_count() and state is not None:
torch.cuda.set_rng_state(state, device=i)
except Exception as e:
print(f"Warning: Could not restore CUDA RNG state: {e}")
start_epoch = checkpoint['epoch'] + 1
best_val_loss = checkpoint['best_val_loss']
val_MSE_list = checkpoint['val_MSE_list']
log_dir = checkpoint.get('log_dir', None)
print(f"Resumed from epoch {checkpoint['epoch']}, best val loss: {best_val_loss:.6f}")
return checkpoint, start_epoch, best_val_loss, val_MSE_list, log_dir
def cleanup_old_checkpoints(path, keep_last=3):
"""Remove old checkpoint files, keeping only the most recent ones."""
checkpoint_pattern = os.path.join(path, '*_epoch_*.pt')
checkpoint_files = glob.glob(checkpoint_pattern)
if len(checkpoint_files) <= keep_last:
return
# Sort by modification time and remove oldest
checkpoint_files.sort(key=os.path.getmtime)
files_to_remove = checkpoint_files[:-keep_last]
for file_path in files_to_remove:
try:
os.remove(file_path)
print(f"Removed old checkpoint: {os.path.basename(file_path)}")
except OSError:
pass
def train(model, train_loader, optimizer, scheduler, criterion, cfg, accelerator, scaler):
model.train()
losses_press = 0.0
for data in train_loader:
optimizer.zero_grad()
targets = data['output_feat']
if cfg.mixed_precision:
with torch.autocast(device_type = accelerator.device.type):
out = model(data)
total_loss = criterion(out, targets)
scaler.scale(total_loss).backward()
scaler.unscale_(optimizer)
if cfg.max_grad_norm is not None:
accelerator.clip_grad_norm_(model.parameters(), cfg.max_grad_norm)
scaler.step(optimizer)
scaler.update()
else:
out = model(data)
total_loss = criterion(out, targets)
accelerator.backward(total_loss)
if cfg.max_grad_norm is not None:
accelerator.clip_grad_norm_(model.parameters(), cfg.max_grad_norm)
optimizer.step()
# Only step OneCycleLR every batch
if cfg.scheduler == "OneCycleLR":
scheduler.step()
losses_press += total_loss.item()
return losses_press / len(train_loader)
@torch.no_grad()
def val(model, val_loader, criterion, cfg, accelerator):
model.eval()
losses_press = 0.0
for data in val_loader:
targets = data['output_feat']
out = model(data)
# Loss computation in FP32 for maximum stability
targets = targets.float() # Ensure FP32
out = out.float() # Ensure FP32
total_loss = criterion(out, targets)
losses_press += total_loss.item()
return losses_press / len(val_loader)
def test_model(model, test_dataloader, criterion, path, cfg, accelerator):
"""Test the model and calculate metrics.""" ## You reported test models everywhere? Complete evaluation?
model.eval()
total_mse = 0.0
total_mae = 0.0
total_huber = 0.0
total_rel_l2 = 0.0
total_rel_l1 = 0.0
total_mse_list = []
total_mae_list = []
total_huber_list = []
total_rel_l2_list = []
total_rel_l1_list = []
r_2_squared_list = []
total_inference_time = 0.0
num_batches = 0
if cfg.normalization == "std_norm":
with open(cfg.json_file, 'r') as f:
json_data = json.load(f)
pressure_mean = torch.tensor(json_data["scalars"]["pressure"]["mean"], device=accelerator.device)
pressure_std = torch.tensor(json_data["scalars"]["pressure"]["std"], device=accelerator.device)
# Store outputs and targets on all processes
all_outputs = []
all_targets = []
all_physical_coordinates = []
with torch.no_grad():
for data in tqdm(test_dataloader, desc="[Testing]", disable=not accelerator.is_local_main_process):
start_time = time.time()
targets = data['output_feat']
if cfg.chunked_eval:
input_pos = data['input_pos']
B, N, C = input_pos.shape
chunk_size = cfg.num_points
outputs = []
for i in range(0, N, chunk_size):
# start with the raw slice
chunk = input_pos[:, i:i+chunk_size, :] # (B, n_valid, C)
n_valid = chunk.shape[1]
# Pad if last chunk is short
if n_valid < chunk_size:
shape_diff = chunk_size - n_valid
# Wrap from the beginning to make a full chunk
pad = input_pos[:, :shape_diff, :] # (B, shape_diff, C)
chunk = torch.cat([chunk, pad], dim=1) # (B, chunk_size, C)
data['input_pos'] = chunk
out_chunk = model(data) # (B, chunk_size, D)
# Keep only the valid part that corresponds to real points
out_chunk = out_chunk[:, :n_valid, :] # (B, n_valid, D)
else:
data['input_pos'] = chunk
out_chunk = model(data) # (B, chunk_size, D)
outputs.append(out_chunk)
outputs = torch.cat(outputs, dim=1) # (B, N, 3)
else:
outputs = model(data)
# Metric computations in FP32 for maximum stability
targets = targets.float() # Ensure FP32
outputs = outputs.float() # Ensure FP32
if cfg.physical_scale_for_test == True:
targets[:,:,0] = targets[:,:,0] * pressure_std + pressure_mean
outputs[:,:,0] = outputs[:,:,0] * pressure_std + pressure_mean
inference_time = time.time() - start_time
total_inference_time += inference_time
# Compute all relevant losses and metrics for documentation and analysis
criterion_mse = nn.MSELoss()
criterion_mae = nn.L1Loss()
criterion_huber = nn.HuberLoss(delta=1.0)
mse = criterion_mse(outputs, targets)
mae = criterion_mae(outputs, targets)
huber = criterion_huber(outputs, targets)
# Relative L2 error: mean over batch of (L2 norm of error / L2 norm of target)
rel_l2 = torch.mean(torch.norm(outputs.squeeze(-1) - targets.squeeze(-1), p=2, dim=-1) /
torch.norm(targets.squeeze(-1), p=2, dim=-1))
# Relative L1 error: mean over batch of (L1 norm of error / L1 norm of target)
rel_l1 = torch.mean(torch.norm(outputs.squeeze(-1) - targets.squeeze(-1), p=1, dim=-1) /
torch.norm(targets.squeeze(-1), p=1, dim=-1))
ss_tot = torch.sum((outputs - torch.mean(targets)) ** 2)
ss_res = torch.sum((targets - outputs) ** 2)
r_squared = 1 - (ss_res / ss_tot) if ss_tot > 0 else 0
total_mse_list.append(mse.item())
total_mae_list.append(mae.item())
total_huber_list.append(huber.item())
total_rel_l2_list.append(rel_l2.item())
total_rel_l1_list.append(rel_l1.item())
r_2_squared_list.append(r_squared.item())
total_mse += mse
total_mae += mae
total_huber += huber
total_rel_l2 += rel_l2
total_rel_l1 += rel_l1
num_batches += 1
# Store outputs and targets on all processes for later aggregation and R² computation
all_outputs.append(outputs.cpu())
all_targets.append(targets.cpu())
all_physical_coordinates.append(data['physical_coordinates'].cpu())
path_vtk = path + "/vtk_files"
# Save VTK files for each data(if any data exists)
if len(all_outputs) > 0 and len(all_targets) > 0 and len(all_physical_coordinates) > 0:
if cfg.physical_scale_for_test == False:
targets[:,:,0] = targets[:,:,0] * pressure_std + pressure_mean
outputs[:,:,0] = outputs[:,:,0] * pressure_std + pressure_mean
try:
vtk_writer(outputs, targets, data['physical_coordinates'].cpu(), path_vtk, prefix=data["data_id"][0], config_json_path=cfg.json_file)
except Exception as e:
print(f"[Warning] Could not save VTK files: {e}")
# Clear references to tensors
del outputs, targets, mse, mae, huber, rel_l2, rel_l1
metrics_list = {
"total_mse_list": total_mse_list,
"total_mae_list": total_mae_list,
"total_huber_list": total_huber_list,
"total_rel_l2_list": total_rel_l2_list,
"total_rel_l1_list": total_rel_l1_list,
"r_2_squared_list": r_2_squared_list,
}
# Save metrics_list as a JSON file for per-batch analysis
metrics_list_file = os.path.join(path, 'test_metrics_list.txt')
with open(metrics_list_file, 'w') as f:
json.dump(metrics_list, f, indent=2)
# Convert to tensors for reduction
metrics = {
"total_mse": total_mse,
"total_mae": total_mae,
"total_huber": total_huber,
"total_rel_l2": total_rel_l2,
"total_rel_l1": total_rel_l1,
"num_batches": torch.tensor(num_batches, device=accelerator.device),
"total_inference_time": torch.tensor(total_inference_time, device=accelerator.device)
}
# Gather metrics from all processes
gathered_metrics = accelerator.gather(metrics)
# Only calculate averages if we have data
if gathered_metrics["num_batches"].sum().item() > 0:
total_batches = gathered_metrics["num_batches"].sum().item()
avg_mse = gathered_metrics["total_mse"].sum().item() / total_batches
avg_mae = gathered_metrics["total_mae"].sum().item() / total_batches
avg_huber = gathered_metrics["total_huber"].sum().item() / total_batches
avg_rel_l2 = gathered_metrics["total_rel_l2"].sum().item() / total_batches
avg_rel_l1 = gathered_metrics["total_rel_l1"].sum().item() / total_batches
total_inference_time = gathered_metrics["total_inference_time"].sum().item()
avg_inference_time = total_inference_time / total_batches
# Gather all outputs and targets from all processes
all_outputs = torch.cat(all_outputs, dim=1)
all_targets = torch.cat(all_targets, dim=1)
# Gather outputs and targets across processes
all_outputs = accelerator.gather(all_outputs.to(accelerator.device))
all_targets = accelerator.gather(all_targets.to(accelerator.device))
# Calculate R² score using complete dataset
if accelerator.is_main_process:
all_outputs = all_outputs.to(torch.float32).cpu().numpy()
all_targets = all_targets.to(torch.float32).cpu().numpy()
ss_tot = np.sum((all_targets - np.mean(all_targets)) ** 2)
ss_res = np.sum((all_targets - all_outputs) ** 2)
r_squared = 1 - (ss_res / ss_tot) if ss_tot > 0 else 0
print(f"Test MSE: {avg_mse:.6f}, Test MAE: {avg_mae:.6f}, Test Huber: {avg_huber:.6f}, R²: {r_squared:.4f}")
print(f"Relative L2 Error: {avg_rel_l2:.6f}, Relative L1 Error: {avg_rel_l1:.6f}")
print(f"Average inference time per batch: {avg_inference_time:.4f}s")
print(f"Total inference time: {total_inference_time:.2f}s for {total_batches} batches")
# Save metrics to a text file
metrics_file = os.path.join(path, 'test_metrics.txt')
with open(metrics_file, 'w') as f:
f.write(f"Test MSE: {avg_mse:.6f}\n")
f.write(f"Test MAE: {avg_mae:.6f}\n")
f.write(f"Test Huber: {avg_huber:.6f}\n")
f.write(f"R2 Score: {r_squared:.6f}\n")
f.write(f"Relative L2 Error: {avg_rel_l2:.6f}\n")
f.write(f"Relative L1 Error: {avg_rel_l1:.6f}\n")
f.write(f"Average inference time per batch: {avg_inference_time:.4f}s\n")
f.write(f"Total inference time: {total_inference_time:.2f}s for {total_batches} batches\n")
# Save outputs and targets as .npy files for further analysis
#np.save(os.path.join(path, 'test_outputs.npy'), all_outputs)
#np.save(os.path.join(path, 'test_targets.npy'), all_targets)
else:
r_squared = 0.0 # Will be overwritten by broadcast
else:
print("Warning: No data in test_dataloader")
avg_mse = avg_mae = avg_huber = avg_rel_l2 = avg_rel_l1 = r_squared = 0.0
# Clear GPU cache after all testing
torch.cuda.empty_cache()
return avg_mse, avg_mae, avg_huber, avg_rel_l2, avg_rel_l1, r_squared, avg_inference_time
def train_driveaerpp_main(model, path, cfg, accelerator):
train_loader, val_loader, test_loader = get_dataloaders(cfg)
if accelerator.is_main_process:
print(
f"Data loaded: {len(train_loader)} training batches, "
f"{len(val_loader)} validation batches, "
f"{len(test_loader)} test batches")
#Select optimizer
if cfg.optimizer.type == 'Adam':
optimizer = torch.optim.Adam(model.parameters(), lr=cfg.lr, weight_decay=1e-4)
elif cfg.optimizer.type == 'AdamW':
optimizer = torch.optim.AdamW(model.parameters(), lr=cfg.lr, weight_decay=0.05)
elif cfg.optimizer.type == 'LION':
optimizer = Lion(model.parameters(), lr=cfg.lr, weight_decay=0.05)
#Select scheduler
if cfg.scheduler == "ReduceLROnPlateau":
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=10, factor=0.1, verbose=True)
elif cfg.scheduler == "LinearWarmupCosineAnnealingLR":
warmup_epochs = int(cfg.epochs * 0.05) # Convert back to epochs
# Linear warmup scheduler
warmup_scheduler = torch.optim.lr_scheduler.LinearLR(
optimizer,
start_factor=1e-6, # Start very low (almost zero)
end_factor=1.0, # End at base lr
total_iters=warmup_epochs
)
# Cosine decay scheduler
cosine_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
optimizer,
T_max=cfg.epochs - warmup_epochs, # Remaining epochs
eta_min=1e-6 # End at 1e-6 learning rate
)
# Combine schedulers
scheduler = torch.optim.lr_scheduler.SequentialLR(
optimizer,
schedulers=[warmup_scheduler, cosine_scheduler],
milestones=[warmup_epochs]
)
else:
scheduler = torch.optim.lr_scheduler.OneCycleLR(
optimizer,
pct_start=0.05,
max_lr=cfg.lr,
total_steps = len(train_loader) * cfg.epochs
)
if cfg.loss_type == "mse":
criterion = nn.MSELoss()
elif cfg.loss_type == "mae":
criterion = nn.L1Loss()
elif cfg.loss_type == "huber":
criterion = nn.HuberLoss(delta=1.0)
else:
raise ValueError(f"Unknown loss_type: {cfg.loss_type}")
scaler = GradScaler()
model, optimizer, train_loader, val_loader, test_loader, scheduler, scaler = accelerator.prepare(
model, optimizer, train_loader, val_loader, test_loader, scheduler, scaler)
best_epoch = 0
# Try to load checkpoint before evaluation or training
checkpoint, start_epoch, best_val_loss, val_MSE_list, resumed_log_dir = load_checkpoint(
path, model, optimizer, scheduler, accelerator)
# Before the training loop, after loading checkpoint:
# or -1 if you want to indicate "not set"
if cfg.eval:
# For evaluation, try to load from checkpoint first, then fall back to best_model.pt
if (cfg.train_ckpt_load):
print("Using model from checkpoint for evaluation")
else:
# Load the saved state dict and create a fresh model
load_path = f'metrics/{cfg.project_name}/{cfg.model}_{cfg.test_name}'
# Find all best_model_epoch_*.pt files and get the epoch number from the last one
pattern = os.path.join(os.getcwd(), load_path.lstrip('/'), 'best_case', 'best_model_epoch_*.pt')
best_model_files = glob.glob(pattern)
if not best_model_files:
raise FileNotFoundError(f"No best_model_epoch_*.pt files found in {os.path.join(load_path, 'best_case')}")
# Extract epoch numbers
epoch_numbers = []
for fname in best_model_files:
match = re.search(r'best_model_epoch_(\d+)\.pt', os.path.basename(fname))
if match:
epoch_numbers.append(int(match.group(1)))
if not epoch_numbers:
raise ValueError("No epoch numbers found in best_model_epoch_*.pt filenames")
last_best_epoch = max(epoch_numbers)
state_dict = torch.load(os.path.join(load_path, 'best_case', f'best_model_epoch_{last_best_epoch}.pt'))
unwrapped_model = accelerator.unwrap_model(model)
unwrapped_model.load_state_dict(state_dict)
model = accelerator.prepare(unwrapped_model)
path = os.path.join(path, "best_case") # Update path to point to best model directory
print("Using best model for evaluation at epoch", last_best_epoch)
best_mse, best_mae, best_huber, best_rel_l2, best_rel_l1, best_r2, inf_time = test_model(model, test_loader, criterion, path, cfg, accelerator)
else:
# Calculate total parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
# Reset memory stats before training
torch.cuda.reset_peak_memory_stats()
torch.cuda.empty_cache() # Clear any existing cached memory
start = time.time()
# Only initialize tensorboard on the main process
if accelerator.is_main_process:
# Create a descriptive run name using model type and timestamp
if resumed_log_dir is not None:
# Resume logging to the same directory
log_dir = resumed_log_dir
print(f"Resuming tensorboard logging to: {log_dir}")
else:
# Create new log directory
run_name = f"{cfg.model}_{cfg.test_name}_{time.strftime('%Y%m%d_%H%M%S')}"
project_name = os.path.join("tensorboard_logs", f"{cfg.project_name}")
log_dir = os.path.join(project_name, run_name)
print(f"Starting new tensorboard logging to: {log_dir}")
writer = SummaryWriter(log_dir)
# Add full config (only if starting fresh)
if checkpoint is None:
config_text = "```yaml\n" # Using yaml format for better readability
config_text += OmegaConf.to_yaml(cfg)
config_text += "```"
writer.add_text('hyperparameters/full_config', config_text)
pbar_train = tqdm(range(start_epoch, cfg.epochs), position=0)
pbar_train.set_description(f"Training (resumed from epoch {start_epoch})" if checkpoint else "Training")
else:
writer = None
log_dir = None
pbar_train = tqdm(range(start_epoch, cfg.epochs), position=0)
# Add checkpoint saving frequency to config (default every 10 epochs)
checkpoint_freq = getattr(cfg, 'checkpoint_freq', 10)
for epoch in pbar_train:
train_loss = train(model, train_loader, optimizer, scheduler, criterion, cfg, accelerator, scaler)
if cfg.val_iter is not None and (epoch == cfg.epochs - 1 or epoch % cfg.val_iter == 0):
val_loss_MSE = val(model, val_loader,criterion, cfg, accelerator)
if cfg.scheduler == "ReduceLROnPlateau":
scheduler.step(val_loss_MSE)
elif cfg.scheduler == "LinearWarmupCosineAnnealingLR":
scheduler.step()
val_MSE_list.append(val_loss_MSE)
if accelerator.is_main_process:
# Get peak GPU memory in GB
peak_mem_gb = torch.cuda.max_memory_allocated() / (1024 * 1024 * 1024)
# Log metrics to tensorboard
writer.add_scalar('Loss/train_MSE', train_loss, epoch)
writer.add_scalar('Loss/val_MSE', val_loss_MSE, epoch)
writer.add_scalar('Learning_rate', scheduler.get_last_lr()[0], epoch)
writer.add_scalar('Memory/GPU', peak_mem_gb, epoch)
with open(os.path.join(path,'MSE.json'), 'w') as f:
json.dump(val_MSE_list, f, indent=2)
pbar_train.set_postfix({
'train_loss': train_loss,
'val_loss': val_loss_MSE,
'lr': scheduler.get_last_lr()[0],
'mem_gb': f'{peak_mem_gb:.1f}'
})
if val_loss_MSE < best_val_loss:
best_val_loss = val_loss_MSE
unwrapped_model = accelerator.unwrap_model(model)
os.makedirs(os.path.join(path, 'best_case'), exist_ok=True)
best_epoch = epoch
# Save the best model state_dict
cleanup_old_checkpoints(os.path.join(path, 'best_case'), keep_last=1)
torch.save(unwrapped_model.state_dict(), os.path.join(path, f'best_case/best_model_epoch_{best_epoch}.pt'))
print("saving best model at epoch", epoch)
elif accelerator.is_main_process:
# Simple progress display without validation
peak_mem_gb = torch.cuda.max_memory_allocated() / (1024 * 1024 * 1024)
pbar_train.set_postfix({
'train_loss': train_loss,
'mem_gb': f'{peak_mem_gb:.1f}'
})
# Save checkpoint periodically
if accelerator.is_main_process and (epoch % checkpoint_freq == 0 or epoch == cfg.epochs - 1):
save_checkpoint(model, optimizer, scheduler, epoch, best_val_loss, val_MSE_list,
cfg, path, accelerator, log_dir)
# Clean up old checkpoints to save disk space
cleanup_old_checkpoints(path, keep_last=3)
end = time.time()
time_elapsed = end - start
# Get final peak memory for reporting
if accelerator.is_main_process:
peak_mem_gb = torch.cuda.max_memory_allocated() / (1024 * 1024 * 1024)
# Reset memory stats before final evaluation
torch.cuda.reset_peak_memory_stats()
torch.cuda.empty_cache()
# Save final checkpoint BEFORE loading best model for evaluation
if accelerator.is_main_process:
save_checkpoint(model, optimizer, scheduler, cfg.epochs - 1, best_val_loss, val_MSE_list,
cfg, path, accelerator, log_dir)
# Test final model (last epoch)
final_mse, final_mae, final_huber, final_rel_l2, final_rel_l1, final_r2, inf_time = test_model(
model, test_loader, criterion, path, cfg, accelerator)
# Get peak memory during testing
if accelerator.is_main_process:
test_peak_mem_gb = torch.cuda.max_memory_allocated() / (1024 * 1024 * 1024)
# Create metrics text for final model
metrics_text = f"Test MSE: {final_mse:.6f}\n"
metrics_text += f"Test MAE: {final_mae:.6f}\n"
metrics_text += f"Test Huber: {final_huber:.6f}\n"
metrics_text += f"Test RelL1: {final_rel_l1:.6f}\n"
metrics_text += f"Test RelL2: {final_rel_l2:.6f}\n"
metrics_text += f"Test R2: {final_r2:.6f}\n"
metrics_text += f"Inference time: {inf_time:.6f}s\n"
metrics_text += f"Total training time: {time_elapsed:.2f}s\n"
metrics_text += f"Average epoch time: {time_elapsed/cfg.epochs:.2f}s\n"
metrics_text += f"Total parameters: {total_params}\n"
metrics_text += f"Trainable parameters: {trainable_params}\n"
metrics_text += f"Peak GPU memory usage:\n"
metrics_text += f" - During training: {peak_mem_gb:.1f} GB\n"
metrics_text += f" - During testing: {test_peak_mem_gb:.1f} GB\n"
# Write to file and add to tensorboard
metrics_file = os.path.join(path, 'final_test_metrics.txt')
with open(metrics_file, 'w') as f:
f.write(metrics_text)
# Add final metrics to tensorboard as text (replace \n with markdown line break)
writer.add_text('metrics/final_metrics', metrics_text.replace('\n', ' \n'))
# --- Log per-batch test metrics for final model ---
metrics_list_file = os.path.join(path, 'test_metrics_list.txt')
if os.path.exists(metrics_list_file):
with open(metrics_list_file, 'r') as f:
metrics_list = json.load(f)
for metric_name, values in metrics_list.items():
for i, v in enumerate(values):
writer.add_scalar(f'per_batch_test_metrics/final/{metric_name}', v, i)
# Load the best model using state_dict for compatibility (into a separate model instance)
from copy import deepcopy
best_model = deepcopy(model)
state_dict = torch.load(os.path.join(path, f'best_case/best_model_epoch_{best_epoch}.pt'))
unwrapped_best_model = accelerator.unwrap_model(best_model)
unwrapped_best_model.load_state_dict(state_dict)
best_model = accelerator.prepare(unwrapped_best_model)
# Test the best model
path_best = os.path.join(path, 'best_case') # Do not overwrite path for final model logging
best_mse, best_mae, best_huber, best_rel_l2, best_rel_l1, best_r2, inf_time = test_model(
best_model, test_loader, criterion, path_best, cfg, accelerator)
if accelerator.is_main_process:
# Create metrics text for best model
metrics_text = f"Test MSE: {best_mse:.6f}\n"
metrics_text += f"Test MAE: {best_mae:.6f}\n"
metrics_text += f"Test Huber: {best_huber:.6f}\n"
metrics_text += f"Test RelL1: {best_rel_l1:.6f}\n"
metrics_text += f"Test RelL2: {best_rel_l2:.6f}\n"
metrics_text += f"Test R2: {best_r2:.6f}\n"
metrics_text += f"Inference time: {inf_time:.6f}s\n"
metrics_text += f"Total training time: {time_elapsed:.2f}s\n"
metrics_text += f"Average epoch time: {time_elapsed/cfg.epochs:.2f}s\n"
metrics_text += f"Total parameters: {total_params}\n"
metrics_text += f"Trainable parameters: {trainable_params}\n"
metrics_text += f"Peak GPU memory usage:\n"
metrics_text += f" - During training: {peak_mem_gb:.1f} GB\n"
metrics_text += f" - During testing: {test_peak_mem_gb:.1f} GB\n"
# Write to file and add to tensorboard
metrics_file = os.path.join(path_best, 'best_test_metrics.txt')
with open(metrics_file, 'w') as f:
f.write(metrics_text)
# Add best metrics to tensorboard as text (replace \n with markdown line break)
writer.add_text('metrics/best_metrics', metrics_text.replace('\n', ' \n'))
# --- Log per-batch test metrics for best-case model ---
metrics_list_file = os.path.join(path_best, 'test_metrics_list.txt')
if os.path.exists(metrics_list_file):
with open(metrics_list_file, 'r') as f:
metrics_list = json.load(f)
for metric_name, values in metrics_list.items():
for i, v in enumerate(values):
writer.add_scalar(f'per_batch_test_metrics/best/{metric_name}', v, i)
print(f"\nFinal model metrics - MSE: {final_mse:.6f}, MAE: {final_mae:.6f}, huber: {final_huber:.6f}, R²: {final_r2:.4f}")
print(f"Best model metrics - MSE: {best_mse:.6f}, MAE: {best_mae:.6f}, huber: {best_huber:.6f}, R²: {best_r2:.4f}")
# Close tensorboard writer
writer.close()