Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow me to access loss values for GAN-based synthesizers #1671

Closed
npatki opened this issue Nov 7, 2023 · 0 comments · Fixed by #1681
Closed

Allow me to access loss values for GAN-based synthesizers #1671

npatki opened this issue Nov 7, 2023 · 0 comments · Fixed by #1681
Assignees
Labels
feature request Request for a new feature
Milestone

Comments

@npatki
Copy link
Contributor

npatki commented Nov 7, 2023

Problem Description

The GAN-based synthesizers work by iterating through many different epochs and optimizing a loss function. Currently, all of our core ML models store a DataFrame with the loss values per epoch under a parameter called .loss_values.

See:

Now that the underlying ML models are storing loss values, they can be exposed through the SDV synthesizers using a simple function.

Expected behavior

Each of the GAN-based synthesizers should include a function get_loss_values() that returns a copy of the loss values. If the synthesizer has not yet been fit, then this function could raise an error.

# available for CTGAN, TVAE, CopulaGAN and PAR
synthesizer.fit(data)
loss_values_df = synthesizer.get_loss_values()

# if a synthesizer has not yet been fit
unfit_synthesizer.get_loss_values()
SynthesizerProcessingError: Loss values are not available yet. Please fit your synthesizer first.

For multi-table synthesizes, it's possible to have a GAN-based model for single tables. (In both HSA and Independent). These synthesizers should then be able to return the loss values given the table name.

# for HSA and Independent, it's possible to use GAN-based synthesizers for individual table modeling
hsa_synthesizer.set_table_parameters(
  table_name='users',
  table_synthesizer='CTGANSynthesizer')

# error if you haven't fit yet
users_loss = hsa_synthesizer.get_loss_values(table_name='users')
SynthesizerProcessingError: Loss values are not available yet. Please fit your synthesizer first.

# after fitting, it should be possible to access the loss values for those tables
hsa_synthesizer.fit(data)
users_loss = has_synthesizer.get_loss_values(table_name='users')

# if you try to get loss values from something that's not GAN-based, then there should be an error
txns_loss = hsa_synthesizer.get_loss_values(table_name='transactions')
SynthesizerInputError: Loss values are not available for table 'transactions' because the table does not use a GAN-based model.

Additional context

In order to do this issue, we'll likely need to make a CTGAN release and a DeepEcho release first. (The PRs are merged, but the releases haven't been made, so the SDV won't be able to access loss values.)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature request Request for a new feature
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants