Skip to content

Commit 37747cd

Browse files
Commit directly to the main branch
1 parent a9139c5 commit 37747cd

File tree

6 files changed

+1525
-0
lines changed

6 files changed

+1525
-0
lines changed

RegDGCNN_SurfaceFields/data_loader.py

Lines changed: 203 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,203 @@
1+
# data_loader.py
2+
"""
3+
@author: Mohamed Elrefaie, [email protected]
4+
5+
Data loading utilities for the DrivAerNet++ dataset.
6+
7+
This module provides functionality for loading and preprocessing point cloud data
8+
with pressure field information from the DrivAerNet++ dataset.
9+
"""
10+
11+
import os
12+
import numpy as np
13+
import torch
14+
from torch.utils.data import Dataset, Subset, DataLoader
15+
import torch.distributed as dist
16+
import pyvista as pv
17+
import logging
18+
19+
20+
class SurfacePressureDataset(Dataset):
21+
"""
22+
Dataset class for loading and preprocessing surface pressure data from DrivAerNet++ VTK files.
23+
24+
This dataset handles loading surface meshes with pressure field data,
25+
sampling points, and caching processed data for faster loading.
26+
"""
27+
28+
def __init__(self, root_dir: str, num_points: int, preprocess=False, cache_dir=None):
29+
"""
30+
Initializes the SurfacePressureDataset instance.
31+
32+
Args:
33+
root_dir: Directory containing the VTK files for the car surface meshes.
34+
num_points: Fixed number of points to sample from each 3D model.
35+
preprocess: Flag to indicate if preprocessing should occur or not.
36+
cache_dir: Directory where the preprocessed files (NPZ) are stored.
37+
"""
38+
self.root_dir = root_dir
39+
self.vtk_files = [os.path.join(root_dir, f) for f in os.listdir(root_dir) if f.endswith('.vtk')]
40+
self.num_points = num_points
41+
self.preprocess = preprocess
42+
self.cache_dir = cache_dir if cache_dir else os.path.join(root_dir, "processed_data")
43+
44+
if not os.path.exists(self.cache_dir):
45+
os.makedirs(self.cache_dir)
46+
47+
def __len__(self):
48+
return len(self.vtk_files)
49+
50+
def _get_cache_path(self, vtk_file_path):
51+
"""Get the corresponding .npz file path for a given .vtk file."""
52+
base_name = os.path.basename(vtk_file_path).replace('.vtk', '.npz')
53+
return os.path.join(self.cache_dir, base_name)
54+
55+
def _save_to_cache(self, cache_path, point_cloud, pressures):
56+
"""Save preprocessed point cloud and pressure data into an npz file."""
57+
np.savez_compressed(cache_path, points=point_cloud.points, pressures=pressures)
58+
59+
def _load_from_cache(self, cache_path):
60+
"""Load preprocessed point cloud and pressure data from an npz file."""
61+
data = np.load(cache_path)
62+
point_cloud = pv.PolyData(data['points'])
63+
pressures = data['pressures']
64+
return point_cloud, pressures
65+
66+
def sample_point_cloud_with_pressure(self, mesh, n_points=5000):
67+
"""
68+
Sample n_points from the surface mesh and get corresponding pressure values.
69+
70+
Args:
71+
mesh: PyVista mesh object with pressure data stored in point_data.
72+
n_points: Number of points to sample.
73+
74+
Returns:
75+
A tuple containing the sampled point cloud and corresponding pressures.
76+
"""
77+
if mesh.n_points > n_points:
78+
indices = np.random.choice(mesh.n_points, n_points, replace=False)
79+
else:
80+
indices = np.arange(mesh.n_points)
81+
logging.info(f"Mesh has only {mesh.n_points} points. Using all available points.")
82+
83+
sampled_points = mesh.points[indices]
84+
sampled_pressures = mesh.point_data['p'][indices] # Assuming pressure data is stored under key 'p'
85+
sampled_pressures = sampled_pressures.flatten() # Ensure it's a flat array
86+
87+
return pv.PolyData(sampled_points), sampled_pressures
88+
89+
def __getitem__(self, idx):
90+
vtk_file_path = self.vtk_files[idx]
91+
cache_path = self._get_cache_path(vtk_file_path)
92+
93+
# Check if the data is already cached
94+
if os.path.exists(cache_path):
95+
logging.info(f"Loading cached data from {cache_path}")
96+
point_cloud, pressures = self._load_from_cache(cache_path)
97+
else:
98+
if self.preprocess:
99+
logging.info(f"Preprocessing and caching data for {vtk_file_path}")
100+
try:
101+
mesh = pv.read(vtk_file_path)
102+
except Exception as e:
103+
logging.error(f"Failed to load VTK file: {vtk_file_path}. Error: {e}")
104+
return None, None # Skip the file and return None
105+
106+
point_cloud, pressures = self.sample_point_cloud_with_pressure(mesh, self.num_points)
107+
108+
# Cache the sampled data to a new file
109+
self._save_to_cache(cache_path, point_cloud, pressures)
110+
else:
111+
logging.error(f"Cache file not found for {vtk_file_path} and preprocessing is disabled.")
112+
return None, None # Return None if preprocessing is disabled and cache doesn't exist
113+
114+
point_cloud_np = np.array(point_cloud.points)
115+
point_cloud_tensor = torch.tensor(point_cloud_np.T[np.newaxis, :, :], dtype=torch.float32)
116+
pressures_tensor = torch.tensor(pressures[np.newaxis, :], dtype=torch.float32)
117+
118+
return point_cloud_tensor, pressures_tensor
119+
120+
121+
def create_subset(dataset, ids_file):
122+
"""
123+
Create a subset of the dataset based on design IDs from a file.
124+
125+
Args:
126+
dataset: The full dataset
127+
ids_file: Path to a file containing design IDs, one per line
128+
129+
Returns:
130+
A Subset of the dataset containing only the specified designs
131+
"""
132+
try:
133+
with open(ids_file, 'r') as file:
134+
subset_ids = [id_.strip() for id_ in file.readlines()]
135+
subset_files = [f for f in dataset.vtk_files if any(id_ in f for id_ in subset_ids)]
136+
subset_indices = [dataset.vtk_files.index(f) for f in subset_files]
137+
if not subset_indices:
138+
logging.error(f"No matching VTK files found for IDs in {ids_file}.")
139+
return Subset(dataset, subset_indices)
140+
except FileNotFoundError as e:
141+
logging.error(f"Error loading subset file {ids_file}: {e}")
142+
return None
143+
144+
145+
def get_dataloaders(dataset_path: str, subset_dir: str, num_points: int, batch_size: int,
146+
world_size: int, rank: int, cache_dir: str = None, num_workers: int = 4) -> tuple:
147+
"""
148+
Prepare and return the training, validation, and test DataLoader objects.
149+
150+
Args:
151+
dataset_path: Path to the directory containing VTK files
152+
subset_dir: Directory containing train/val/test split files
153+
num_points: Number of points to sample from each mesh
154+
batch_size: Batch size for dataloaders
155+
world_size: Total number of processes for distributed training
156+
rank: Current process rank
157+
cache_dir: Directory to store processed data
158+
num_workers: Number of workers for data loading
159+
160+
Returns:
161+
A tuple of (train_dataloader, val_dataloader, test_dataloader)
162+
"""
163+
full_dataset = SurfacePressureDataset(
164+
root_dir=dataset_path,
165+
num_points=num_points,
166+
preprocess=True,
167+
cache_dir=cache_dir
168+
)
169+
170+
train_dataset = create_subset(full_dataset, os.path.join(subset_dir, 'train_design_ids.txt'))
171+
val_dataset = create_subset(full_dataset, os.path.join(subset_dir, 'val_design_ids.txt'))
172+
test_dataset = create_subset(full_dataset, os.path.join(subset_dir, 'test_design_ids.txt'))
173+
174+
# Distributed samplers for DDP
175+
train_sampler = torch.utils.data.distributed.DistributedSampler(
176+
train_dataset, num_replicas=world_size, rank=rank
177+
)
178+
val_sampler = torch.utils.data.distributed.DistributedSampler(
179+
val_dataset, num_replicas=world_size, rank=rank
180+
)
181+
test_sampler = torch.utils.data.distributed.DistributedSampler(
182+
test_dataset, num_replicas=world_size, rank=rank
183+
)
184+
185+
train_dataloader = DataLoader(
186+
train_dataset, batch_size=batch_size, sampler=train_sampler,
187+
drop_last=True, num_workers=num_workers
188+
)
189+
val_dataloader = DataLoader(
190+
val_dataset, batch_size=batch_size, sampler=val_sampler,
191+
drop_last=True, num_workers=num_workers
192+
)
193+
test_dataloader = DataLoader(
194+
test_dataset, batch_size=batch_size, sampler=test_sampler,
195+
drop_last=True, num_workers=num_workers
196+
)
197+
198+
return train_dataloader, val_dataloader, test_dataloader
199+
200+
201+
# Constants for normalization
202+
PRESSURE_MEAN = -94.5
203+
PRESSURE_STD = 117.25

0 commit comments

Comments
 (0)