Coverage for pybmc/bmc.py: 88%
106 statements
« prev ^ index » next coverage.py v7.10.0, created at 2026-02-19 17:54 +0000
« prev ^ index » next coverage.py v7.10.0, created at 2026-02-19 17:54 +0000
1import numpy as np
2import pandas as pd
3import matplotlib.pyplot as plt
4from sklearn.model_selection import train_test_split
5import os
6from .inference_utils import gibbs_sampler, gibbs_sampler_simplex, USVt_hat_extraction
7from .sampling_utils import coverage, rndm_m_random_calculator
10class BayesianModelCombination:
11 """
12 The main idea of this class is to perform BMM on the set of models that we choose
13 from the dataset class. What should this class contain:
14 + Orthogonalization step.
15 + Perform Bayesian inference on the training data that we extract from the Dataset class.
16 + Predictions for certain isotopes.
17 """
19 VALID_CONSTRAINTS = ("unconstrained", "simplex")
21 def __init__(self, models_list, data_dict, truth_column_name, weights=None, constraint="unconstrained"):
22 """
23 Initialize the BayesianModelCombination class.
25 :param models_list: List of model names
26 :param data_dict: Dictionary from `load_data()` where each key is a model name and each value is a DataFrame of properties
27 :param truth_column_name: Name of the column containing the truth values.
28 :param weights: Optional initial weights for the models.
29 :param constraint: Weight constraint mode. Options:
30 - ``"unconstrained"`` (default): No constraints on model weights.
31 - ``"simplex"``: Forces weights to lie on the probability simplex
32 (each weight between 0 and 1, weights sum to 1). Uses a
33 Metropolis-within-Gibbs sampler to enforce the constraint.
34 """
36 if not isinstance(models_list, list) or not all(isinstance(model, str) for model in models_list):
37 raise ValueError("The 'models' should be a list of model names (strings) for Bayesian Combination.")
38 if not isinstance(data_dict, dict) or not all(isinstance(df, pd.DataFrame) for df in data_dict.values()):
39 raise ValueError("The 'data_dict' should be a dictionary of pandas DataFrames, one per property.")
40 if constraint not in self.VALID_CONSTRAINTS:
41 raise ValueError(
42 f"Invalid constraint '{constraint}'. "
43 f"Must be one of {self.VALID_CONSTRAINTS}."
44 )
46 self.data_dict = data_dict
47 self.models_list = models_list
48 self.models = [m for m in models_list if m != 'truth']
49 self.weights = weights if weights is not None else None
50 self.truth_column_name = truth_column_name
51 self.constraint = constraint
52 self.samples = None
53 self.Vt_hat = None
56 def orthogonalize(self, property, train_df, components_kept):
57 """
58 Perform orthogonalization for the specified property using training data.
60 :param property: The nuclear property to orthogonalize on (e.g., 'BE').
61 :param train_index: Training data from split_data
62 :param components_kept: Number of SVD components to retain.
63 """
64 # Store selected property
65 self.current_property = property
67 # Extract the relevant DataFrame for that property
68 df = self.data_dict[property].copy()
69 self.selected_models_dataset = df # Store for train() and predict()
71 # Extract model outputs (only the model columns)
72 models_output_train = train_df[self.models]
73 model_predictions_train = models_output_train.values
75 # Mean prediction across models (per nucleus)
76 predictions_mean_train = np.mean(model_predictions_train, axis=1)
78 # Experimental truth values for the property
79 centered_experiment_train = train_df[self.truth_column_name].values - predictions_mean_train
81 # Center model predictions
82 model_predictions_train_centered = model_predictions_train - predictions_mean_train[:, None]
84 # Perform SVD
85 U, S, Vt = np.linalg.svd(model_predictions_train_centered)
87 # Dimensionality reduction
88 U_hat, S_hat, Vt_hat, Vt_hat_normalized = USVt_hat_extraction(U, S, Vt, components_kept) #type: ignore
90 # Save for training
91 self.centered_experiment_train = centered_experiment_train
92 self.U_hat = U_hat
93 self.Vt_hat = Vt_hat
94 self.S_hat = S_hat
95 self.Vt_hat_normalized = Vt_hat_normalized
96 self._predictions_mean_train = predictions_mean_train
99 def train(self, training_options=None):
100 """
101 Train the model combination using training data and optional training parameters.
103 :param training_options: Dictionary of training options. Keys:
104 - 'iterations': (int) Number of Gibbs iterations (default 50000)
105 - 'sampler': (str) Override the constraint mode for this training run.
106 ``"unconstrained"`` or ``"simplex"``. If not provided, uses the
107 instance-level ``self.constraint`` set at initialization.
108 - 'b_mean_prior': (np.ndarray) Prior mean vector (default zeros)
109 *(unconstrained sampler only)*
110 - 'b_mean_cov': (np.ndarray) Prior covariance matrix (default diag(S_hat²))
111 *(unconstrained sampler only)*
112 - 'nu0_chosen': (float) Degrees of freedom for variance prior (default 1.0)
113 - 'sigma20_chosen': (float) Prior variance (default 0.02)
114 - 'burn': (int) Burn-in iterations (default 10000)
115 *(simplex sampler only)*
116 - 'stepsize': (float) Proposal step size (default 0.001)
117 *(simplex sampler only)*
118 """
119 if training_options is None:
120 training_options = {}
122 # Determine which sampler to use: training_options overrides instance default
123 sampler_mode = training_options.get('sampler', self.constraint)
124 if sampler_mode not in self.VALID_CONSTRAINTS:
125 raise ValueError(
126 f"Invalid sampler '{sampler_mode}'. "
127 f"Must be one of {self.VALID_CONSTRAINTS}."
128 )
130 iterations = training_options.get('iterations', 50000)
131 num_components = self.U_hat.shape[1]
132 S_hat = self.S_hat
133 nu0_chosen = training_options.get('nu0_chosen', 1.0)
134 sigma20_chosen = training_options.get('sigma20_chosen', 0.02)
136 if sampler_mode == "simplex":
137 burn = training_options.get('burn', 10000)
138 stepsize = training_options.get('stepsize', 0.001)
139 self.samples = gibbs_sampler_simplex(
140 self.centered_experiment_train,
141 self.U_hat,
142 self.Vt_hat,
143 self.S_hat,
144 iterations,
145 [nu0_chosen, sigma20_chosen],
146 burn=burn,
147 stepsize=stepsize,
148 )
149 else:
150 b_mean_prior = training_options.get('b_mean_prior', np.zeros(num_components))
151 b_mean_cov = training_options.get('b_mean_cov', np.diag(S_hat**2))
152 self.samples = gibbs_sampler(
153 self.centered_experiment_train,
154 self.U_hat,
155 iterations,
156 [b_mean_prior, b_mean_cov, nu0_chosen, sigma20_chosen],
157 )
161 def predict(self, property):
162 """
163 Predict a specified property using the model weights learned during training.
165 :param property: The property name to predict (e.g., 'ChRad').
166 :return:
167 - rndm_m: array of shape (n_samples, n_points), full posterior draws
168 - lower_df: DataFrame with columns domain_keys + ['Predicted_Lower']
169 - median_df: DataFrame with columns domain_keys + ['Predicted_Median']
170 - upper_df: DataFrame with columns domain_keys + ['Predicted_Upper']
171 """
172 if self.samples is None or self.Vt_hat is None:
173 raise ValueError("Must call `orthogonalize()` and `train()` before predicting.")
175 if property not in self.data_dict:
176 raise KeyError(f"Property '{property}' not found in data_dict.")
178 df = self.data_dict[property].copy()
180 # Infer domain and model columns
181 full_model_cols = self.models
182 domain_keys = [col for col in df.columns if col not in full_model_cols and col != self.truth_column_name]
184 # Determine which models are present
185 available_models = [m for m in full_model_cols if m in df.columns]
187 if len(available_models) == 0:
188 raise ValueError("No available trained models are present in prediction DataFrame.")
190 # Filter predictions and model weights
191 model_preds = df[available_models].values
192 domain_df = df[domain_keys].reset_index(drop=True)
194 rndm_m, (lower, median, upper) = rndm_m_random_calculator(model_preds, self.samples, self.Vt_hat)
196 # Build output DataFrames
197 lower_df = domain_df.copy()
199 lower_df["Predicted_Lower"] = lower
201 median_df = domain_df.copy()
202 median_df["Predicted_Median"] = median
204 upper_df = domain_df.copy()
205 upper_df["Predicted_Upper"] = upper
207 return rndm_m, lower_df, median_df, upper_df
209 def evaluate(self, domain_filter=None):
210 """
211 Evaluate the model combination using coverage calculation.
213 :param domain_filter: dict with optional domain key ranges, e.g., {"Z": (20, 30), "N": (20, 40)}
214 :return: coverage list for each percentile
215 """
216 df = self.data_dict[self.current_property]
218 if domain_filter:
219 # Inline optimized filtering
220 for col, cond in domain_filter.items():
221 if col == 'multi' and callable(cond):
222 df = df[df.apply(cond, axis=1)]
223 elif callable(cond):
224 df = df[cond(df[col])]
225 elif isinstance(cond, tuple) and len(cond) == 2:
226 df = df[df[col].between(*cond)]
227 elif isinstance(cond, list):
228 df = df[df[col].isin(cond)]
229 else:
230 df = df[df[col] == cond]
232 preds = df[self.models].to_numpy()
233 rndm_m, (lower, median, upper) = rndm_m_random_calculator(preds, self.samples, self.Vt_hat)
235 return coverage(np.arange(0, 101, 5), rndm_m, df, truth_column=self.truth_column_name)
237 def get_weights(self, summary=True):
238 """
239 Compute model weights from posterior samples.
241 Converts the sampled coefficient vectors (beta) into model weights
242 using the transformation ``omega = beta @ Vt_hat + 1/M``, where M is
243 the number of models. In simplex-constrained mode, all weights are
244 guaranteed to be non-negative and sum to 1.
246 :param summary: If True (default), return a dictionary with
247 ``'mean'``, ``'std'``, ``'median'`` arrays keyed by statistic.
248 If False, return the full ``(n_samples, n_models)`` weight matrix.
249 :return: Weight summary dict or full weight matrix.
250 :raises ValueError: If ``train()`` has not been called.
251 """
252 if self.samples is None or self.Vt_hat is None:
253 raise ValueError("Must call `orthogonalize()` and `train()` before getting weights.")
255 betas = self.samples[:, :-1]
256 n_models = self.Vt_hat.shape[1]
257 default_weights = np.full(n_models, 1.0 / n_models)
258 weight_matrix = betas @ self.Vt_hat + default_weights
260 if summary:
261 return {
262 "mean": np.mean(weight_matrix, axis=0),
263 "std": np.std(weight_matrix, axis=0),
264 "median": np.median(weight_matrix, axis=0),
265 "models": self.models,
266 }
267 return weight_matrix