Skip to content

Commit

Permalink
Allow partial column mappings
Browse files Browse the repository at this point in the history
I.e. no longer require {"sentence": "text", "label": "label"}, you can just do {"sentence": "text"}
  • Loading branch information
tomaarsen committed Nov 29, 2023
1 parent d85d537 commit 62f7eea
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 2 deletions.
9 changes: 7 additions & 2 deletions src/setfit/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,10 +97,15 @@ def _validate_column_mapping(self, dataset: "Dataset") -> None:
"Either make sure these columns are present, or specify which columns to use with column_mapping in Trainer."
)
if self.column_mapping is not None:
missing_columns = self._REQUIRED_COLUMNS.difference(self.column_mapping.values())
missing_columns = set(self._REQUIRED_COLUMNS)
# Remove columns that will be provided via the column mapping
missing_columns -= set(self.column_mapping.values())
# Remove columns that will be provided because they are in the dataset & not mapped away
missing_columns -= set(dataset.column_names) - set(self.column_mapping.keys())
if missing_columns:
raise ValueError(
f"The following columns are missing from the column mapping: {missing_columns}. Please provide a mapping for all required columns."
f"The following columns are missing from the column mapping: {missing_columns}. "
"Please provide a mapping for all required columns."
)
if not set(self.column_mapping.keys()).issubset(column_names):
raise ValueError(
Expand Down
13 changes: 13 additions & 0 deletions tests/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,19 @@ def test_trainer_works_with_column_mapping(self):
metrics = trainer.evaluate()
self.assertEqual(metrics["accuracy"], 1.0)

def test_trainer_works_with_partial_column_mapping(self):
dataset = Dataset.from_dict({"text_new": ["a", "b", "c"], "label": [0, 1, 2], "extra_column": ["d", "e", "f"]})
trainer = Trainer(
model=self.model,
args=self.args,
train_dataset=dataset,
eval_dataset=dataset,
column_mapping={"text_new": "text"},
)
trainer.train()
metrics = trainer.evaluate()
self.assertEqual(metrics["accuracy"], 1.0)

def test_trainer_works_with_default_columns(self):
dataset = Dataset.from_dict({"text": ["a", "b", "c"], "label": [0, 1, 2], "extra_column": ["d", "e", "f"]})
trainer = Trainer(model=self.model, args=self.args, train_dataset=dataset, eval_dataset=dataset)
Expand Down

0 comments on commit 62f7eea

Please sign in to comment.