Coverage for pybmc/data.py: 76%
162 statements
« prev ^ index » next coverage.py v7.10.0, created at 2025-07-27 15:48 +0000
« prev ^ index » next coverage.py v7.10.0, created at 2025-07-27 15:48 +0000
1import numpy as np
2import pandas as pd
3from sklearn.model_selection import train_test_split # type: ignore
4import os
7class Dataset:
8 """
9 Manages datasets for Bayesian model combination workflows.
11 Supports loading data from HDF5 and CSV files, splitting data, and filtering.
13 Attributes:
14 data_source (str): Path to data file.
15 data (dict[str, pandas.DataFrame]): Dictionary of loaded data by property.
16 domain_keys (list[str]): Domain columns used for data alignment.
17 """
19 def __init__(self, data_source=None):
20 """
21 Initializes the Dataset instance.
23 Args:
24 data_source (str, optional): Path to data file (.h5 or .csv).
25 """
26 self.data_source = data_source
27 self.data = {} # Dictionary of model to DataFrame
28 self.domain_keys = ["X1", "X2"] # Default domain keys
30 def load_data(self, models, keys=None, domain_keys=None, model_column="model"):
31 """
32 Loads data for multiple properties and models.
34 Args:
35 models (list[str]): Model names to load.
36 keys (list[str]): Property names to extract.
37 domain_keys (list[str], optional): Domain columns (default: ['N', 'Z']).
38 model_column (str, optional): CSV column identifying models (default: 'model').
40 Returns:
41 dict[str, pandas.DataFrame]: Dictionary of DataFrames keyed by property name.
43 Raises:
44 ValueError: If `data_source` not specified or `keys` missing.
45 FileNotFoundError: If `data_source` doesn't exist.
47 Example:
48 >>> dataset = Dataset('data.h5')
49 >>> data = dataset.load_data(
50 models=['model1', 'model2'],
51 keys=['BE', 'Rad'],
52 domain_keys=['Z', 'N']
53 )
54 """
55 self.domain_keys = domain_keys
57 if self.data_source is None:
58 raise ValueError("Data source must be specified.")
59 if not os.path.exists(self.data_source):
60 raise FileNotFoundError(f"Data source '{self.data_source}' not found.")
61 if keys is None:
62 raise ValueError("You must specify which properties to extract via 'keys'.")
64 result = {}
66 for prop in keys:
67 dfs = []
68 skipped_models = []
70 if self.data_source.endswith(".h5"):
71 for model in models:
72 df = pd.read_hdf(self.data_source, key=model)
73 # Check required columns
74 missing_cols = [
75 col for col in domain_keys + [prop] if col not in df.columns
76 ]
77 if missing_cols:
78 print(
79 f"[Skipped] Model '{model}' missing columns {missing_cols} for property '{prop}'."
80 )
81 skipped_models.append(model)
82 continue
83 temp = df[domain_keys + [prop]].copy()
84 temp.rename(columns={prop: model}, inplace=True) # type: ignore
85 dfs.append(temp)
86 elif self.data_source.endswith(".csv"):
87 df = pd.read_csv(self.data_source)
88 for model in models:
89 if model_column not in df.columns:
90 raise ValueError(
91 f"Expected column '{model_column}' not found in CSV."
92 )
93 model_df = df[df[model_column] == model]
94 missing_cols = [
95 col
96 for col in domain_keys + [prop]
97 if col not in model_df.columns
98 ]
99 if missing_cols:
100 print(
101 f"[Skipped] Model '{model}' missing columns {missing_cols} for key '{prop}'."
102 )
103 skipped_models.append(model)
104 continue
105 temp = model_df[domain_keys + [prop]].copy()
106 temp.rename(columns={prop: model}, inplace=True)
107 dfs.append(temp)
108 else:
109 raise ValueError(
110 "Unsupported file format. Only .h5 and .csv are supported."
111 )
113 if not dfs:
114 print(
115 f"[Warning] No models with property '{prop}'. Resulting DataFrame will be empty."
116 )
117 result[prop] = pd.DataFrame(
118 columns=domain_keys + [m for m in models if m not in skipped_models]
119 )
120 continue
122 # Intersect domain for this property
123 common_df = dfs[0]
124 for other_df in dfs[1:]:
125 common_df = pd.merge(common_df, other_df, on=domain_keys, how="inner")
127 result[prop] = common_df
128 self.data = result
129 return result
131 def view_data(self, property_name=None, model_name=None):
132 """
133 Provides flexible data viewing options.
135 Args:
136 property_name (str, optional): Specific property to view.
137 model_name (str, optional): Specific model to view.
139 Returns:
140 Union[dict[str, Union[pandas.DataFrame, str]], pandas.DataFrame, pandas.Series]:
141 - If no args: dict of available properties/models.
142 - If only `model_name`: dict of `{property: DataFrame}`.
143 - If only `property_name`: DataFrame for property.
144 - If both: Series of model values for property.
146 Raises:
147 RuntimeError: If no data loaded.
148 KeyError: If property or model not found.
149 """
151 if not self.data:
152 raise RuntimeError("No data loaded. Run `load_data(...)` first.")
154 if property_name is None and model_name is None:
155 props = list(self.data.keys())
156 models = sorted(
157 set(
158 col
159 for prop_df in self.data.values()
160 for col in prop_df.columns
161 if col not in self.domain_keys
162 )
163 )
165 return {"available_properties": props, "available_models": models}
167 if model_name is not None and property_name is None:
168 # Return a dictionary: {property: Series of model values}
169 result = {}
170 for prop, df in self.data.items():
171 if model_name in df.columns:
172 cols = self.domain_keys + [model_name]
173 result[prop] = df[cols]
174 else:
175 result[prop] = f"[Model '{model_name}' not available]"
176 return result
178 if property_name is not None:
179 if property_name not in self.data:
180 raise KeyError(f"Property '{property_name}' not found.")
182 df = self.data[property_name]
184 if model_name is None:
185 return df # Full property DataFrame
187 if model_name not in df.columns:
188 raise KeyError(
189 f"Model '{model_name}' not found in property '{property_name}'."
190 )
192 return df[model_name]
194 def separate_points_distance_allSets(self, list1, list2, distance1, distance2):
195 """
196 Separates points into groups based on proximity thresholds.
198 Args:
199 list1 (list[tuple[float, float]]): Points to classify as (x, y) tuples.
200 list2 (list[tuple[float, float]]): Reference points as (x, y) tuples.
201 distance1 (float): First proximity threshold.
202 distance2 (float): Second proximity threshold.
204 Returns:
205 tuple[list[int], list[int], list[int]]: Three lists of indices from `list1`:
206 - Within `distance1` of any point in `list2`.
207 - Within `distance2` but not `distance1`.
208 - Beyond `distance2`.
209 """
210 train = []
211 validation = []
212 test = []
214 train_list_coordinates = []
215 validation_list_coordinates = []
216 test_list_coordinates = []
218 for i in range(len(list1)):
219 point1 = list1[i]
220 close = False
221 for point2 in list2:
222 if np.linalg.norm(np.array(point1) - np.array(point2)) <= distance1:
223 close = True
224 break
225 if close:
226 train.append(point1)
227 train_list_coordinates.append(i)
228 else:
229 close2 = False
230 for point2 in list2:
231 if np.linalg.norm(np.array(point1) - np.array(point2)) <= distance2:
232 close2 = True
233 break
234 if close2:
235 validation.append(point1)
236 validation_list_coordinates.append(i)
237 else:
238 test.append(point1)
239 test_list_coordinates.append(i)
241 return (
242 train_list_coordinates,
243 validation_list_coordinates,
244 test_list_coordinates,
245 )
247 def split_data(
248 self, data_dict, property_name, splitting_algorithm="random", **kwargs
249 ):
250 """
251 Splits data into training, validation, and test sets.
253 Args:
254 data_dict (dict[str, pandas.DataFrame]): Output from `load_data()`.
255 property_name (str): Property to use for splitting.
256 splitting_algorithm (str): 'random' or 'inside_to_outside'.
257 **kwargs: Algorithm-specific parameters:
258 - `random`: `train_size` (float), `val_size` (float), `test_size` (float).
259 - `inside_to_outside`: `stable_points` (list[tuple[float, float]]), `distance1` (float), `distance2` (float).
261 Returns:
262 tuple[pandas.DataFrame, pandas.DataFrame, pandas.DataFrame]: (train, validation, test) DataFrames.
264 Raises:
265 ValueError: For invalid algorithm or missing parameters.
266 """
267 if property_name not in data_dict:
268 raise ValueError(
269 f"Property '{property_name}' not found in the provided data dictionary."
270 )
272 data = data_dict[property_name]
274 if isinstance(data, pd.DataFrame):
275 indexable_data = data.reset_index(drop=True)
276 point_list = list(indexable_data.itertuples(index=False, name=None))
277 else:
278 raise TypeError(
279 "Data for the specified property must be a pandas DataFrame."
280 )
282 if splitting_algorithm == "random":
283 required = ["train_size", "val_size", "test_size"]
284 if not all(k in kwargs for k in required):
285 raise ValueError(f"Missing required kwargs for 'random': {required}")
287 train_size = kwargs["train_size"]
288 val_size = kwargs["val_size"]
289 test_size = kwargs["test_size"]
291 if not np.isclose(train_size + val_size + test_size, 1.0):
292 raise ValueError("train_size + val_size + test_size must equal 1.0")
294 # Random split using indexes
295 train_idx, temp_idx = train_test_split(
296 indexable_data.index, train_size=train_size, random_state=1
297 )
298 val_rel = val_size / (val_size + test_size)
299 val_idx, test_idx = train_test_split(
300 temp_idx, test_size=1 - val_rel, random_state=1
301 )
303 elif splitting_algorithm == "inside_to_outside":
304 required = ["stable_points", "distance1", "distance2"]
305 if not all(k in kwargs for k in required):
306 raise ValueError(
307 f"Missing required kwargs for 'inside_to_outside': {required}"
308 )
310 stable_points = kwargs["stable_points"]
311 distance1 = kwargs["distance1"]
312 distance2 = kwargs["distance2"]
314 (
315 train_idx,
316 val_idx,
317 test_idx,
318 ) = self.separate_points_distance_allSets(
319 point_list, stable_points, distance1, distance2
320 )
321 else:
322 raise ValueError(
323 "splitting_algorithm must be either 'random' or 'inside_to_outside'"
324 )
326 train_data = indexable_data.iloc[train_idx]
327 val_data = indexable_data.iloc[val_idx]
328 test_data = indexable_data.iloc[test_idx]
330 return train_data, val_data, test_data
332 def get_subset(self, property_name, filters=None, models_to_include=None):
333 """
334 Returns a filtered subset of data for a property.
336 Args:
337 property_name (str): Property to filter.
338 filters (dict, optional): Domain filtering rules.
339 models_to_include (list[str], optional): Models to include.
341 Returns:
342 pandas.DataFrame: Filtered DataFrame.
344 Raises:
345 ValueError: If property not found.
346 """
347 if property_name not in self.data:
348 raise ValueError(f"Property '{property_name}' not found in dataset.")
350 df = self.data[property_name].copy()
352 # Apply row-level filters (domain-based)
353 if filters:
354 for column, condition in filters.items():
355 if column == "multi" and callable(condition):
356 df = df[df.apply(condition, axis=1)]
357 elif callable(condition):
358 df = df[condition(df[column])]
359 elif isinstance(condition, tuple) and len(condition) == 2:
360 df = df[(df[column] >= condition[0]) & (df[column] <= condition[1])]
361 elif isinstance(condition, list):
362 df = df[df[column].isin(condition)]
363 else:
364 df = df[df[column] == condition]
366 # Optionally restrict to a subset of models
367 if models_to_include is not None:
368 domain_keys = [col for col in ["N", "Z"] if col in df.columns]
369 allowed_cols = domain_keys + [
370 m for m in models_to_include if m in df.columns
371 ]
372 df = df[allowed_cols]
374 return df