Coverage for pybmc/bmc.py: 85%
119 statements
« prev ^ index » next coverage.py v7.10.0, created at 2025-07-27 15:48 +0000
« prev ^ index » next coverage.py v7.10.0, created at 2025-07-27 15:48 +0000
1import numpy as np
2import pandas as pd
3from .inference_utils import (
4 gibbs_sampler,
5 gibbs_sampler_simplex,
6 USVt_hat_extraction,
7)
8from .sampling_utils import coverage, rndm_m_random_calculator
11class BayesianModelCombination:
12 """
13 Implements Bayesian Model Combination (BMC) for aggregating predictions from multiple models.
15 This class performs orthogonalization of model predictions, trains the model combination
16 using Gibbs sampling, and provides methods for prediction and evaluation.
18 Args:
19 models_list (list[str]): List of model names to combine.
20 data_dict (dict[str, pandas.DataFrame]): Dictionary from `load_data()` where keys are property names and values are DataFrames.
21 truth_column_name (str): Name of the column containing ground truth values.
22 weights (list[float], optional): Initial weights for models. Defaults to equal weights.
24 Attributes:
25 models_list (list[str]): List of model names.
26 data_dict (dict[str, pandas.DataFrame]): Loaded data dictionary.
27 truth_column_name (str): Ground truth column name.
28 weights (list[float]): Current model weights.
29 samples (numpy.ndarray): Posterior samples from Gibbs sampling.
30 current_property (str): Current property being processed.
31 centered_experiment_train (numpy.ndarray): Centered experimental values.
32 U_hat (numpy.ndarray): Reduced left singular vectors from SVD.
33 Vt_hat (numpy.ndarray): Normalized right singular vectors.
34 S_hat (numpy.ndarray): Retained singular values.
35 Vt_hat_normalized (numpy.ndarray): Original right singular vectors.
36 _predictions_mean_train (numpy.ndarray): Mean predictions across models.
38 Example:
39 >>> bmc = BayesianModelCombination(
40 models_list=["model1", "model2"],
41 data_dict=data,
42 truth_column_name="truth"
43 )
44 """
46 def __init__(self, models_list, data_dict, truth_column_name, weights=None):
47 """
48 Initializes the BMC instance.
50 Args:
51 models_list (list[str]): List of model names to combine.
52 data_dict (dict[str, pandas.DataFrame]): Dictionary of DataFrames from Dataset.load_data().
53 truth_column_name (str): Name of column containing ground truth values.
54 weights (list[float], optional): Initial model weights. Defaults to None (equal weights).
56 Raises:
57 ValueError: If `models_list` is not a list of strings or `data_dict` is invalid.
58 """
60 if not isinstance(models_list, list) or not all(
61 isinstance(model, str) for model in models_list
62 ):
63 raise ValueError(
64 "The 'models' should be a list of model names (strings) for Bayesian Combination."
65 )
66 if not isinstance(data_dict, dict) or not all(
67 isinstance(df, pd.DataFrame) for df in data_dict.values()
68 ):
69 raise ValueError(
70 "The 'data_dict' should be a dictionary of pandas DataFrames, one per property."
71 )
73 self.data_dict = data_dict
74 self.models_list = models_list
75 self.models = [m for m in models_list if m != "truth"]
76 self.weights = weights if weights is not None else None
77 self.truth_column_name = truth_column_name
79 def orthogonalize(self, property, train_df, components_kept):
80 """
81 Performs orthogonalization of model predictions using SVD.
83 This method centers model predictions, performs SVD decomposition, and retains
84 the specified number of components for subsequent training.
86 Args:
87 property (str): Nuclear property to orthogonalize (e.g., 'BE').
88 train_df (pandas.DataFrame): Training data from Dataset.split_data().
89 components_kept (int): Number of SVD components to retain.
91 Note:
92 This method must be called before training. Results are stored in instance attributes.
93 """
94 # Store selected property
95 self.current_property = property
97 # Extract the relevant DataFrame for that property
98 df = self.data_dict[property].copy()
99 self.selected_models_dataset = df # Store for train() and predict()
101 # Extract model outputs (only the model columns)
102 models_output_train = train_df[self.models]
103 model_predictions_train = models_output_train.values
105 # Mean prediction across models (per nucleus)
106 predictions_mean_train = np.mean(model_predictions_train, axis=1)
108 # Experimental truth values for the property
109 centered_experiment_train = (
110 train_df[self.truth_column_name].values - predictions_mean_train
111 )
113 # Center model predictions
114 model_predictions_train_centered = (
115 model_predictions_train - predictions_mean_train[:, None]
116 )
118 # Perform SVD
119 U, S, Vt = np.linalg.svd(model_predictions_train_centered)
121 # Dimensionality reduction
122 U_hat, S_hat, Vt_hat, Vt_hat_normalized = USVt_hat_extraction(U, S, Vt, components_kept) # type: ignore
124 # Save for training
125 self.centered_experiment_train = centered_experiment_train
126 self.U_hat = U_hat
127 self.Vt_hat = Vt_hat
128 self.S_hat = S_hat
129 self.Vt_hat_normalized = Vt_hat_normalized
130 self._predictions_mean_train = predictions_mean_train
132 def train(self, training_options=None):
133 """
134 Trains the model combination using Gibbs sampling.
136 Args:
137 training_options (dict, optional): Training configuration. Options:
138 - iterations (int): Number of Gibbs iterations (default: 50000).
139 - sampler (str): 'gibbs_sampling' or 'simplex' (default: 'gibbs_sampling').
140 - burn (int): Burn-in iterations for simplex sampler (default: 10000).
141 - stepsize (float): Proposal step size for simplex sampler (default: 0.001).
142 - b_mean_prior (numpy.ndarray): Prior mean vector (default: zeros).
143 - b_mean_cov (numpy.ndarray): Prior covariance matrix (default: diag(S_hat²)).
144 - nu0_chosen (float): Degrees of freedom for variance prior (default: 1.0).
145 - sigma20_chosen (float): Prior variance (default: 0.02).
147 Note:
148 Requires prior call to `orthogonalize()`. Stores posterior samples in `self.samples`.
149 """
151 if training_options is None:
152 training_options = {}
154 # functions defined so that whenever a key not specified, we print out the default value for users
155 def get_option(key, default):
156 if key not in training_options:
157 print(f"[INFO] Using default value for '{key}': {default}")
158 return training_options.get(key, default)
160 iterations = get_option("iterations", 50000)
161 sampler = get_option("sampler", "gibbs_sampling")
162 burn = get_option("burn", 10000)
163 stepsize = get_option("stepsize", 0.001)
165 S_hat = self.S_hat
166 num_components = self.U_hat.shape[1]
168 b_mean_prior = get_option("b_mean_prior", np.zeros(num_components))
169 b_mean_cov = get_option("b_mean_cov", np.diag(S_hat**2))
170 nu0_chosen = get_option("nu0_chosen", 1.0)
171 sigma20_chosen = get_option("sigma20_chosen", 0.02)
173 if sampler == "simplex":
174 self.samples = gibbs_sampler_simplex(
175 self.centered_experiment_train,
176 self.U_hat,
177 self.Vt_hat,
178 self.S_hat,
179 iterations,
180 [
181 nu0_chosen,
182 sigma20_chosen,
183 ], # Note: no b_mean_prior/b_mean_cov needed
184 burn=burn,
185 stepsize=stepsize,
186 )
187 else:
188 self.samples = gibbs_sampler(
189 self.centered_experiment_train,
190 self.U_hat,
191 iterations,
192 [b_mean_prior, b_mean_cov, nu0_chosen, sigma20_chosen],
193 )
195 def predict(self, X):
196 """
197 Predicts values using the trained model combination with uncertainty quantification.
199 Args:
200 X (pandas.DataFrame): Input data containing model predictions and domain information.
202 Returns:
203 tuple[numpy.ndarray, pandas.DataFrame, pandas.DataFrame, pandas.DataFrame]: Contains:
204 - rndm_m (numpy.ndarray): Full posterior draws (n_samples, n_points).
205 - lower_df (pandas.DataFrame): Lower bounds (2.5th percentile) with domain keys.
206 - median_df (pandas.DataFrame): Median predictions with domain keys.
207 - upper_df (pandas.DataFrame): Upper bounds (97.5th percentile) with domain keys.
209 Raises:
210 ValueError: If `orthogonalize()` and `train()` haven't been called.
211 """
212 if self.samples is None or self.Vt_hat is None:
213 raise ValueError(
214 "Must call `orthogonalize()` and `train()` before predicting."
215 )
217 if not isinstance(X, pd.DataFrame):
218 raise ValueError(
219 "X must be a pandas DataFrame containing model predictions and domain info."
220 )
222 # Infer model columns vs. domain columns
223 model_cols = self.models
224 domain_keys = [col for col in X.columns if col not in model_cols]
226 model_preds = X[model_cols].values
227 rndm_m, (lower, median, upper) = rndm_m_random_calculator(
228 model_preds, self.samples, self.Vt_hat
229 )
231 domain_df = X[domain_keys].reset_index(drop=True)
233 lower_df = domain_df.copy()
234 lower_df["Predicted_Lower"] = lower
236 median_df = domain_df.copy()
237 median_df["Predicted_Median"] = median
239 upper_df = domain_df.copy()
240 upper_df["Predicted_Upper"] = upper
242 return rndm_m, lower_df, median_df, upper_df
244 def predict2(self, property):
245 """
246 Predicts values for a specific property using the trained model combination.
248 This version uses the property name instead of a DataFrame input.
250 Args:
251 property (str): Property name to predict (e.g., 'ChRad').
253 Returns:
254 tuple[numpy.ndarray, pandas.DataFrame, pandas.DataFrame, pandas.DataFrame]: Contains:
255 - rndm_m (numpy.ndarray): Full posterior draws (n_samples, n_points).
256 - lower_df (pandas.DataFrame): Lower bounds (2.5th percentile) with domain keys.
257 - median_df (pandas.DataFrame): Median predictions with domain keys.
258 - upper_df (pandas.DataFrame): Upper bounds (97.5th percentile) with domain keys.
260 Raises:
261 ValueError: If `orthogonalize()` and `train()` haven't been called.
262 KeyError: If property not found in `data_dict`.
263 """
264 if self.samples is None or self.Vt_hat is None:
265 raise ValueError(
266 "Must call `orthogonalize()` and `train()` before predicting."
267 )
269 if property not in self.data_dict:
270 raise KeyError(f"Property '{property}' not found in data_dict.")
272 df = self.data_dict[property].copy()
274 # Infer domain and model columns
275 full_model_cols = self.models
276 domain_keys = [
277 col
278 for col in df.columns
279 if col not in full_model_cols and col != self.truth_column_name
280 ]
282 # Determine which models are present
283 available_models = [m for m in df.columns if m in self.models]
285 # Sets
286 trained_models_set = set(self.models)
287 available_models_set = set(available_models)
289 missing_models = trained_models_set - available_models_set
290 extra_models = available_models_set - trained_models_set
291 print(f"Available models: {available_models_set}")
292 print(f"Trained models: {trained_models_set}")
294 if len(extra_models) > 0:
295 raise ValueError(
296 f"ERROR: Property '{property}' contains extra models not present during training: {list(extra_models)}. "
297 "You must retrain if using a larger model space."
298 )
300 if len(missing_models) > 0:
301 print(
302 f"WARNING: Predicting on property '{property}' with missing models: {list(missing_models)}"
303 )
304 print(
305 " The trained model weights include these models — prediction will proceed, but results may not be statistically accurate."
306 )
308 if len(available_models) == 0:
309 raise ValueError(
310 "No available trained models are present in prediction DataFrame."
311 )
313 # Filter predictions and model weights
314 model_preds = df[available_models].values
315 domain_df = df[domain_keys].reset_index(drop=True)
317 # Find indices of available models in training order
318 model_indices = [self.models.index(m) for m in available_models]
320 # Reduce Vt_hat and samples to only use available models
321 Vt_hat_reduced = self.Vt_hat[:, model_indices]
323 rndm_m, (lower, median, upper) = rndm_m_random_calculator(
324 model_preds, self.samples, Vt_hat_reduced
325 )
327 # Build output DataFrames
328 lower_df = domain_df.copy()
329 lower_df["Predicted_Lower"] = lower
331 median_df = domain_df.copy()
332 median_df["Predicted_Median"] = median
334 upper_df = domain_df.copy()
335 upper_df["Predicted_Upper"] = upper
337 return rndm_m, lower_df, median_df, upper_df
339 def evaluate(self, domain_filter=None):
340 """
341 Evaluates model performance using coverage calculation.
343 Args:
344 domain_filter (dict, optional): Filtering rules for domain columns.
345 Example: {"Z": (20, 30), "N": (20, 40)}.
347 Returns:
348 list[float]: Coverage percentages for each percentile in [0, 5, 10, ..., 100].
349 """
350 df = self.data_dict[self.current_property]
352 if domain_filter:
353 # Inline optimized filtering
354 for col, cond in domain_filter.items():
355 if col == "multi" and callable(cond):
356 df = df[df.apply(cond, axis=1)]
357 elif callable(cond):
358 df = df[cond(df[col])]
359 elif isinstance(cond, tuple) and len(cond) == 2:
360 df = df[df[col].between(*cond)]
361 elif isinstance(cond, list):
362 df = df[df[col].isin(cond)]
363 else:
364 df = df[df[col] == cond]
366 preds = df[self.models].to_numpy()
367 rndm_m, (lower, median, upper) = rndm_m_random_calculator(
368 preds, self.samples, self.Vt_hat
369 )
371 return coverage(
372 np.arange(0, 101, 5),
373 rndm_m,
374 df,
375 truth_column=self.truth_column_name,
376 )