import os import argparse from omegaconf import OmegaConf data_name = "cadillac" #DrivAerML,deepjeb, driveaerpp, shapenet_car_pv, DriveAerNet, elasticity, plane_engine_test parser = argparse.ArgumentParser() parser.add_argument('--config_path', type=str, default='./configs/'+data_name+'/config.yaml') args = parser.parse_args() # Load config with OmegaConf cfg = OmegaConf.load(args.config_path) os.environ["CUDA_VISIBLE_DEVICES"] = str(cfg.gpu_id) # if hasattr(args, 'gpu_id') else "0" print("running on gpu ",cfg.gpu_id) import models import trainers import torch from accelerate import Accelerator from accelerate.utils import DistributedDataParallelKwargs import numpy as np import random from train import train_main # Setup DDP kwargs for distributed training ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True) accelerator = Accelerator(kwargs_handlers=[ddp_kwargs]) print("FlashAttention available:", torch.backends.cuda.flash_sdp_enabled()) # dataset names: # DrivAerML: # model names: Transolver, Transformer, New # chunked_eval and resume model capabilities available # trained model checkpoints for inference available in metrics folder # shapenet_car_pv: # model names: Transolver, Transformer, New # DriveAerNet: # model names: Transolver, Transformer, New # elasticity: # model names: Transolver, Transformer, New # set seed def set_seed(seed): torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) np.random.seed(seed) random.seed(seed) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False set_seed(0) def count_parameters(model): total_params = 0 for name, parameter in model.named_parameters(): if not parameter.requires_grad: continue params = parameter.numel() total_params += params if accelerator.is_main_process: print(f"Total Trainable Params: {total_params}") return total_params def create_model(cfg): """Factory function to create models based on dataset and model name.""" model_name = f"{cfg.model}" if hasattr(models, model_name): model_class = getattr(models, model_name) return model_class(cfg).cuda() else: raise ValueError(f"Model '{model_name}' not found. Available models: {[attr for attr in dir(models) if not attr.startswith('_')]}") if accelerator.is_main_process: print(OmegaConf.to_yaml(cfg)) #Set path for saving metrics path = os.path.join('metrics', f'{cfg.project_name}', f'{cfg.model}_{cfg.test_name}') if accelerator.is_main_process: if not os.path.exists(path): os.makedirs(path) # Create src folder for source code backup src_path = os.path.join(path, 'src') if not os.path.exists(src_path): os.makedirs(src_path) # Save config.yaml config_file = os.path.join(path, 'config.yaml') print(f'Saving config to {config_file}') OmegaConf.save(cfg, config_file) # Save source files import shutil source_files = { 'main.py': 'main.py', 'train.py': 'train.py', 'dataset_loader.py': 'dataset_loader.py', 'models/ansysLPFMs.py': 'models/ansysLPFMs.py' } for src_file, dst_name in source_files.items(): if os.path.exists(src_file): dst_file = os.path.join(src_path, dst_name) # Create directory if needed (for models/Transformer.py) os.makedirs(os.path.dirname(dst_file), exist_ok=True) shutil.copy2(src_file, dst_file) print(f'Saved source file: {src_file} -> {dst_file}') else: print(f'Warning: Source file not found: {src_file}') # Create model using factory function model = create_model(cfg) total_parameters = count_parameters(model) # Training function mapping trainer_map = { 'artery': trainers.train_artery_main, 'shapenet_car_pv': trainers.train_shapenet_car_pv_main, 'DriveAerNet': trainers.train_DriveAerNet_main, 'DrivAerML': trainers.train_DrivAerML_main, 'elasticity': trainers.train_elasticity_main, 'plane_engine_test': trainers.train_plane_engine_test_main, 'driveaerpp': trainers.train_driveaerpp_main, 'deepjeb': trainers.train_deepjeb_main, 'plane_transonic': train_main, 'cadillac': train_main } # Train the model if cfg.dataset_name in trainer_map: trainer_map[cfg.dataset_name](model, path, cfg, accelerator) else: raise ValueError(f"No trainer found for dataset: {cfg.dataset_name}")