diff --git a/phenaki_pytorch/phenaki_trainer.py b/phenaki_pytorch/phenaki_trainer.py index 7b8ecdb..6af6082 100644 --- a/phenaki_pytorch/phenaki_trainer.py +++ b/phenaki_pytorch/phenaki_trainer.py @@ -270,6 +270,11 @@ def __init__( dl = self.accelerator.prepare(dl) self.dl = cycle(dl) + if exists(dataset_fields): + assert not has_duplicates(dataset_fields), 'dataset fields must not have duplicate field names' + valid_dataset_fields = set(DATASET_FIELD_TYPE_CONFIG.keys()) + assert len(set(dataset_fields) - valid_dataset_fields) == 0, f'dataset fields must be one of {valid_dataset_fields}' + self.dataset_fields = dataset_fields # optimizer @@ -290,6 +295,7 @@ def __init__( def data_tuple_to_kwargs(self, data): if not exists(self.dataset_fields): self.dataset_fields = determine_types(data, DATASET_FIELD_TYPE_CONFIG) + assert not has_duplicates(self.dataset_fields), 'dataset fields must not have duplicate field names' return dict(zip(self.dataset_fields, data)) diff --git a/setup.py b/setup.py index 36b7d85..488a988 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'phenaki-pytorch', packages = find_packages(exclude=[]), - version = '0.0.52', + version = '0.0.53', license='MIT', description = 'Phenaki - Pytorch', author = 'Phil Wang',