Skip to content

Commit

Permalink
Fixing inference data loading and conversion for regression tasks
Browse files Browse the repository at this point in the history
  • Loading branch information
SebieF committed Oct 14, 2024
1 parent 11ebfeb commit a9056b6
Showing 1 changed file with 17 additions and 14 deletions.
31 changes: 17 additions & 14 deletions biotrainer/inference/inferencer.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,16 +193,20 @@ def _pad_target(self, sequence: Union[List, Any], length_to_pad: int):
return sequence

def _convert_target_dict(self, target_dict: Dict[str, str]):
if self.protocol in Protocol.per_residue_protocols():
max_prediction_length = len(max(target_dict.values(), key=len))
return {seq_id: torch.tensor(self._pad_target(self._convert_class_str2int(prediction),
length_to_pad=max_prediction_length),
device=self.device)
for seq_id, prediction in target_dict.items()}
else:
return {seq_id: torch.tensor(self._convert_class_str2int(prediction),
device=self.device)
for seq_id, prediction in target_dict.items()}
if self.protocol in Protocol.classification_protocols():
if self.protocol in Protocol.per_residue_protocols():
max_prediction_length = len(max(target_dict.values(), key=len))
return {seq_id: torch.tensor(self._pad_target(self._convert_class_str2int(prediction),
length_to_pad=max_prediction_length),
device=self.device)
for seq_id, prediction in target_dict.items()}
else:
return {seq_id: torch.tensor(self._convert_class_str2int(prediction),
device=self.device)
for seq_id, prediction in target_dict.items()}
return {seq_id: torch.tensor(prediction,
device=self.device)
for seq_id, prediction in target_dict.items()}

def _load_solver_and_dataloader(self, embeddings: Union[Iterable, Dict],
split_name, targets: Optional[List] = None):
Expand All @@ -214,14 +218,13 @@ def _load_solver_and_dataloader(self, embeddings: Union[Iterable, Dict],
else:
embeddings_dict = {str(idx): embedding for idx, embedding in enumerate(embeddings)}

converted_targets = None
if targets:
converted_targets = [self._convert_class_str2int(target) for target in targets]
if targets and self.protocol in Protocol.classification_protocols():
targets = [self._convert_class_str2int(target) for target in targets]

solver, loader = self.solvers_and_loaders_by_split[split_name]
dataset = get_dataset(self.protocol, samples=[
DatasetSample(seq_id, torch.tensor(np.array(embedding)),
torch.empty(1) if not targets else torch.tensor(np.array(converted_targets[idx])))
torch.empty(1) if not targets else torch.tensor(np.array(targets[idx])))
for idx, (seq_id, embedding) in enumerate(embeddings_dict.items())
])
dataloader = loader(dataset)
Expand Down

0 comments on commit a9056b6

Please sign in to comment.