import pyvista as pv import numpy as np from scipy.spatial.distance import jensenshannon from scipy.ndimage import gaussian_filter import matplotlib.pyplot as plt import os # -------------------- PyVista Config -------------------- pv.OFF_SCREEN = True pv.set_plot_theme("document") # -------------------- Paths -------------------- dataset = "plane_transonic" train_folder = f'/raid/ansysai/pkakka/6-Transformers/comparePhysicsLM/Data/{dataset}/' train_save_dir = os.path.join(train_folder, f"../../metrics/{dataset}/") os.makedirs(train_save_dir, exist_ok=True) # -------------------- Utility Functions -------------------- def get_points(file_path, max_points=5000): """Extract and subsample point cloud from VTP file.""" try: mesh = pv.read(file_path) points = mesh.points if len(points) > max_points: indices = np.random.choice(len(points), max_points, replace=False) points = points[indices] return points except Exception as e: raise ValueError(f"Error reading {file_path}: {e}") # -------------------- Training Histogram -------------------- def compute_training_dist(train_folder, output_file='train_dist.npz', num_bins=25, smooth_sigma=1): """Compute combined training point cloud histogram (no normalization, with smoothing).""" train_txt_path = os.path.join(train_folder, '1_VTK_surface/train.txt') if not os.path.exists(train_txt_path): raise ValueError(f"train.txt not found at {train_txt_path}") with open(train_txt_path, 'r') as f: folder_names = [line.strip() for line in f if line.strip()] train_files = [] for name in folder_names: vtp_file = os.path.join(train_folder, '1_VTK_surface', name, f'{name}.vtp') if os.path.exists(vtp_file): train_files.append(vtp_file) else: print(f"Warning: VTP not found: {vtp_file}") if not train_files: raise ValueError("No training VTPs found.") # Combine all training points train_points = np.concatenate([get_points(f) for f in train_files], axis=0) # 3D histogram edges bin_edges = [np.histogram_bin_edges(train_points[:, i], bins=num_bins) for i in range(3)] train_hist, _ = np.histogramdd(train_points, bins=bin_edges, density=True) # Gaussian smoothing train_hist = gaussian_filter(train_hist, sigma=smooth_sigma) # Flatten and normalize train_hist = train_hist.flatten() train_hist /= train_hist.sum() # Save histogram and edges output_path = os.path.join(train_folder+"1_VTK_surface", output_file) np.savez(output_path, hist=train_hist, edges0=bin_edges[0], edges1=bin_edges[1], edges2=bin_edges[2]) print(f"Training histogram saved: {output_path} ({train_points.shape[0]} points)") # -------------------- JS Score for Test -------------------- def compute_js_score(test_file, train_dist_file='train_dist.npz', smooth_sigma=1): """Compute JS score (1 - JS divergence) for a test VTP file with smoothing.""" data = np.load(os.path.join(train_save_dir, train_dist_file)) train_hist = data['hist'] bin_edges = [data['edges0'], data['edges1'], data['edges2']] test_points = get_points(test_file) # 3D histogram for test points test_hist, _ = np.histogramdd(test_points, bins=bin_edges, density=True) test_hist = gaussian_filter(test_hist, sigma=smooth_sigma) # smooth test histogram # Flatten and normalize test_hist = test_hist.flatten() test_hist /= test_hist.sum() # Compute JS divergence div = jensenshannon(train_hist, test_hist) return 1 - div # -------------------- Analyze Test Folder -------------------- def analyze_and_save_scores(test_folder, train_dist_file='train_dist.npz', output_file='test_js_scores.txt', num_bins=25, smooth_sigma=1): """Compute JS scores for all test cases and save results.""" test_txt_path = os.path.join(train_folder, '1_VTK_surface/test.txt') if not os.path.exists(test_txt_path): raise ValueError(f"test.txt not found at {test_txt_path}") with open(test_txt_path, 'r') as f: folder_names = [line.strip() for line in f if line.strip()] if not folder_names: raise ValueError("No test cases found.") output_path = os.path.join(train_save_dir, output_file) scores = [] names = [] with open(output_path, 'w') as f_out: f_out.write("Test_File\tJS_Score\n") for name in folder_names: vtp_file = os.path.join(train_folder, '1_VTK_surface', name, f'{name}.vtp') if os.path.exists(vtp_file): try: score = compute_js_score(vtp_file, train_dist_file, smooth_sigma) scores.append(score) names.append(name) print(f"{name}: {score:.4f}") f_out.write(f"{name}\t{score:.6f}\n") except Exception as e: print(f"Error for {name}: {e}") f_out.write(f"{name}\tERROR: {e}\n") else: print(f"Warning: VTP not found: {vtp_file}") f_out.write(f"{name}\tERROR: VTP not found\n") if scores: print(f"\nAverage Score: {np.mean(scores):.4f} ± {np.std(scores):.4f}") print(f"Min/Max: {np.min(scores):.4f} / {np.max(scores):.4f}") # Plot histogram plt.figure(figsize=(6, 4)) plt.hist(scores, bins=10, alpha=0.7, edgecolor='black') plt.xlabel('JS Score (Higher = Closer to Train)') plt.ylabel('Count') plt.title('Test Geometry JS Scores') plt.savefig(os.path.join(train_save_dir, 'js_scores_hist.png')) plt.close() return names, scores # -------------------- Optional Visualization -------------------- def visualize_sample(test_folder, show_plot=True): """Plot the first test geometry (optional).""" if not show_plot: print("Skipping visualization (set show_plot=True to enable)") return test_txt_path = os.path.join(train_folder, '1_VTK_surface/test.txt') with open(test_txt_path, 'r') as f: folder_names = [line.strip() for line in f if line.strip()] if folder_names: name = folder_names[0] vtp_file = os.path.join(train_folder, '1_VTK_surface', name, f'{name}.vtp') if os.path.exists(vtp_file): try: mesh = pv.read(vtp_file) plotter = pv.Plotter(off_screen=True) plotter.add_mesh(mesh, color='blue', show_edges=True) plotter.add_title(f'Sample Geometry: {name}') screenshot_path = os.path.join(train_save_dir, f'sample_geometry_{name}.png') plotter.screenshot(screenshot_path) plotter.close() print(f"Sample screenshot saved: {screenshot_path}") except Exception as e: print(f"Warning: Could not visualize geometry: {e}") else: print(f"Warning: VTP file not found: {vtp_file}") else: print("Warning: No test cases found for visualization") # -------------------- Main -------------------- if __name__ == "__main__": try: print("Computing training histogram...") compute_training_dist(train_folder, num_bins=25, smooth_sigma=1) print("Analyzing test cases...") names, scores = analyze_and_save_scores(train_folder, num_bins=25, smooth_sigma=1) visualize_sample(train_folder, show_plot=False) print("JS analysis completed successfully!") except Exception as e: print(f"Error during analysis: {e}") import traceback traceback.print_exc()