Spaces:
Runtime error
Runtime error
File size: 4,573 Bytes
67fb03c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
import os
import argparse
from omegaconf import OmegaConf
data_name = "cadillac" #DrivAerML,deepjeb, driveaerpp, shapenet_car_pv, DriveAerNet, elasticity, plane_engine_test
parser = argparse.ArgumentParser()
parser.add_argument('--config_path', type=str, default='./configs/'+data_name+'/config.yaml')
args = parser.parse_args()
# Load config with OmegaConf
cfg = OmegaConf.load(args.config_path)
os.environ["CUDA_VISIBLE_DEVICES"] = str(cfg.gpu_id) # if hasattr(args, 'gpu_id') else "0"
print("running on gpu ",cfg.gpu_id)
import models
import trainers
import torch
from accelerate import Accelerator
from accelerate.utils import DistributedDataParallelKwargs
import numpy as np
import random
from train import train_main
# Setup DDP kwargs for distributed training
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
accelerator = Accelerator(kwargs_handlers=[ddp_kwargs])
print("FlashAttention available:", torch.backends.cuda.flash_sdp_enabled())
# dataset names:
# DrivAerML:
# model names: Transolver, Transformer, New
# chunked_eval and resume model capabilities available
# trained model checkpoints for inference available in metrics folder
# shapenet_car_pv:
# model names: Transolver, Transformer, New
# DriveAerNet:
# model names: Transolver, Transformer, New
# elasticity:
# model names: Transolver, Transformer, New
# set seed
def set_seed(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
set_seed(0)
def count_parameters(model):
total_params = 0
for name, parameter in model.named_parameters():
if not parameter.requires_grad: continue
params = parameter.numel()
total_params += params
if accelerator.is_main_process:
print(f"Total Trainable Params: {total_params}")
return total_params
def create_model(cfg):
"""Factory function to create models based on dataset and model name."""
model_name = f"{cfg.model}"
if hasattr(models, model_name):
model_class = getattr(models, model_name)
return model_class(cfg).cuda()
else:
raise ValueError(f"Model '{model_name}' not found. Available models: {[attr for attr in dir(models) if not attr.startswith('_')]}")
if accelerator.is_main_process:
print(OmegaConf.to_yaml(cfg))
#Set path for saving metrics
path = os.path.join('metrics', f'{cfg.project_name}', f'{cfg.model}_{cfg.test_name}')
if accelerator.is_main_process:
if not os.path.exists(path):
os.makedirs(path)
# Create src folder for source code backup
src_path = os.path.join(path, 'src')
if not os.path.exists(src_path):
os.makedirs(src_path)
# Save config.yaml
config_file = os.path.join(path, 'config.yaml')
print(f'Saving config to {config_file}')
OmegaConf.save(cfg, config_file)
# Save source files
import shutil
source_files = {
'main.py': 'main.py',
'train.py': 'train.py',
'dataset_loader.py': 'dataset_loader.py',
'models/ansysLPFMs.py': 'models/ansysLPFMs.py'
}
for src_file, dst_name in source_files.items():
if os.path.exists(src_file):
dst_file = os.path.join(src_path, dst_name)
# Create directory if needed (for models/Transformer.py)
os.makedirs(os.path.dirname(dst_file), exist_ok=True)
shutil.copy2(src_file, dst_file)
print(f'Saved source file: {src_file} -> {dst_file}')
else:
print(f'Warning: Source file not found: {src_file}')
# Create model using factory function
model = create_model(cfg)
total_parameters = count_parameters(model)
# Training function mapping
trainer_map = {
'artery': trainers.train_artery_main,
'shapenet_car_pv': trainers.train_shapenet_car_pv_main,
'DriveAerNet': trainers.train_DriveAerNet_main,
'DrivAerML': trainers.train_DrivAerML_main,
'elasticity': trainers.train_elasticity_main,
'plane_engine_test': trainers.train_plane_engine_test_main,
'driveaerpp': trainers.train_driveaerpp_main,
'deepjeb': trainers.train_deepjeb_main,
'plane_transonic': train_main,
'cadillac': train_main
}
# Train the model
if cfg.dataset_name in trainer_map:
trainer_map[cfg.dataset_name](model, path, cfg, accelerator)
else:
raise ValueError(f"No trainer found for dataset: {cfg.dataset_name}") |