Coverage for pybmc/data.py: 81%

178 statements  

« prev     ^ index     » next       coverage.py v7.10.0, created at 2025-10-14 21:12 +0000

1import numpy as np 

2import pandas as pd 

3from sklearn.model_selection import train_test_split 

4import os 

5import logging 

6 

7 

8 

9 

10class Dataset: 

11 """ 

12 A general-purpose dataset class for loading and managing model data 

13 for Bayesian model combination workflows. 

14  

15 Supports .h5 and .csv files, and provides data splitting functionality. 

16 """ 

17 

18 def __init__(self, data_source=None, verbose=True): 

19 """ 

20 Initialize the Dataset object. 

21 

22 :param data_source: Path to the data file (.h5 or .csv). 

23 :param verbose: If True, display warnings and informational messages. Default is True. 

24 """ 

25 self.data_source = data_source 

26 self.data = {} # Dictionary of model to DataFrame 

27 self.verbose = verbose 

28 self.logger = logging.getLogger(__name__) 

29 if not self.logger.handlers: 

30 handler = logging.StreamHandler() 

31 handler.setFormatter(logging.Formatter('%(message)s')) 

32 self.logger.addHandler(handler) 

33 self.logger.setLevel(logging.INFO if verbose else logging.WARNING) 

34 

35 def load_data(self, models, keys=None, domain_keys=None, model_column='model', truth_column_name=None): 

36 """ 

37 Load data for each property and return a dictionary of synchronized DataFrames. 

38 Each DataFrame has columns: domain_keys + one column per model for that property. 

39 

40 Parameters: 

41 models (list): List of model names (for HDF5 keys or filtering CSV). 

42 keys (list): List of property names to extract (each will be a key in the output dict). 

43 domain_keys (list, optional): List of columns used to define the common domain (default ['N', 'Z']). 

44 model_column (str, optional): Name of the column in the CSV that identifies which model each row belongs to. 

45 Only used for CSV files; ignored for HDF5 files. 

46 truth_column_name (str, optional): Name of the truth model. If provided, the truth data will be  

47 left-joined to the common domain of the other models, allowing  

48 the truth data to have a smaller domain than the models. 

49 

50 Returns: 

51 dict: Dictionary where each key is a property name and each value is a DataFrame with columns: 

52 domain_keys + one column per model for that property. 

53 The DataFrames are synchronized to the intersection of the domains for all models. 

54 If truth_column_name is provided, truth data is left-joined (may have NaN values). 

55 

56 Supports both .h5 and .csv files. 

57 """ 

58 self.domain_keys = domain_keys 

59 

60 if self.data_source is None: 

61 raise ValueError("Data source must be specified.") 

62 if not os.path.exists(self.data_source): 

63 raise FileNotFoundError(f"Data source '{self.data_source}' not found.") 

64 if keys is None: 

65 raise ValueError("You must specify which properties to extract via 'keys'.") 

66 

67 result = {} 

68 

69 for prop in keys: 

70 dfs = [] 

71 truth_df = None 

72 skipped_models = [] 

73 

74 # Separate regular models from truth model 

75 regular_models = [m for m in models if m != truth_column_name] 

76 

77 if self.data_source.endswith('.h5'): 

78 for model in models: 

79 df = pd.read_hdf(self.data_source, key=model) 

80 # Check required columns 

81 missing_cols = [col for col in domain_keys + [prop] if col not in df.columns] 

82 if missing_cols: 

83 self.logger.info(f"[Skipped] Model '{model}' missing columns {missing_cols} for property '{prop}'.") 

84 skipped_models.append(model) 

85 continue 

86 temp = df[domain_keys + [prop]].copy() 

87 temp.rename(columns={prop: model}, inplace=True) # type: ignore 

88 

89 # Store truth data separately if truth_column_name is provided 

90 if truth_column_name and model == truth_column_name: 

91 truth_df = temp 

92 else: 

93 dfs.append(temp) 

94 

95 elif self.data_source.endswith('.csv'): 

96 df = pd.read_csv(self.data_source) 

97 for model in models: 

98 if model_column not in df.columns: 

99 raise ValueError(f"Expected column '{model_column}' not found in CSV.") 

100 model_df = df[df[model_column] == model] 

101 missing_cols = [col for col in domain_keys + [prop] if col not in model_df.columns] 

102 if missing_cols: 

103 self.logger.info(f"[Skipped] Model '{model}' missing columns {missing_cols} for key '{prop}'.") 

104 skipped_models.append(model) 

105 continue 

106 temp = model_df[domain_keys + [prop]].copy() 

107 temp.rename(columns={prop: model}, inplace=True) 

108 

109 # Store truth data separately if truth_column_name is provided 

110 if truth_column_name and model == truth_column_name: 

111 truth_df = temp 

112 else: 

113 dfs.append(temp) 

114 else: 

115 raise ValueError("Unsupported file format. Only .h5 and .csv are supported.") 

116 

117 if not dfs: 

118 self.logger.info(f"[Warning] No models with property '{prop}'. Resulting DataFrame will be empty.") 

119 result[prop] = pd.DataFrame(columns=domain_keys + [m for m in models if m not in skipped_models]) 

120 continue 

121 

122 # Intersect domain for regular models only 

123 common_df = dfs[0] 

124 for other_df in dfs[1:]: 

125 common_df = pd.merge(common_df, other_df, on=domain_keys, how="inner") 

126 # Drop rows with NaN in any required column (for regular models) 

127 common_df = common_df.dropna() 

128 

129 # Left join truth data if it exists and was specified 

130 if truth_df is not None: 

131 common_df = pd.merge(common_df, truth_df, on=domain_keys, how="left") 

132 

133 result[prop] = common_df 

134 self.data = result 

135 return result 

136 

137 def view_data(self, property_name=None, model_name=None): 

138 """ 

139 View data flexibly based on input parameters. 

140 

141 - No arguments: returns available property names and model names. 

142 - property_name only: returns the full DataFrame for that property. 

143 - model_name only: Return model values across all properties. 

144 - property_name + model_name: returns a Series of values for the model. 

145 

146 :param property_name: Optional property name  

147 :param model_name: Optional model name  

148 :return: dict, DataFrame, or Series depending on input. 

149 """ 

150 

151 if not self.data: 

152 raise RuntimeError("No data loaded. Run `load_data(...)` first.") 

153 

154 if property_name is None and model_name is None: 

155 props = list(self.data.keys()) 

156 models = sorted(set(col for prop_df in self.data.values() for col in prop_df.columns if col not in self.domain_keys)) 

157 

158 return { 

159 "available_properties": props, 

160 "available_models": models 

161 } 

162 

163 if model_name is not None and property_name is None: 

164 # Return a dictionary: {property: Series of model values} 

165 result = {} 

166 for prop, df in self.data.items(): 

167 if model_name in df.columns: 

168 cols = self.domain_keys + [model_name] 

169 result[prop] = df[cols] 

170 else: 

171 result[prop] = f"[Model '{model_name}' not available]" 

172 return result 

173 

174 if property_name is not None: 

175 if property_name not in self.data: 

176 raise KeyError(f"Property '{property_name}' not found.") 

177 

178 df = self.data[property_name] 

179 

180 if model_name is None: 

181 return df # Full property DataFrame 

182 

183 if model_name not in df.columns: 

184 raise KeyError(f"Model '{model_name}' not found in property '{property_name}'.") 

185 

186 return df[model_name] 

187 

188 

189 

190 def separate_points_distance_allSets(self, list1, list2, distance1, distance2): 

191 """ 

192 Separates points in list1 into three groups based on their proximity to any point in list2. 

193 

194 :param list1: List of (x, y) tuples. 

195 :param list2: List of (x, y) tuples. 

196 :param distance: The threshold distance to determine proximity. 

197 :return: Two lists - close_points and distant_points. 

198 """ 

199 train = [] 

200 validation=[] 

201 test = [] 

202 

203 train_list_coordinates=[] 

204 validation_list_coordinates=[] 

205 test_list_coordinates=[] 

206 

207 for i in range(len(list1)): 

208 point1=list1[i] 

209 close = False 

210 for point2 in list2: 

211 if np.linalg.norm(np.array(point1) - np.array(point2)) <= distance1: 

212 close = True 

213 break 

214 if close: 

215 train.append(point1) 

216 train_list_coordinates.append(i) 

217 else: 

218 close2=False 

219 for point2 in list2: 

220 if np.linalg.norm(np.array(point1) - np.array(point2)) <= distance2: 

221 close2 = True 

222 break 

223 if close2: 

224 validation.append(point1) 

225 validation_list_coordinates.append(i) 

226 else: 

227 test.append(point1) 

228 test_list_coordinates.append(i) 

229 

230 return train_list_coordinates, validation_list_coordinates, test_list_coordinates 

231 

232 def split_data(self, data_dict, property_name, splitting_algorithm="random", **kwargs): 

233 """ 

234 Split data into training, validation, and testing sets using random or inside-to-outside logic. 

235 

236 :param data_dict: Dictionary output from `load_data`, where keys are property names and values are DataFrames. 

237 :param property_name: The key in `data_dict` specifying which DataFrame to use for splitting. 

238 :param splitting_algorithm: 'random' (default) or 'inside_to_outside'. 

239 :param kwargs: Additional arguments depending on the chosen algorithm. 

240 For 'random': train_size, val_size, test_size 

241 For 'inside_to_outside': stable_points (list of (x, y)), distance1, distance2 

242 :return: Tuple of train, validation, test datasets as DataFrames. 

243 """ 

244 if property_name not in data_dict: 

245 raise ValueError(f"Property '{property_name}' not found in the provided data dictionary.") 

246 

247 data = data_dict[property_name] 

248 

249 if isinstance(data, pd.DataFrame): 

250 indexable_data = data.reset_index(drop=True) 

251 point_list = list(indexable_data.itertuples(index=False, name=None)) 

252 else: 

253 raise TypeError("Data for the specified property must be a pandas DataFrame.") 

254 

255 if splitting_algorithm == "random": 

256 required = ['train_size', 'val_size', 'test_size'] 

257 if not all(k in kwargs for k in required): 

258 raise ValueError(f"Missing required kwargs for 'random': {required}") 

259 

260 train_size = kwargs['train_size'] 

261 val_size = kwargs['val_size'] 

262 test_size = kwargs['test_size'] 

263 

264 if not np.isclose(train_size + val_size + test_size, 1.0): 

265 raise ValueError("train_size + val_size + test_size must equal 1.0") 

266 

267 # Random split using indexes 

268 train_idx, temp_idx = train_test_split(indexable_data.index, train_size=train_size, random_state=1) 

269 val_rel = val_size / (val_size + test_size) 

270 val_idx, test_idx = train_test_split(temp_idx, test_size=1 - val_rel, random_state=1) 

271 

272 elif splitting_algorithm == "inside_to_outside": 

273 required = ['stable_points', 'distance1', 'distance2'] 

274 if not all(k in kwargs for k in required): 

275 raise ValueError(f"Missing required kwargs for 'inside_to_outside': {required}") 

276 

277 stable_points = kwargs['stable_points'] 

278 distance1 = kwargs['distance1'] 

279 distance2 = kwargs['distance2'] 

280 

281 train_idx, val_idx, test_idx = self.separate_points_distance_allSets( 

282 point_list, stable_points, distance1, distance2 

283 ) 

284 else: 

285 raise ValueError("splitting_algorithm must be either 'random' or 'inside_to_outside'") 

286 

287 train_data = indexable_data.iloc[train_idx] 

288 val_data = indexable_data.iloc[val_idx] 

289 test_data = indexable_data.iloc[test_idx] 

290 

291 return train_data, val_data, test_data 

292 

293 

294 def get_subset(self, property_name, filters=None, models_to_include=None): 

295 """ 

296 Return a filtered, wide-format DataFrame for a given property. 

297 

298 :param property_name: Name of the property (e.g., "BE", "ChRad"). 

299 :param filters: Dictionary of filtering rules applied to the domain columns (e.g., {"Z": (26, 28)}). 

300 :param models_to_include: Optional list of model names to retain in the output. 

301 If None, all model columns are retained. 

302 :return: Filtered wide-format DataFrame with columns: domain keys + model columns. 

303 """ 

304 if property_name not in self.data: 

305 raise ValueError(f"Property '{property_name}' not found in dataset.") 

306 

307 df = self.data[property_name].copy() 

308 

309 # Apply row-level filters (domain-based) 

310 if filters: 

311 for column, condition in filters.items(): 

312 if column == 'multi' and callable(condition): 

313 df = df[df.apply(condition, axis=1)] 

314 elif callable(condition): 

315 df = df[condition(df[column])] 

316 elif isinstance(condition, tuple) and len(condition) == 2: 

317 df = df[(df[column] >= condition[0]) & (df[column] <= condition[1])] 

318 elif isinstance(condition, list): 

319 df = df[df[column].isin(condition)] 

320 else: 

321 df = df[df[column] == condition] 

322 

323 # Optionally restrict to a subset of models 

324 if models_to_include is not None: 

325 domain_keys = [col for col in ['N', 'Z'] if col in df.columns] 

326 allowed_cols = domain_keys + [m for m in models_to_include if m in df.columns] 

327 df = df[allowed_cols] 

328 

329 return df