# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. # SPDX-FileCopyrightText: All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import hydra import torch from omegaconf import DictConfig, OmegaConf from physicsnemo.distributed import DistributedManager from physicsnemo.launch.logging import LaunchLogger, PythonLogger from physicsnemo.sym.hydra import to_absolute_path from torch.nn.parallel import DistributedDataParallel from torch.optim import AdamW import time from dataloaders import Dedalus2DDataset, MHDDataloaderVecPot from losses import LossMHDVecPot_PhysicsNeMo from tfno import TFNO from utils.plot_utils import plot_predictions_mhd, plot_predictions_mhd_plotly dtype = torch.float torch.set_default_dtype(dtype) @hydra.main( version_base="1.3", config_path="config", config_name="eval_mhd_vec_pot_tfno.yaml" ) def main(cfg: DictConfig) -> None: DistributedManager.initialize() # Only call this once in the entire script! dist = DistributedManager() # call if required elsewhere cfg = OmegaConf.to_container(cfg, resolve=True) # initialize monitoring log = PythonLogger(name="mhd_pino") log.file_logging() # Load config file parameters model_params = cfg["model_params"] dataset_params = cfg["dataset_params"] test_loader_params = cfg["test_loader_params"] loss_params = cfg["loss_params"] optimizer_params = cfg["optimizer_params"] output_dir = cfg["output_dir"] test_params = cfg["test"] load_checkpoint = cfg.get("load_ckpt", False) output_dir = to_absolute_path(output_dir) os.makedirs(output_dir, exist_ok=True) data_dir = dataset_params["data_dir"] # Construct dataloaders dataset_test = Dedalus2DDataset( data_dir, output_names=dataset_params["output_names"], field_names=dataset_params["field_names"], num_train=dataset_params["num_train"], num_test=dataset_params["num_test"], num=dataset_params["num"], use_train=False, ) mhd_dataloader_test = MHDDataloaderVecPot( dataset_test, sub_x=dataset_params["sub_x"], sub_t=dataset_params["sub_t"], ind_x=dataset_params["ind_x"], ind_t=dataset_params["ind_t"], ) dataloader_test, sampler_test = mhd_dataloader_test.create_dataloader( batch_size=test_loader_params["batch_size"], shuffle=test_loader_params["shuffle"], num_workers=test_loader_params["num_workers"], pin_memory=test_loader_params["pin_memory"], distributed=dist.distributed, ) # define FNO model model = TFNO( in_channels=model_params["in_dim"], out_channels=model_params["out_dim"], decoder_layers=model_params["decoder_layers"], decoder_layer_size=model_params["fc_dim"], dimension=model_params["dimension"], latent_channels=model_params["layers"], num_fno_layers=model_params["num_fno_layers"], num_fno_modes=model_params["modes"], padding=[model_params["pad_z"], model_params["pad_y"], model_params["pad_x"]], rank=model_params["rank"], factorization=model_params["factorization"], fixed_rank_modes=model_params["fixed_rank_modes"], ).to(dist.device) # Set up DistributedDataParallel if using more than a single process. # The `distributed` property of DistributedManager can be used to # check this. if dist.distributed: ddps = torch.cuda.Stream() with torch.cuda.stream(ddps): model = DistributedDataParallel( model, device_ids=[dist.local_rank], # Set the device_id to be # the local rank of this process on # this node output_device=dist.device, broadcast_buffers=dist.broadcast_buffers, find_unused_parameters=dist.find_unused_parameters, ) torch.cuda.current_stream().wait_stream(ddps) # Construct optimizer and scheduler optimizer = AdamW( model.parameters(), betas=optimizer_params["betas"], lr=optimizer_params["lr"], weight_decay=0.1, ) scheduler = torch.optim.lr_scheduler.MultiStepLR( optimizer, milestones=optimizer_params["milestones"], gamma=optimizer_params["gamma"], ) # Construct Loss class mhd_loss = LossMHDVecPot_PhysicsNeMo(**loss_params) # Load model from checkpoint (if exists) if load_checkpoint: _ = load_checkpoint( test_params["ckpt_path"], model, optimizer, scheduler, device=dist.device ) # Eval Loop names = dataset_params["fields"] input_norm = torch.tensor(model_params["input_norm"]).to(dist.device) output_norm = torch.tensor(model_params["output_norm"]).to(dist.device) with LaunchLogger("test") as log: # Val loop model.eval() plot_count = 0 with torch.no_grad(): for i, (inputs, outputs) in enumerate(dataloader_test): inputs = inputs.type(dtype).to(dist.device) outputs = outputs.type(dtype).to(dist.device) start_time = time.time() # Compute Predictions pred = ( model((inputs / input_norm).permute(0, 4, 1, 2, 3)).permute( 0, 2, 3, 4, 1 ) * output_norm ) end_time = time.time() print(f"Inference Time: {end_time-start_time}") # Compute Loss loss, loss_dict = mhd_loss(pred, outputs, inputs, return_loss_dict=True) log.log_minibatch(loss_dict) # Get prediction plots for j, _ in enumerate(pred): # Make plots for each field for index, name in enumerate(names): # Generate figure _ = plot_predictions_mhd_plotly( pred[j].cpu(), outputs[j].cpu(), inputs[j].cpu(), index=index, name=name, ) plot_count += 1 # Get prediction plots and save images locally for j, _ in enumerate(pred): # Generate figure plot_predictions_mhd( pred[j].cpu(), outputs[j].cpu(), inputs[j].cpu(), names=names, save_path=os.path.join( output_dir, "MHD_eval_" + str(dist.rank), ), save_suffix=i, ) if __name__ == "__main__": main()