File size: 4,573 Bytes
67fb03c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132

import os
import argparse
from omegaconf import OmegaConf
data_name = "cadillac" #DrivAerML,deepjeb, driveaerpp, shapenet_car_pv, DriveAerNet, elasticity, plane_engine_test
parser = argparse.ArgumentParser()
parser.add_argument('--config_path', type=str,  default='./configs/'+data_name+'/config.yaml')
args = parser.parse_args()

# Load config with OmegaConf
cfg = OmegaConf.load(args.config_path)
os.environ["CUDA_VISIBLE_DEVICES"] = str(cfg.gpu_id) # if hasattr(args, 'gpu_id') else "0"
print("running on gpu ",cfg.gpu_id)

import models
import trainers
import torch
from accelerate import Accelerator
from accelerate.utils import DistributedDataParallelKwargs
import numpy as np
import random
from train import train_main 
# Setup DDP kwargs for distributed training
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)

accelerator = Accelerator(kwargs_handlers=[ddp_kwargs])
print("FlashAttention available:", torch.backends.cuda.flash_sdp_enabled())

# dataset names:
#      DrivAerML:
#           model names: Transolver, Transformer, New
#           chunked_eval and resume model capabilities available
#           trained model checkpoints for inference available in metrics folder
#      shapenet_car_pv:
#           model names: Transolver, Transformer, New
#      DriveAerNet:
#           model names: Transolver, Transformer, New
#      elasticity:
#           model names: Transolver, Transformer, New

# set seed
def set_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(0)  

def count_parameters(model):
    total_params = 0
    for name, parameter in model.named_parameters():
        if not parameter.requires_grad: continue
        params = parameter.numel()
        total_params += params
    if accelerator.is_main_process:
        print(f"Total Trainable Params: {total_params}")
    return total_params

def create_model(cfg):
    """Factory function to create models based on dataset and model name."""
    model_name = f"{cfg.model}"
    
    if hasattr(models, model_name):
        model_class = getattr(models, model_name)
        return model_class(cfg).cuda()
    else:
        raise ValueError(f"Model '{model_name}' not found. Available models: {[attr for attr in dir(models) if not attr.startswith('_')]}")

if accelerator.is_main_process:
    print(OmegaConf.to_yaml(cfg))

#Set path for saving metrics
path = os.path.join('metrics', f'{cfg.project_name}', f'{cfg.model}_{cfg.test_name}')
if accelerator.is_main_process:
    if not os.path.exists(path):
        os.makedirs(path)
    
    # Create src folder for source code backup
    src_path = os.path.join(path, 'src')
    if not os.path.exists(src_path):
        os.makedirs(src_path)
    
    # Save config.yaml
    config_file = os.path.join(path, 'config.yaml')
    print(f'Saving config to {config_file}')
    OmegaConf.save(cfg, config_file)
    
    # Save source files
    import shutil
    source_files = {
        'main.py': 'main.py',
        'train.py': 'train.py', 
        'dataset_loader.py': 'dataset_loader.py',
        'models/ansysLPFMs.py': 'models/ansysLPFMs.py'
    }
    
    for src_file, dst_name in source_files.items():
        if os.path.exists(src_file):
            dst_file = os.path.join(src_path, dst_name)
            # Create directory if needed (for models/Transformer.py)
            os.makedirs(os.path.dirname(dst_file), exist_ok=True)
            shutil.copy2(src_file, dst_file)
            print(f'Saved source file: {src_file} -> {dst_file}')
        else:
            print(f'Warning: Source file not found: {src_file}')

# Create model using factory function
model = create_model(cfg)
total_parameters = count_parameters(model)

# Training function mapping
trainer_map = {
    'artery': trainers.train_artery_main,
    'shapenet_car_pv': trainers.train_shapenet_car_pv_main,
    'DriveAerNet': trainers.train_DriveAerNet_main,
    'DrivAerML': trainers.train_DrivAerML_main,
    'elasticity': trainers.train_elasticity_main,
    'plane_engine_test': trainers.train_plane_engine_test_main,
    'driveaerpp': trainers.train_driveaerpp_main,
    'deepjeb': trainers.train_deepjeb_main,
    'plane_transonic': train_main,
    'cadillac': train_main
}

# Train the model
if cfg.dataset_name in trainer_map:
    trainer_map[cfg.dataset_name](model, path, cfg, accelerator)
else:
    raise ValueError(f"No trainer found for dataset: {cfg.dataset_name}")