Coverage for pybmc/bmc.py: 88%

106 statements  

« prev     ^ index     » next       coverage.py v7.10.0, created at 2026-02-19 17:54 +0000

1import numpy as np 

2import pandas as pd 

3import matplotlib.pyplot as plt 

4from sklearn.model_selection import train_test_split 

5import os 

6from .inference_utils import gibbs_sampler, gibbs_sampler_simplex, USVt_hat_extraction 

7from .sampling_utils import coverage, rndm_m_random_calculator 

8 

9 

10class BayesianModelCombination: 

11 """ 

12 The main idea of this class is to perform BMM on the set of models that we choose  

13 from the dataset class. What should this class contain: 

14 + Orthogonalization step. 

15 + Perform Bayesian inference on the training data that we extract from the Dataset class. 

16 + Predictions for certain isotopes. 

17 """ 

18 

19 VALID_CONSTRAINTS = ("unconstrained", "simplex") 

20 

21 def __init__(self, models_list, data_dict, truth_column_name, weights=None, constraint="unconstrained"): 

22 """  

23 Initialize the BayesianModelCombination class. 

24 

25 :param models_list: List of model names 

26 :param data_dict: Dictionary from `load_data()` where each key is a model name and each value is a DataFrame of properties 

27 :param truth_column_name: Name of the column containing the truth values. 

28 :param weights: Optional initial weights for the models. 

29 :param constraint: Weight constraint mode. Options: 

30 - ``"unconstrained"`` (default): No constraints on model weights. 

31 - ``"simplex"``: Forces weights to lie on the probability simplex 

32 (each weight between 0 and 1, weights sum to 1). Uses a 

33 Metropolis-within-Gibbs sampler to enforce the constraint. 

34 """ 

35 

36 if not isinstance(models_list, list) or not all(isinstance(model, str) for model in models_list): 

37 raise ValueError("The 'models' should be a list of model names (strings) for Bayesian Combination.") 

38 if not isinstance(data_dict, dict) or not all(isinstance(df, pd.DataFrame) for df in data_dict.values()): 

39 raise ValueError("The 'data_dict' should be a dictionary of pandas DataFrames, one per property.") 

40 if constraint not in self.VALID_CONSTRAINTS: 

41 raise ValueError( 

42 f"Invalid constraint '{constraint}'. " 

43 f"Must be one of {self.VALID_CONSTRAINTS}." 

44 ) 

45 

46 self.data_dict = data_dict 

47 self.models_list = models_list 

48 self.models = [m for m in models_list if m != 'truth'] 

49 self.weights = weights if weights is not None else None 

50 self.truth_column_name = truth_column_name 

51 self.constraint = constraint 

52 self.samples = None 

53 self.Vt_hat = None 

54 

55 

56 def orthogonalize(self, property, train_df, components_kept): 

57 """ 

58 Perform orthogonalization for the specified property using training data. 

59 

60 :param property: The nuclear property to orthogonalize on (e.g., 'BE'). 

61 :param train_index: Training data from split_data 

62 :param components_kept: Number of SVD components to retain. 

63 """ 

64 # Store selected property 

65 self.current_property = property 

66 

67 # Extract the relevant DataFrame for that property 

68 df = self.data_dict[property].copy() 

69 self.selected_models_dataset = df # Store for train() and predict() 

70 

71 # Extract model outputs (only the model columns) 

72 models_output_train = train_df[self.models] 

73 model_predictions_train = models_output_train.values 

74 

75 # Mean prediction across models (per nucleus) 

76 predictions_mean_train = np.mean(model_predictions_train, axis=1) 

77 

78 # Experimental truth values for the property 

79 centered_experiment_train = train_df[self.truth_column_name].values - predictions_mean_train 

80 

81 # Center model predictions 

82 model_predictions_train_centered = model_predictions_train - predictions_mean_train[:, None] 

83 

84 # Perform SVD 

85 U, S, Vt = np.linalg.svd(model_predictions_train_centered) 

86 

87 # Dimensionality reduction 

88 U_hat, S_hat, Vt_hat, Vt_hat_normalized = USVt_hat_extraction(U, S, Vt, components_kept) #type: ignore 

89 

90 # Save for training 

91 self.centered_experiment_train = centered_experiment_train 

92 self.U_hat = U_hat 

93 self.Vt_hat = Vt_hat 

94 self.S_hat = S_hat 

95 self.Vt_hat_normalized = Vt_hat_normalized 

96 self._predictions_mean_train = predictions_mean_train 

97 

98 

99 def train(self, training_options=None): 

100 """ 

101 Train the model combination using training data and optional training parameters. 

102 

103 :param training_options: Dictionary of training options. Keys: 

104 - 'iterations': (int) Number of Gibbs iterations (default 50000) 

105 - 'sampler': (str) Override the constraint mode for this training run. 

106 ``"unconstrained"`` or ``"simplex"``. If not provided, uses the 

107 instance-level ``self.constraint`` set at initialization. 

108 - 'b_mean_prior': (np.ndarray) Prior mean vector (default zeros) 

109 *(unconstrained sampler only)* 

110 - 'b_mean_cov': (np.ndarray) Prior covariance matrix (default diag(S_hat²)) 

111 *(unconstrained sampler only)* 

112 - 'nu0_chosen': (float) Degrees of freedom for variance prior (default 1.0) 

113 - 'sigma20_chosen': (float) Prior variance (default 0.02) 

114 - 'burn': (int) Burn-in iterations (default 10000) 

115 *(simplex sampler only)* 

116 - 'stepsize': (float) Proposal step size (default 0.001) 

117 *(simplex sampler only)* 

118 """ 

119 if training_options is None: 

120 training_options = {} 

121 

122 # Determine which sampler to use: training_options overrides instance default 

123 sampler_mode = training_options.get('sampler', self.constraint) 

124 if sampler_mode not in self.VALID_CONSTRAINTS: 

125 raise ValueError( 

126 f"Invalid sampler '{sampler_mode}'. " 

127 f"Must be one of {self.VALID_CONSTRAINTS}." 

128 ) 

129 

130 iterations = training_options.get('iterations', 50000) 

131 num_components = self.U_hat.shape[1] 

132 S_hat = self.S_hat 

133 nu0_chosen = training_options.get('nu0_chosen', 1.0) 

134 sigma20_chosen = training_options.get('sigma20_chosen', 0.02) 

135 

136 if sampler_mode == "simplex": 

137 burn = training_options.get('burn', 10000) 

138 stepsize = training_options.get('stepsize', 0.001) 

139 self.samples = gibbs_sampler_simplex( 

140 self.centered_experiment_train, 

141 self.U_hat, 

142 self.Vt_hat, 

143 self.S_hat, 

144 iterations, 

145 [nu0_chosen, sigma20_chosen], 

146 burn=burn, 

147 stepsize=stepsize, 

148 ) 

149 else: 

150 b_mean_prior = training_options.get('b_mean_prior', np.zeros(num_components)) 

151 b_mean_cov = training_options.get('b_mean_cov', np.diag(S_hat**2)) 

152 self.samples = gibbs_sampler( 

153 self.centered_experiment_train, 

154 self.U_hat, 

155 iterations, 

156 [b_mean_prior, b_mean_cov, nu0_chosen, sigma20_chosen], 

157 ) 

158 

159 

160 

161 def predict(self, property): 

162 """ 

163 Predict a specified property using the model weights learned during training. 

164 

165 :param property: The property name to predict (e.g., 'ChRad'). 

166 :return: 

167 - rndm_m: array of shape (n_samples, n_points), full posterior draws 

168 - lower_df: DataFrame with columns domain_keys + ['Predicted_Lower'] 

169 - median_df: DataFrame with columns domain_keys + ['Predicted_Median'] 

170 - upper_df: DataFrame with columns domain_keys + ['Predicted_Upper'] 

171 """ 

172 if self.samples is None or self.Vt_hat is None: 

173 raise ValueError("Must call `orthogonalize()` and `train()` before predicting.") 

174 

175 if property not in self.data_dict: 

176 raise KeyError(f"Property '{property}' not found in data_dict.") 

177 

178 df = self.data_dict[property].copy() 

179 

180 # Infer domain and model columns 

181 full_model_cols = self.models 

182 domain_keys = [col for col in df.columns if col not in full_model_cols and col != self.truth_column_name] 

183 

184 # Determine which models are present 

185 available_models = [m for m in full_model_cols if m in df.columns] 

186 

187 if len(available_models) == 0: 

188 raise ValueError("No available trained models are present in prediction DataFrame.") 

189 

190 # Filter predictions and model weights 

191 model_preds = df[available_models].values 

192 domain_df = df[domain_keys].reset_index(drop=True) 

193 

194 rndm_m, (lower, median, upper) = rndm_m_random_calculator(model_preds, self.samples, self.Vt_hat) 

195 

196 # Build output DataFrames 

197 lower_df = domain_df.copy() 

198 

199 lower_df["Predicted_Lower"] = lower 

200 

201 median_df = domain_df.copy() 

202 median_df["Predicted_Median"] = median 

203 

204 upper_df = domain_df.copy() 

205 upper_df["Predicted_Upper"] = upper 

206 

207 return rndm_m, lower_df, median_df, upper_df 

208 

209 def evaluate(self, domain_filter=None): 

210 """ 

211 Evaluate the model combination using coverage calculation. 

212 

213 :param domain_filter: dict with optional domain key ranges, e.g., {"Z": (20, 30), "N": (20, 40)} 

214 :return: coverage list for each percentile 

215 """ 

216 df = self.data_dict[self.current_property] 

217 

218 if domain_filter: 

219 # Inline optimized filtering 

220 for col, cond in domain_filter.items(): 

221 if col == 'multi' and callable(cond): 

222 df = df[df.apply(cond, axis=1)] 

223 elif callable(cond): 

224 df = df[cond(df[col])] 

225 elif isinstance(cond, tuple) and len(cond) == 2: 

226 df = df[df[col].between(*cond)] 

227 elif isinstance(cond, list): 

228 df = df[df[col].isin(cond)] 

229 else: 

230 df = df[df[col] == cond] 

231 

232 preds = df[self.models].to_numpy() 

233 rndm_m, (lower, median, upper) = rndm_m_random_calculator(preds, self.samples, self.Vt_hat) 

234 

235 return coverage(np.arange(0, 101, 5), rndm_m, df, truth_column=self.truth_column_name) 

236 

237 def get_weights(self, summary=True): 

238 """ 

239 Compute model weights from posterior samples. 

240 

241 Converts the sampled coefficient vectors (beta) into model weights 

242 using the transformation ``omega = beta @ Vt_hat + 1/M``, where M is 

243 the number of models. In simplex-constrained mode, all weights are 

244 guaranteed to be non-negative and sum to 1. 

245 

246 :param summary: If True (default), return a dictionary with 

247 ``'mean'``, ``'std'``, ``'median'`` arrays keyed by statistic. 

248 If False, return the full ``(n_samples, n_models)`` weight matrix. 

249 :return: Weight summary dict or full weight matrix. 

250 :raises ValueError: If ``train()`` has not been called. 

251 """ 

252 if self.samples is None or self.Vt_hat is None: 

253 raise ValueError("Must call `orthogonalize()` and `train()` before getting weights.") 

254 

255 betas = self.samples[:, :-1] 

256 n_models = self.Vt_hat.shape[1] 

257 default_weights = np.full(n_models, 1.0 / n_models) 

258 weight_matrix = betas @ self.Vt_hat + default_weights 

259 

260 if summary: 

261 return { 

262 "mean": np.mean(weight_matrix, axis=0), 

263 "std": np.std(weight_matrix, axis=0), 

264 "median": np.median(weight_matrix, axis=0), 

265 "models": self.models, 

266 } 

267 return weight_matrix 

268 

269 

270 

271