Skip to content

Commit

Permalink
Fix bug in prob col rename logic
Browse files Browse the repository at this point in the history
  • Loading branch information
OmegaLambda1998 committed Jan 31, 2024
1 parent a5d43d1 commit 3b2e058
Showing 1 changed file with 11 additions and 3 deletions.
14 changes: 11 additions & 3 deletions pippin/aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,25 +253,33 @@ def _run(self,):
# Want to loop over each number and grab the relevant IDs and classifiers
for index in range(self.num_versions):
relevant_classifiers = [c for c in self.classifiers if c.index == index]
self.logger.debug(f"relevant_classifiers: {relevant_classifiers}")

prediction_files = [d.output["predictions_filename"] for d in relevant_classifiers]
lcfits = [d.get_fit_dependency() for d in relevant_classifiers]
self.logger.debug(f"lcfits: {lcfits}")

df = None

colnames = [self.classifier_merge[d.name] for d in relevant_classifiers]
self.logger.debug(f"colnames: {colnames}")
need_to_rename = len(colnames) != len(set(colnames))
rename_ind = []
if need_to_rename:
self.logger.info("Detected duplicate probability column names, will need to rename them")
for (i, n) in enumerate(colnames):
if len([j for j in range(len(colnames)) if colnames[j] == n]) > 1:
rename_ind.append(i)

for f, d, l in zip(prediction_files, relevant_classifiers, lcfits):
for i, (f, d, l) in enumerate(zip(prediction_files, relevant_classifiers, lcfits)):
self.logger.debug(f"l: {l}")
dataframe = self.load_prediction_file(f)
dataframe = dataframe.rename(columns={d.get_prob_column_name(): self.classifier_merge[d.name]})
dataframe = dataframe.rename(columns={dataframe.columns[0]: self.id})
dataframe[self.id] = dataframe[self.id].apply(str)
dataframe[self.id] = dataframe[self.id].str.strip()
if need_to_rename and l is not None:
lcname = l["name"]
if need_to_rename and (l is not None or l != []) and i in rename_ind:
lcname = ensure_list(l)[0]["name"]
self.logger.debug(f"Renaming column {self.classifier_merge[d.name]} to include LCFIT name {lcname}")
dataframe = dataframe.rename(columns={self.classifier_merge[d.name]: self.classifier_merge[d.name] + "_RENAMED_" + lcname})
self.logger.debug(f"Merging on column {self.id} for file {f}")
Expand Down

0 comments on commit 3b2e058

Please sign in to comment.