udbhav
Recreate Trame_app branch with clean history
67fb03c
import os
import argparse
from omegaconf import OmegaConf
data_name = "cadillac" #DrivAerML,deepjeb, driveaerpp, shapenet_car_pv, DriveAerNet, elasticity, plane_engine_test
parser = argparse.ArgumentParser()
parser.add_argument('--config_path', type=str, default='./configs/'+data_name+'/config.yaml')
args = parser.parse_args()
# Load config with OmegaConf
cfg = OmegaConf.load(args.config_path)
os.environ["CUDA_VISIBLE_DEVICES"] = str(cfg.gpu_id) # if hasattr(args, 'gpu_id') else "0"
print("running on gpu ",cfg.gpu_id)
import models
import trainers
import torch
from accelerate import Accelerator
from accelerate.utils import DistributedDataParallelKwargs
import numpy as np
import random
from train import train_main
# Setup DDP kwargs for distributed training
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
accelerator = Accelerator(kwargs_handlers=[ddp_kwargs])
print("FlashAttention available:", torch.backends.cuda.flash_sdp_enabled())
# dataset names:
# DrivAerML:
# model names: Transolver, Transformer, New
# chunked_eval and resume model capabilities available
# trained model checkpoints for inference available in metrics folder
# shapenet_car_pv:
# model names: Transolver, Transformer, New
# DriveAerNet:
# model names: Transolver, Transformer, New
# elasticity:
# model names: Transolver, Transformer, New
# set seed
def set_seed(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
set_seed(0)
def count_parameters(model):
total_params = 0
for name, parameter in model.named_parameters():
if not parameter.requires_grad: continue
params = parameter.numel()
total_params += params
if accelerator.is_main_process:
print(f"Total Trainable Params: {total_params}")
return total_params
def create_model(cfg):
"""Factory function to create models based on dataset and model name."""
model_name = f"{cfg.model}"
if hasattr(models, model_name):
model_class = getattr(models, model_name)
return model_class(cfg).cuda()
else:
raise ValueError(f"Model '{model_name}' not found. Available models: {[attr for attr in dir(models) if not attr.startswith('_')]}")
if accelerator.is_main_process:
print(OmegaConf.to_yaml(cfg))
#Set path for saving metrics
path = os.path.join('metrics', f'{cfg.project_name}', f'{cfg.model}_{cfg.test_name}')
if accelerator.is_main_process:
if not os.path.exists(path):
os.makedirs(path)
# Create src folder for source code backup
src_path = os.path.join(path, 'src')
if not os.path.exists(src_path):
os.makedirs(src_path)
# Save config.yaml
config_file = os.path.join(path, 'config.yaml')
print(f'Saving config to {config_file}')
OmegaConf.save(cfg, config_file)
# Save source files
import shutil
source_files = {
'main.py': 'main.py',
'train.py': 'train.py',
'dataset_loader.py': 'dataset_loader.py',
'models/ansysLPFMs.py': 'models/ansysLPFMs.py'
}
for src_file, dst_name in source_files.items():
if os.path.exists(src_file):
dst_file = os.path.join(src_path, dst_name)
# Create directory if needed (for models/Transformer.py)
os.makedirs(os.path.dirname(dst_file), exist_ok=True)
shutil.copy2(src_file, dst_file)
print(f'Saved source file: {src_file} -> {dst_file}')
else:
print(f'Warning: Source file not found: {src_file}')
# Create model using factory function
model = create_model(cfg)
total_parameters = count_parameters(model)
# Training function mapping
trainer_map = {
'artery': trainers.train_artery_main,
'shapenet_car_pv': trainers.train_shapenet_car_pv_main,
'DriveAerNet': trainers.train_DriveAerNet_main,
'DrivAerML': trainers.train_DrivAerML_main,
'elasticity': trainers.train_elasticity_main,
'plane_engine_test': trainers.train_plane_engine_test_main,
'driveaerpp': trainers.train_driveaerpp_main,
'deepjeb': trainers.train_deepjeb_main,
'plane_transonic': train_main,
'cadillac': train_main
}
# Train the model
if cfg.dataset_name in trainer_map:
trainer_map[cfg.dataset_name](model, path, cfg, accelerator)
else:
raise ValueError(f"No trainer found for dataset: {cfg.dataset_name}")