Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

_split_train_data error #255

Open
alexanderchang1 opened this issue Sep 26, 2024 · 3 comments
Open

_split_train_data error #255

alexanderchang1 opened this issue Sep 26, 2024 · 3 comments

Comments

@alexanderchang1
Copy link

Hi,

Is ContextualizedBayesianNetworks able to handle a Y target? The documentation for the fit function implies yes.

from contextualized.easy import ContextualizedBayesianNetworks

cbn = ContextualizedBayesianNetworks(
    encoder_type='mlp', num_archetypes=16,
    n_bootstraps=2, archetype_dag_loss_type="DAGMA", archetype_alpha=0.,
    sample_specific_dag_loss_type="DAGMA", sample_specific_alpha=1e-1,
    learning_rate=1e-3)
cbn.fit(C, X, Y, max_epochs=10, es_verbose=True)
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[29], line 8
      1 from contextualized.easy import ContextualizedBayesianNetworks
      3 cbn = ContextualizedBayesianNetworks(
      4     encoder_type='mlp', num_archetypes=16,
      5     n_bootstraps=2, archetype_dag_loss_type="DAGMA", archetype_alpha=0.,
      6     sample_specific_dag_loss_type="DAGMA", sample_specific_alpha=1e-1,
      7     learning_rate=1e-3)
----> 8 cbn.fit(C, X, Y, max_epochs=10, es_verbose=True)

File /bgfs/alee/LO_LAB/Personal/Alexander_Chang/alc376/envs/contextml/lib/python3.9/site-packages/contextualized/easy/wrappers/SKLearnWrapper.py:516, in SKLearnWrapper.fit(self, *args, **kwargs)
    514 for bootstrap in range(self.n_bootstraps):
    515     model = self.base_constructor(**organized_kwargs["model"])
--> 516     train_data, val_data = self._split_train_data(
    517         *args, **organized_kwargs["data"]
    518     )
    519     train_dataloader, val_dataloader = self._build_dataloaders(
    520         model,
    521         train_data,
    522         val_data,
    523         **organized_kwargs["data"],
    524     )
    525     # Makes a new trainer for each bootstrap fit - bad practice, but necessary here.

TypeError: _split_train_data() takes 3 positional arguments but 4 were given
@cnellington
Copy link
Owner

Thanks @alexanderchang1, I agree this is unclear right now. Contextualized BNs will not use Y and as such do not tolerate Y as an argument.

If you feel inspired and want to become a Contextualized contributor, maybe you could read our Contributing.md instructions and change the docstring to make the fix? If not, I will get to this soon.

@alexanderchang1
Copy link
Author

I'll take a look in the morning and submit a PR. Can I shoot you an email in regards to a specific dataset and problem I'm trying to model, and maybe you could help point me at the best contextualized model?

@cnellington
Copy link
Owner

Sure! Keeping this open while the ambiguity is still up, just so we keep track of it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants