Coverage for pybmc/data.py: 76%

162 statements  

« prev     ^ index     » next       coverage.py v7.10.0, created at 2025-07-27 15:48 +0000

1import numpy as np 

2import pandas as pd 

3from sklearn.model_selection import train_test_split # type: ignore 

4import os 

5 

6 

7class Dataset: 

8 """ 

9 Manages datasets for Bayesian model combination workflows. 

10 

11 Supports loading data from HDF5 and CSV files, splitting data, and filtering. 

12 

13 Attributes: 

14 data_source (str): Path to data file. 

15 data (dict[str, pandas.DataFrame]): Dictionary of loaded data by property. 

16 domain_keys (list[str]): Domain columns used for data alignment. 

17 """ 

18 

19 def __init__(self, data_source=None): 

20 """ 

21 Initializes the Dataset instance. 

22 

23 Args: 

24 data_source (str, optional): Path to data file (.h5 or .csv). 

25 """ 

26 self.data_source = data_source 

27 self.data = {} # Dictionary of model to DataFrame 

28 self.domain_keys = ["X1", "X2"] # Default domain keys 

29 

30 def load_data(self, models, keys=None, domain_keys=None, model_column="model"): 

31 """ 

32 Loads data for multiple properties and models. 

33 

34 Args: 

35 models (list[str]): Model names to load. 

36 keys (list[str]): Property names to extract. 

37 domain_keys (list[str], optional): Domain columns (default: ['N', 'Z']). 

38 model_column (str, optional): CSV column identifying models (default: 'model'). 

39 

40 Returns: 

41 dict[str, pandas.DataFrame]: Dictionary of DataFrames keyed by property name. 

42 

43 Raises: 

44 ValueError: If `data_source` not specified or `keys` missing. 

45 FileNotFoundError: If `data_source` doesn't exist. 

46 

47 Example: 

48 >>> dataset = Dataset('data.h5') 

49 >>> data = dataset.load_data( 

50 models=['model1', 'model2'], 

51 keys=['BE', 'Rad'], 

52 domain_keys=['Z', 'N'] 

53 ) 

54 """ 

55 self.domain_keys = domain_keys 

56 

57 if self.data_source is None: 

58 raise ValueError("Data source must be specified.") 

59 if not os.path.exists(self.data_source): 

60 raise FileNotFoundError(f"Data source '{self.data_source}' not found.") 

61 if keys is None: 

62 raise ValueError("You must specify which properties to extract via 'keys'.") 

63 

64 result = {} 

65 

66 for prop in keys: 

67 dfs = [] 

68 skipped_models = [] 

69 

70 if self.data_source.endswith(".h5"): 

71 for model in models: 

72 df = pd.read_hdf(self.data_source, key=model) 

73 # Check required columns 

74 missing_cols = [ 

75 col for col in domain_keys + [prop] if col not in df.columns 

76 ] 

77 if missing_cols: 

78 print( 

79 f"[Skipped] Model '{model}' missing columns {missing_cols} for property '{prop}'." 

80 ) 

81 skipped_models.append(model) 

82 continue 

83 temp = df[domain_keys + [prop]].copy() 

84 temp.rename(columns={prop: model}, inplace=True) # type: ignore 

85 dfs.append(temp) 

86 elif self.data_source.endswith(".csv"): 

87 df = pd.read_csv(self.data_source) 

88 for model in models: 

89 if model_column not in df.columns: 

90 raise ValueError( 

91 f"Expected column '{model_column}' not found in CSV." 

92 ) 

93 model_df = df[df[model_column] == model] 

94 missing_cols = [ 

95 col 

96 for col in domain_keys + [prop] 

97 if col not in model_df.columns 

98 ] 

99 if missing_cols: 

100 print( 

101 f"[Skipped] Model '{model}' missing columns {missing_cols} for key '{prop}'." 

102 ) 

103 skipped_models.append(model) 

104 continue 

105 temp = model_df[domain_keys + [prop]].copy() 

106 temp.rename(columns={prop: model}, inplace=True) 

107 dfs.append(temp) 

108 else: 

109 raise ValueError( 

110 "Unsupported file format. Only .h5 and .csv are supported." 

111 ) 

112 

113 if not dfs: 

114 print( 

115 f"[Warning] No models with property '{prop}'. Resulting DataFrame will be empty." 

116 ) 

117 result[prop] = pd.DataFrame( 

118 columns=domain_keys + [m for m in models if m not in skipped_models] 

119 ) 

120 continue 

121 

122 # Intersect domain for this property 

123 common_df = dfs[0] 

124 for other_df in dfs[1:]: 

125 common_df = pd.merge(common_df, other_df, on=domain_keys, how="inner") 

126 

127 result[prop] = common_df 

128 self.data = result 

129 return result 

130 

131 def view_data(self, property_name=None, model_name=None): 

132 """ 

133 Provides flexible data viewing options. 

134 

135 Args: 

136 property_name (str, optional): Specific property to view. 

137 model_name (str, optional): Specific model to view. 

138 

139 Returns: 

140 Union[dict[str, Union[pandas.DataFrame, str]], pandas.DataFrame, pandas.Series]: 

141 - If no args: dict of available properties/models. 

142 - If only `model_name`: dict of `{property: DataFrame}`. 

143 - If only `property_name`: DataFrame for property. 

144 - If both: Series of model values for property. 

145 

146 Raises: 

147 RuntimeError: If no data loaded. 

148 KeyError: If property or model not found. 

149 """ 

150 

151 if not self.data: 

152 raise RuntimeError("No data loaded. Run `load_data(...)` first.") 

153 

154 if property_name is None and model_name is None: 

155 props = list(self.data.keys()) 

156 models = sorted( 

157 set( 

158 col 

159 for prop_df in self.data.values() 

160 for col in prop_df.columns 

161 if col not in self.domain_keys 

162 ) 

163 ) 

164 

165 return {"available_properties": props, "available_models": models} 

166 

167 if model_name is not None and property_name is None: 

168 # Return a dictionary: {property: Series of model values} 

169 result = {} 

170 for prop, df in self.data.items(): 

171 if model_name in df.columns: 

172 cols = self.domain_keys + [model_name] 

173 result[prop] = df[cols] 

174 else: 

175 result[prop] = f"[Model '{model_name}' not available]" 

176 return result 

177 

178 if property_name is not None: 

179 if property_name not in self.data: 

180 raise KeyError(f"Property '{property_name}' not found.") 

181 

182 df = self.data[property_name] 

183 

184 if model_name is None: 

185 return df # Full property DataFrame 

186 

187 if model_name not in df.columns: 

188 raise KeyError( 

189 f"Model '{model_name}' not found in property '{property_name}'." 

190 ) 

191 

192 return df[model_name] 

193 

194 def separate_points_distance_allSets(self, list1, list2, distance1, distance2): 

195 """ 

196 Separates points into groups based on proximity thresholds. 

197 

198 Args: 

199 list1 (list[tuple[float, float]]): Points to classify as (x, y) tuples. 

200 list2 (list[tuple[float, float]]): Reference points as (x, y) tuples. 

201 distance1 (float): First proximity threshold. 

202 distance2 (float): Second proximity threshold. 

203 

204 Returns: 

205 tuple[list[int], list[int], list[int]]: Three lists of indices from `list1`: 

206 - Within `distance1` of any point in `list2`. 

207 - Within `distance2` but not `distance1`. 

208 - Beyond `distance2`. 

209 """ 

210 train = [] 

211 validation = [] 

212 test = [] 

213 

214 train_list_coordinates = [] 

215 validation_list_coordinates = [] 

216 test_list_coordinates = [] 

217 

218 for i in range(len(list1)): 

219 point1 = list1[i] 

220 close = False 

221 for point2 in list2: 

222 if np.linalg.norm(np.array(point1) - np.array(point2)) <= distance1: 

223 close = True 

224 break 

225 if close: 

226 train.append(point1) 

227 train_list_coordinates.append(i) 

228 else: 

229 close2 = False 

230 for point2 in list2: 

231 if np.linalg.norm(np.array(point1) - np.array(point2)) <= distance2: 

232 close2 = True 

233 break 

234 if close2: 

235 validation.append(point1) 

236 validation_list_coordinates.append(i) 

237 else: 

238 test.append(point1) 

239 test_list_coordinates.append(i) 

240 

241 return ( 

242 train_list_coordinates, 

243 validation_list_coordinates, 

244 test_list_coordinates, 

245 ) 

246 

247 def split_data( 

248 self, data_dict, property_name, splitting_algorithm="random", **kwargs 

249 ): 

250 """ 

251 Splits data into training, validation, and test sets. 

252 

253 Args: 

254 data_dict (dict[str, pandas.DataFrame]): Output from `load_data()`. 

255 property_name (str): Property to use for splitting. 

256 splitting_algorithm (str): 'random' or 'inside_to_outside'. 

257 **kwargs: Algorithm-specific parameters: 

258 - `random`: `train_size` (float), `val_size` (float), `test_size` (float). 

259 - `inside_to_outside`: `stable_points` (list[tuple[float, float]]), `distance1` (float), `distance2` (float). 

260 

261 Returns: 

262 tuple[pandas.DataFrame, pandas.DataFrame, pandas.DataFrame]: (train, validation, test) DataFrames. 

263 

264 Raises: 

265 ValueError: For invalid algorithm or missing parameters. 

266 """ 

267 if property_name not in data_dict: 

268 raise ValueError( 

269 f"Property '{property_name}' not found in the provided data dictionary." 

270 ) 

271 

272 data = data_dict[property_name] 

273 

274 if isinstance(data, pd.DataFrame): 

275 indexable_data = data.reset_index(drop=True) 

276 point_list = list(indexable_data.itertuples(index=False, name=None)) 

277 else: 

278 raise TypeError( 

279 "Data for the specified property must be a pandas DataFrame." 

280 ) 

281 

282 if splitting_algorithm == "random": 

283 required = ["train_size", "val_size", "test_size"] 

284 if not all(k in kwargs for k in required): 

285 raise ValueError(f"Missing required kwargs for 'random': {required}") 

286 

287 train_size = kwargs["train_size"] 

288 val_size = kwargs["val_size"] 

289 test_size = kwargs["test_size"] 

290 

291 if not np.isclose(train_size + val_size + test_size, 1.0): 

292 raise ValueError("train_size + val_size + test_size must equal 1.0") 

293 

294 # Random split using indexes 

295 train_idx, temp_idx = train_test_split( 

296 indexable_data.index, train_size=train_size, random_state=1 

297 ) 

298 val_rel = val_size / (val_size + test_size) 

299 val_idx, test_idx = train_test_split( 

300 temp_idx, test_size=1 - val_rel, random_state=1 

301 ) 

302 

303 elif splitting_algorithm == "inside_to_outside": 

304 required = ["stable_points", "distance1", "distance2"] 

305 if not all(k in kwargs for k in required): 

306 raise ValueError( 

307 f"Missing required kwargs for 'inside_to_outside': {required}" 

308 ) 

309 

310 stable_points = kwargs["stable_points"] 

311 distance1 = kwargs["distance1"] 

312 distance2 = kwargs["distance2"] 

313 

314 ( 

315 train_idx, 

316 val_idx, 

317 test_idx, 

318 ) = self.separate_points_distance_allSets( 

319 point_list, stable_points, distance1, distance2 

320 ) 

321 else: 

322 raise ValueError( 

323 "splitting_algorithm must be either 'random' or 'inside_to_outside'" 

324 ) 

325 

326 train_data = indexable_data.iloc[train_idx] 

327 val_data = indexable_data.iloc[val_idx] 

328 test_data = indexable_data.iloc[test_idx] 

329 

330 return train_data, val_data, test_data 

331 

332 def get_subset(self, property_name, filters=None, models_to_include=None): 

333 """ 

334 Returns a filtered subset of data for a property. 

335 

336 Args: 

337 property_name (str): Property to filter. 

338 filters (dict, optional): Domain filtering rules. 

339 models_to_include (list[str], optional): Models to include. 

340 

341 Returns: 

342 pandas.DataFrame: Filtered DataFrame. 

343 

344 Raises: 

345 ValueError: If property not found. 

346 """ 

347 if property_name not in self.data: 

348 raise ValueError(f"Property '{property_name}' not found in dataset.") 

349 

350 df = self.data[property_name].copy() 

351 

352 # Apply row-level filters (domain-based) 

353 if filters: 

354 for column, condition in filters.items(): 

355 if column == "multi" and callable(condition): 

356 df = df[df.apply(condition, axis=1)] 

357 elif callable(condition): 

358 df = df[condition(df[column])] 

359 elif isinstance(condition, tuple) and len(condition) == 2: 

360 df = df[(df[column] >= condition[0]) & (df[column] <= condition[1])] 

361 elif isinstance(condition, list): 

362 df = df[df[column].isin(condition)] 

363 else: 

364 df = df[df[column] == condition] 

365 

366 # Optionally restrict to a subset of models 

367 if models_to_include is not None: 

368 domain_keys = [col for col in ["N", "Z"] if col in df.columns] 

369 allowed_cols = domain_keys + [ 

370 m for m in models_to_include if m in df.columns 

371 ] 

372 df = df[allowed_cols] 

373 

374 return df