Coverage for pybmc/data.py: 81%
178 statements
« prev ^ index » next coverage.py v7.10.0, created at 2025-10-14 21:12 +0000
« prev ^ index » next coverage.py v7.10.0, created at 2025-10-14 21:12 +0000
1import numpy as np
2import pandas as pd
3from sklearn.model_selection import train_test_split
4import os
5import logging
10class Dataset:
11 """
12 A general-purpose dataset class for loading and managing model data
13 for Bayesian model combination workflows.
15 Supports .h5 and .csv files, and provides data splitting functionality.
16 """
18 def __init__(self, data_source=None, verbose=True):
19 """
20 Initialize the Dataset object.
22 :param data_source: Path to the data file (.h5 or .csv).
23 :param verbose: If True, display warnings and informational messages. Default is True.
24 """
25 self.data_source = data_source
26 self.data = {} # Dictionary of model to DataFrame
27 self.verbose = verbose
28 self.logger = logging.getLogger(__name__)
29 if not self.logger.handlers:
30 handler = logging.StreamHandler()
31 handler.setFormatter(logging.Formatter('%(message)s'))
32 self.logger.addHandler(handler)
33 self.logger.setLevel(logging.INFO if verbose else logging.WARNING)
35 def load_data(self, models, keys=None, domain_keys=None, model_column='model', truth_column_name=None):
36 """
37 Load data for each property and return a dictionary of synchronized DataFrames.
38 Each DataFrame has columns: domain_keys + one column per model for that property.
40 Parameters:
41 models (list): List of model names (for HDF5 keys or filtering CSV).
42 keys (list): List of property names to extract (each will be a key in the output dict).
43 domain_keys (list, optional): List of columns used to define the common domain (default ['N', 'Z']).
44 model_column (str, optional): Name of the column in the CSV that identifies which model each row belongs to.
45 Only used for CSV files; ignored for HDF5 files.
46 truth_column_name (str, optional): Name of the truth model. If provided, the truth data will be
47 left-joined to the common domain of the other models, allowing
48 the truth data to have a smaller domain than the models.
50 Returns:
51 dict: Dictionary where each key is a property name and each value is a DataFrame with columns:
52 domain_keys + one column per model for that property.
53 The DataFrames are synchronized to the intersection of the domains for all models.
54 If truth_column_name is provided, truth data is left-joined (may have NaN values).
56 Supports both .h5 and .csv files.
57 """
58 self.domain_keys = domain_keys
60 if self.data_source is None:
61 raise ValueError("Data source must be specified.")
62 if not os.path.exists(self.data_source):
63 raise FileNotFoundError(f"Data source '{self.data_source}' not found.")
64 if keys is None:
65 raise ValueError("You must specify which properties to extract via 'keys'.")
67 result = {}
69 for prop in keys:
70 dfs = []
71 truth_df = None
72 skipped_models = []
74 # Separate regular models from truth model
75 regular_models = [m for m in models if m != truth_column_name]
77 if self.data_source.endswith('.h5'):
78 for model in models:
79 df = pd.read_hdf(self.data_source, key=model)
80 # Check required columns
81 missing_cols = [col for col in domain_keys + [prop] if col not in df.columns]
82 if missing_cols:
83 self.logger.info(f"[Skipped] Model '{model}' missing columns {missing_cols} for property '{prop}'.")
84 skipped_models.append(model)
85 continue
86 temp = df[domain_keys + [prop]].copy()
87 temp.rename(columns={prop: model}, inplace=True) # type: ignore
89 # Store truth data separately if truth_column_name is provided
90 if truth_column_name and model == truth_column_name:
91 truth_df = temp
92 else:
93 dfs.append(temp)
95 elif self.data_source.endswith('.csv'):
96 df = pd.read_csv(self.data_source)
97 for model in models:
98 if model_column not in df.columns:
99 raise ValueError(f"Expected column '{model_column}' not found in CSV.")
100 model_df = df[df[model_column] == model]
101 missing_cols = [col for col in domain_keys + [prop] if col not in model_df.columns]
102 if missing_cols:
103 self.logger.info(f"[Skipped] Model '{model}' missing columns {missing_cols} for key '{prop}'.")
104 skipped_models.append(model)
105 continue
106 temp = model_df[domain_keys + [prop]].copy()
107 temp.rename(columns={prop: model}, inplace=True)
109 # Store truth data separately if truth_column_name is provided
110 if truth_column_name and model == truth_column_name:
111 truth_df = temp
112 else:
113 dfs.append(temp)
114 else:
115 raise ValueError("Unsupported file format. Only .h5 and .csv are supported.")
117 if not dfs:
118 self.logger.info(f"[Warning] No models with property '{prop}'. Resulting DataFrame will be empty.")
119 result[prop] = pd.DataFrame(columns=domain_keys + [m for m in models if m not in skipped_models])
120 continue
122 # Intersect domain for regular models only
123 common_df = dfs[0]
124 for other_df in dfs[1:]:
125 common_df = pd.merge(common_df, other_df, on=domain_keys, how="inner")
126 # Drop rows with NaN in any required column (for regular models)
127 common_df = common_df.dropna()
129 # Left join truth data if it exists and was specified
130 if truth_df is not None:
131 common_df = pd.merge(common_df, truth_df, on=domain_keys, how="left")
133 result[prop] = common_df
134 self.data = result
135 return result
137 def view_data(self, property_name=None, model_name=None):
138 """
139 View data flexibly based on input parameters.
141 - No arguments: returns available property names and model names.
142 - property_name only: returns the full DataFrame for that property.
143 - model_name only: Return model values across all properties.
144 - property_name + model_name: returns a Series of values for the model.
146 :param property_name: Optional property name
147 :param model_name: Optional model name
148 :return: dict, DataFrame, or Series depending on input.
149 """
151 if not self.data:
152 raise RuntimeError("No data loaded. Run `load_data(...)` first.")
154 if property_name is None and model_name is None:
155 props = list(self.data.keys())
156 models = sorted(set(col for prop_df in self.data.values() for col in prop_df.columns if col not in self.domain_keys))
158 return {
159 "available_properties": props,
160 "available_models": models
161 }
163 if model_name is not None and property_name is None:
164 # Return a dictionary: {property: Series of model values}
165 result = {}
166 for prop, df in self.data.items():
167 if model_name in df.columns:
168 cols = self.domain_keys + [model_name]
169 result[prop] = df[cols]
170 else:
171 result[prop] = f"[Model '{model_name}' not available]"
172 return result
174 if property_name is not None:
175 if property_name not in self.data:
176 raise KeyError(f"Property '{property_name}' not found.")
178 df = self.data[property_name]
180 if model_name is None:
181 return df # Full property DataFrame
183 if model_name not in df.columns:
184 raise KeyError(f"Model '{model_name}' not found in property '{property_name}'.")
186 return df[model_name]
190 def separate_points_distance_allSets(self, list1, list2, distance1, distance2):
191 """
192 Separates points in list1 into three groups based on their proximity to any point in list2.
194 :param list1: List of (x, y) tuples.
195 :param list2: List of (x, y) tuples.
196 :param distance: The threshold distance to determine proximity.
197 :return: Two lists - close_points and distant_points.
198 """
199 train = []
200 validation=[]
201 test = []
203 train_list_coordinates=[]
204 validation_list_coordinates=[]
205 test_list_coordinates=[]
207 for i in range(len(list1)):
208 point1=list1[i]
209 close = False
210 for point2 in list2:
211 if np.linalg.norm(np.array(point1) - np.array(point2)) <= distance1:
212 close = True
213 break
214 if close:
215 train.append(point1)
216 train_list_coordinates.append(i)
217 else:
218 close2=False
219 for point2 in list2:
220 if np.linalg.norm(np.array(point1) - np.array(point2)) <= distance2:
221 close2 = True
222 break
223 if close2:
224 validation.append(point1)
225 validation_list_coordinates.append(i)
226 else:
227 test.append(point1)
228 test_list_coordinates.append(i)
230 return train_list_coordinates, validation_list_coordinates, test_list_coordinates
232 def split_data(self, data_dict, property_name, splitting_algorithm="random", **kwargs):
233 """
234 Split data into training, validation, and testing sets using random or inside-to-outside logic.
236 :param data_dict: Dictionary output from `load_data`, where keys are property names and values are DataFrames.
237 :param property_name: The key in `data_dict` specifying which DataFrame to use for splitting.
238 :param splitting_algorithm: 'random' (default) or 'inside_to_outside'.
239 :param kwargs: Additional arguments depending on the chosen algorithm.
240 For 'random': train_size, val_size, test_size
241 For 'inside_to_outside': stable_points (list of (x, y)), distance1, distance2
242 :return: Tuple of train, validation, test datasets as DataFrames.
243 """
244 if property_name not in data_dict:
245 raise ValueError(f"Property '{property_name}' not found in the provided data dictionary.")
247 data = data_dict[property_name]
249 if isinstance(data, pd.DataFrame):
250 indexable_data = data.reset_index(drop=True)
251 point_list = list(indexable_data.itertuples(index=False, name=None))
252 else:
253 raise TypeError("Data for the specified property must be a pandas DataFrame.")
255 if splitting_algorithm == "random":
256 required = ['train_size', 'val_size', 'test_size']
257 if not all(k in kwargs for k in required):
258 raise ValueError(f"Missing required kwargs for 'random': {required}")
260 train_size = kwargs['train_size']
261 val_size = kwargs['val_size']
262 test_size = kwargs['test_size']
264 if not np.isclose(train_size + val_size + test_size, 1.0):
265 raise ValueError("train_size + val_size + test_size must equal 1.0")
267 # Random split using indexes
268 train_idx, temp_idx = train_test_split(indexable_data.index, train_size=train_size, random_state=1)
269 val_rel = val_size / (val_size + test_size)
270 val_idx, test_idx = train_test_split(temp_idx, test_size=1 - val_rel, random_state=1)
272 elif splitting_algorithm == "inside_to_outside":
273 required = ['stable_points', 'distance1', 'distance2']
274 if not all(k in kwargs for k in required):
275 raise ValueError(f"Missing required kwargs for 'inside_to_outside': {required}")
277 stable_points = kwargs['stable_points']
278 distance1 = kwargs['distance1']
279 distance2 = kwargs['distance2']
281 train_idx, val_idx, test_idx = self.separate_points_distance_allSets(
282 point_list, stable_points, distance1, distance2
283 )
284 else:
285 raise ValueError("splitting_algorithm must be either 'random' or 'inside_to_outside'")
287 train_data = indexable_data.iloc[train_idx]
288 val_data = indexable_data.iloc[val_idx]
289 test_data = indexable_data.iloc[test_idx]
291 return train_data, val_data, test_data
294 def get_subset(self, property_name, filters=None, models_to_include=None):
295 """
296 Return a filtered, wide-format DataFrame for a given property.
298 :param property_name: Name of the property (e.g., "BE", "ChRad").
299 :param filters: Dictionary of filtering rules applied to the domain columns (e.g., {"Z": (26, 28)}).
300 :param models_to_include: Optional list of model names to retain in the output.
301 If None, all model columns are retained.
302 :return: Filtered wide-format DataFrame with columns: domain keys + model columns.
303 """
304 if property_name not in self.data:
305 raise ValueError(f"Property '{property_name}' not found in dataset.")
307 df = self.data[property_name].copy()
309 # Apply row-level filters (domain-based)
310 if filters:
311 for column, condition in filters.items():
312 if column == 'multi' and callable(condition):
313 df = df[df.apply(condition, axis=1)]
314 elif callable(condition):
315 df = df[condition(df[column])]
316 elif isinstance(condition, tuple) and len(condition) == 2:
317 df = df[(df[column] >= condition[0]) & (df[column] <= condition[1])]
318 elif isinstance(condition, list):
319 df = df[df[column].isin(condition)]
320 else:
321 df = df[df[column] == condition]
323 # Optionally restrict to a subset of models
324 if models_to_include is not None:
325 domain_keys = [col for col in ['N', 'Z'] if col in df.columns]
326 allowed_cols = domain_keys + [m for m in models_to_include if m in df.columns]
327 df = df[allowed_cols]
329 return df