Skip to content

Commit

Permalink
if dataset fields is being passed in, validate there are no duplicate…
Browse files Browse the repository at this point in the history
…s and that all the field names are valid
  • Loading branch information
lucidrains committed Dec 6, 2022
1 parent adc2486 commit 94b664d
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 1 deletion.
6 changes: 6 additions & 0 deletions phenaki_pytorch/phenaki_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,11 @@ def __init__(
dl = self.accelerator.prepare(dl)
self.dl = cycle(dl)

if exists(dataset_fields):
assert not has_duplicates(dataset_fields), 'dataset fields must not have duplicate field names'
valid_dataset_fields = set(DATASET_FIELD_TYPE_CONFIG.keys())
assert len(set(dataset_fields) - valid_dataset_fields) == 0, f'dataset fields must be one of {valid_dataset_fields}'

self.dataset_fields = dataset_fields

# optimizer
Expand All @@ -290,6 +295,7 @@ def __init__(
def data_tuple_to_kwargs(self, data):
if not exists(self.dataset_fields):
self.dataset_fields = determine_types(data, DATASET_FIELD_TYPE_CONFIG)
assert not has_duplicates(self.dataset_fields), 'dataset fields must not have duplicate field names'

return dict(zip(self.dataset_fields, data))

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'phenaki-pytorch',
packages = find_packages(exclude=[]),
version = '0.0.52',
version = '0.0.53',
license='MIT',
description = 'Phenaki - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit 94b664d

Please sign in to comment.