Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH] rework TimeSeriesDataSet using LightningDataModule - experimental #1766

Open
fkiraly opened this issue Feb 10, 2025 · 9 comments · May be fixed by #1770
Open

[ENH] rework TimeSeriesDataSet using LightningDataModule - experimental #1766

fkiraly opened this issue Feb 10, 2025 · 9 comments · May be fixed by #1770
Assignees
Labels
API design API design & software architecture enhancement New feature or request

Comments

@fkiraly
Copy link
Collaborator

fkiraly commented Feb 10, 2025

Umbrella issue for pytorch-forecasting 2.0 design: #1736

In sktime/enhancement-proposals#39, @phoeenniixx suggested a LightningDataModule based design for the end state of dsipts and pytorch-forecasting 2.0.

As an work item, this implies a rework of the TimeSeriesDataSet using LightningDataModule, covering layers D1 and D2 (referencing the EP), with the D1 layer based on a refined design following #1757 (but simpler).

@phoeenniixx agreed to give this a go as part of an experimental PR.

@fkiraly fkiraly added API design API design & software architecture enhancement New feature or request labels Feb 10, 2025
@phoeenniixx
Copy link

with the D1 layer based on a refined design following #1757 (but simpler).

Should we use the "dict" implementation for metadata or keep it separate as it is now?

@fkiraly
Copy link
Collaborator Author

fkiraly commented Feb 11, 2025

For D1, I think your dict idea is better than the list based format I suggested, because it allows to keep the number of metadata fields small. That is, your original suggestion, but with my subsequent modification of having even less dicts and metadata fields, is what I would go for.

For the start we can assume everything is float and future-known though, I estimate otherwise there would be a lot of boring boilerplate in handling the different column types etc.

The reason for that is, I think we should get to an end-to-end design quickly and see how it looks like and how/whether it works, because we might modify or even abandon it. The work on the boilerplate would then be lost.

Whereas, if this proves to be the way to go, it is still easy to add it on top.

@phoeenniixx
Copy link

Few problems I found with this approach:

  • torch dataloader expects Dataset class and not LightningDataModule
  • LightningDataModule is meant to do just the data "handling" and to pass it to the dataloader we need to wrap it around Dataset

The to_dataloader function of TimeSeriesDataset passes self to the the dataloader, while if we use just LightningDataModule, it is not possible.

Proposed solutions:

  1. We can add one more layer after LightningDataModule, that is a Dataset layer
class TimeSeriesDataset(Dataset):
   def __init__(self, datamodule: 'DecoderEncoderDataModule'):

       self.datamodule = datamodule
       self.tsd = datamodule.tsd  # Preprocessed TimeSeries data

   def __len__(self):
       return len(self.tsd)

   def __getitem__(self, idx):
       # Fetch raw sample from TimeSeries
       batch = self.tsd[idx]

       # Apply all transformations inside the datamodule
       transformed_batch = self.datamodule.transformation(batch)

       return transformed_batch
  1. We can remove the LightningDataModule layer completely and can the other proposed approach, of passing the metadata in the __init__ of the D2:
class DecoderEncoderData(Dataset):
   def __init__(self, tsd: PandasTSDataSet, **params):
       self.tsd = tsd  # Store dataset reference

       # Access metadata from dataset (D1)
       self.metadata = tsd.get_metadata()

   def __getitem__(self, idx):
       sample = self.tsd[idx]  
       # other required additions to ``sample``
       return sample

@fkiraly
Copy link
Collaborator Author

fkiraly commented Feb 11, 2025

could you double check how the LightningDataModule interacts with a LighningModel, in the vanilla vignette? Is the additional layer between them really required? That would surprise me.

I feel the LightningDataModule is needed to satisfy the requirement of dsipts to be able to specify train, validation splits.

@phoeenniixx
Copy link

phoeenniixx commented Feb 11, 2025

In this tutorial, the dataloaders in datamodule get only dataset class and not the datamodule class itself,

LightningDataModule only defines how to train, validate, and test the model. It does not handle dataset indexing, transformations, or batching—this is the job of DataLoader.

DataLoader requires a Dataset object to properly iterate over data, as it has __getitem__ and __len__ implementations.

More clear example:

class MNISTDataModule(LightningDataModule):
    def __init__(self, batch_size=64):
        super().__init__()
        self.batch_size = batch_size

    def prepare_data(self):
        # download only
        MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor())
        MNIST(os.getcwd(), train=False, download=True, transform=transforms.ToTensor())

    def setup(self, stage):
        # transform
        transform=transforms.Compose([transforms.ToTensor()])
        MNIST(os.getcwd(), train=True, download=False, transform=transform)
        MNIST(os.getcwd(), train=False, download=False, transform=transform)

        # train/val split
        mnist_train, mnist_val = random_split(mnist_train, [55000, 5000])

        # assign to use in dataloaders
        self.train_dataset = mnist_train
        self.val_dataset = mnist_val
        self.test_dataset = mnist_test

    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.batch_size)

    def val_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=self.batch_size)

    def test_dataloader(self):
        return DataLoader(self.test_dataset, batch_size=self.batch_size)

source : https://pytorch-lightning.readthedocs.io/en/0.10.0/introduction_guide.html#the-engineering

here you can see, they pass the datasets and not the module itself.

DataModules are useful while training and testing as then you just pass the model and the module and everything is handled being the curtain.

dm = MNISTDataModule()
model = LitMNIST()
trainer = Trainer(tpu_cores=8)
trainer.fit(model, dm)

Here LitMNIST is a lightning module

Difference between using and not using data modules of lightning [Source]

Image

@fkiraly
Copy link
Collaborator Author

fkiraly commented Feb 12, 2025

ok, that is what I thought - what you just posted means that there is a direct layer interaction between LightningDataModule and LightningModel (via the Trainer). The DataLoader creation etc happens under the hood.

So there is no need to introduce another DataLoader or DataSet layer.

Can you hence explain in which respect these would be issues?
(quoting you from above)

torch dataloader expects Dataset class and not LightningDataModule
LightningDataModule is meant to do just the data "handling" and to pass it to the dataloader we need to wrap it around Dataset

I think this needs to happen, but it happens "under the hood" in LightningDataModule.

@phoeenniixx
Copy link

phoeenniixx commented Feb 12, 2025

What I meant was, we discussed that we could implement D2 layer as LightningDataModule but that would mean we do all the normalization and stuff in this data module and then how can we pass this "processed" data to the dataloader in the data module?
As we cannot pass any dict of tensors, we need to pass the Dataset class in the dataloader right?

So if we do the preprocessing in data module then how to use that processed data? we need to change back that processed data to the Dataset class right?

So I meant we can do the processing in Dataset class and then pass that to the data module?

@fkiraly
Copy link
Collaborator Author

fkiraly commented Feb 14, 2025

Hm, I see, so you see we will need an "internal" DataSet in the LightningDataModule, to pass to the internal DataLoader? Which will cause us to have two DataSet-s, although one would be entirely private to the LightningDataModule?

I think that would be fine, because that is a private module.

@phoeenniixx
Copy link

Thanks! that's a good idea. Then I will create a private Dataset class in the data module

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
API design API design & software architecture enhancement New feature or request
Projects
2 participants