Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Revisit NeuroTree implementation #484

Open
pat-alt opened this issue Sep 30, 2024 · 0 comments
Open

Revisit NeuroTree implementation #484

pat-alt opened this issue Sep 30, 2024 · 0 comments

Comments

@pat-alt
Copy link
Member

pat-alt commented Sep 30, 2024

          Glad to hear @jeremiedb :)

Thanks for clarifying!

That being said, I'd be curious about how the Models.logits are Models.probs are used in the package. Ie. is it for a single shot inference of the models, or is it to be used throught iterative loops while training another model?

For gradient-based counterfactual generators, Models.logits is called iteratively as we take gradient steps in the feature space, but not to train the NeuroTreeModels or any derivative model.

In short, if it's expected that X can be a large dataset, it may pose a challenge if the NeuroTree model is very large (just like any other large Flux model). And if for quick inference on batch size like data, then workflow wise, it may be advise to ensure that the full feature matrix is first converted into Float32 before going into iterations.

That's good to know, thank you. I suppose that for large datasets this may then cause performance bottlenecks. We do have a method to wrap Tables, but that still then transforms the Table into a Matrix for downstream tasks.

The main reason we've designed things this way is that gradient-based counterfactual search involves gradient computations with respect to features (as opposed to model parameters) and matrices are easier to handle.

So I think that to change this we'll have to revisit some core design choices. For now I think I'm happy with the conversion here, although I might simplify it further since we can just generally assume Float32 as the element type.

Thanks for your advise!

Originally posted by @pat-alt in #479 (comment)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant