1
+ # data_loader.py
2
+ """
3
+ @author: Mohamed Elrefaie, [email protected]
4
+
5
+ Data loading utilities for the DrivAerNet++ dataset.
6
+
7
+ This module provides functionality for loading and preprocessing point cloud data
8
+ with pressure field information from the DrivAerNet++ dataset.
9
+ """
10
+
11
+ import os
12
+ import numpy as np
13
+ import torch
14
+ from torch .utils .data import Dataset , Subset , DataLoader
15
+ import torch .distributed as dist
16
+ import pyvista as pv
17
+ import logging
18
+
19
+
20
+ class SurfacePressureDataset (Dataset ):
21
+ """
22
+ Dataset class for loading and preprocessing surface pressure data from DrivAerNet++ VTK files.
23
+
24
+ This dataset handles loading surface meshes with pressure field data,
25
+ sampling points, and caching processed data for faster loading.
26
+ """
27
+
28
+ def __init__ (self , root_dir : str , num_points : int , preprocess = False , cache_dir = None ):
29
+ """
30
+ Initializes the SurfacePressureDataset instance.
31
+
32
+ Args:
33
+ root_dir: Directory containing the VTK files for the car surface meshes.
34
+ num_points: Fixed number of points to sample from each 3D model.
35
+ preprocess: Flag to indicate if preprocessing should occur or not.
36
+ cache_dir: Directory where the preprocessed files (NPZ) are stored.
37
+ """
38
+ self .root_dir = root_dir
39
+ self .vtk_files = [os .path .join (root_dir , f ) for f in os .listdir (root_dir ) if f .endswith ('.vtk' )]
40
+ self .num_points = num_points
41
+ self .preprocess = preprocess
42
+ self .cache_dir = cache_dir if cache_dir else os .path .join (root_dir , "processed_data" )
43
+
44
+ if not os .path .exists (self .cache_dir ):
45
+ os .makedirs (self .cache_dir )
46
+
47
+ def __len__ (self ):
48
+ return len (self .vtk_files )
49
+
50
+ def _get_cache_path (self , vtk_file_path ):
51
+ """Get the corresponding .npz file path for a given .vtk file."""
52
+ base_name = os .path .basename (vtk_file_path ).replace ('.vtk' , '.npz' )
53
+ return os .path .join (self .cache_dir , base_name )
54
+
55
+ def _save_to_cache (self , cache_path , point_cloud , pressures ):
56
+ """Save preprocessed point cloud and pressure data into an npz file."""
57
+ np .savez_compressed (cache_path , points = point_cloud .points , pressures = pressures )
58
+
59
+ def _load_from_cache (self , cache_path ):
60
+ """Load preprocessed point cloud and pressure data from an npz file."""
61
+ data = np .load (cache_path )
62
+ point_cloud = pv .PolyData (data ['points' ])
63
+ pressures = data ['pressures' ]
64
+ return point_cloud , pressures
65
+
66
+ def sample_point_cloud_with_pressure (self , mesh , n_points = 5000 ):
67
+ """
68
+ Sample n_points from the surface mesh and get corresponding pressure values.
69
+
70
+ Args:
71
+ mesh: PyVista mesh object with pressure data stored in point_data.
72
+ n_points: Number of points to sample.
73
+
74
+ Returns:
75
+ A tuple containing the sampled point cloud and corresponding pressures.
76
+ """
77
+ if mesh .n_points > n_points :
78
+ indices = np .random .choice (mesh .n_points , n_points , replace = False )
79
+ else :
80
+ indices = np .arange (mesh .n_points )
81
+ logging .info (f"Mesh has only { mesh .n_points } points. Using all available points." )
82
+
83
+ sampled_points = mesh .points [indices ]
84
+ sampled_pressures = mesh .point_data ['p' ][indices ] # Assuming pressure data is stored under key 'p'
85
+ sampled_pressures = sampled_pressures .flatten () # Ensure it's a flat array
86
+
87
+ return pv .PolyData (sampled_points ), sampled_pressures
88
+
89
+ def __getitem__ (self , idx ):
90
+ vtk_file_path = self .vtk_files [idx ]
91
+ cache_path = self ._get_cache_path (vtk_file_path )
92
+
93
+ # Check if the data is already cached
94
+ if os .path .exists (cache_path ):
95
+ logging .info (f"Loading cached data from { cache_path } " )
96
+ point_cloud , pressures = self ._load_from_cache (cache_path )
97
+ else :
98
+ if self .preprocess :
99
+ logging .info (f"Preprocessing and caching data for { vtk_file_path } " )
100
+ try :
101
+ mesh = pv .read (vtk_file_path )
102
+ except Exception as e :
103
+ logging .error (f"Failed to load VTK file: { vtk_file_path } . Error: { e } " )
104
+ return None , None # Skip the file and return None
105
+
106
+ point_cloud , pressures = self .sample_point_cloud_with_pressure (mesh , self .num_points )
107
+
108
+ # Cache the sampled data to a new file
109
+ self ._save_to_cache (cache_path , point_cloud , pressures )
110
+ else :
111
+ logging .error (f"Cache file not found for { vtk_file_path } and preprocessing is disabled." )
112
+ return None , None # Return None if preprocessing is disabled and cache doesn't exist
113
+
114
+ point_cloud_np = np .array (point_cloud .points )
115
+ point_cloud_tensor = torch .tensor (point_cloud_np .T [np .newaxis , :, :], dtype = torch .float32 )
116
+ pressures_tensor = torch .tensor (pressures [np .newaxis , :], dtype = torch .float32 )
117
+
118
+ return point_cloud_tensor , pressures_tensor
119
+
120
+
121
+ def create_subset (dataset , ids_file ):
122
+ """
123
+ Create a subset of the dataset based on design IDs from a file.
124
+
125
+ Args:
126
+ dataset: The full dataset
127
+ ids_file: Path to a file containing design IDs, one per line
128
+
129
+ Returns:
130
+ A Subset of the dataset containing only the specified designs
131
+ """
132
+ try :
133
+ with open (ids_file , 'r' ) as file :
134
+ subset_ids = [id_ .strip () for id_ in file .readlines ()]
135
+ subset_files = [f for f in dataset .vtk_files if any (id_ in f for id_ in subset_ids )]
136
+ subset_indices = [dataset .vtk_files .index (f ) for f in subset_files ]
137
+ if not subset_indices :
138
+ logging .error (f"No matching VTK files found for IDs in { ids_file } ." )
139
+ return Subset (dataset , subset_indices )
140
+ except FileNotFoundError as e :
141
+ logging .error (f"Error loading subset file { ids_file } : { e } " )
142
+ return None
143
+
144
+
145
+ def get_dataloaders (dataset_path : str , subset_dir : str , num_points : int , batch_size : int ,
146
+ world_size : int , rank : int , cache_dir : str = None , num_workers : int = 4 ) -> tuple :
147
+ """
148
+ Prepare and return the training, validation, and test DataLoader objects.
149
+
150
+ Args:
151
+ dataset_path: Path to the directory containing VTK files
152
+ subset_dir: Directory containing train/val/test split files
153
+ num_points: Number of points to sample from each mesh
154
+ batch_size: Batch size for dataloaders
155
+ world_size: Total number of processes for distributed training
156
+ rank: Current process rank
157
+ cache_dir: Directory to store processed data
158
+ num_workers: Number of workers for data loading
159
+
160
+ Returns:
161
+ A tuple of (train_dataloader, val_dataloader, test_dataloader)
162
+ """
163
+ full_dataset = SurfacePressureDataset (
164
+ root_dir = dataset_path ,
165
+ num_points = num_points ,
166
+ preprocess = True ,
167
+ cache_dir = cache_dir
168
+ )
169
+
170
+ train_dataset = create_subset (full_dataset , os .path .join (subset_dir , 'train_design_ids.txt' ))
171
+ val_dataset = create_subset (full_dataset , os .path .join (subset_dir , 'val_design_ids.txt' ))
172
+ test_dataset = create_subset (full_dataset , os .path .join (subset_dir , 'test_design_ids.txt' ))
173
+
174
+ # Distributed samplers for DDP
175
+ train_sampler = torch .utils .data .distributed .DistributedSampler (
176
+ train_dataset , num_replicas = world_size , rank = rank
177
+ )
178
+ val_sampler = torch .utils .data .distributed .DistributedSampler (
179
+ val_dataset , num_replicas = world_size , rank = rank
180
+ )
181
+ test_sampler = torch .utils .data .distributed .DistributedSampler (
182
+ test_dataset , num_replicas = world_size , rank = rank
183
+ )
184
+
185
+ train_dataloader = DataLoader (
186
+ train_dataset , batch_size = batch_size , sampler = train_sampler ,
187
+ drop_last = True , num_workers = num_workers
188
+ )
189
+ val_dataloader = DataLoader (
190
+ val_dataset , batch_size = batch_size , sampler = val_sampler ,
191
+ drop_last = True , num_workers = num_workers
192
+ )
193
+ test_dataloader = DataLoader (
194
+ test_dataset , batch_size = batch_size , sampler = test_sampler ,
195
+ drop_last = True , num_workers = num_workers
196
+ )
197
+
198
+ return train_dataloader , val_dataloader , test_dataloader
199
+
200
+
201
+ # Constants for normalization
202
+ PRESSURE_MEAN = - 94.5
203
+ PRESSURE_STD = 117.25
0 commit comments