Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Validation loss #1864

Merged
merged 78 commits into from
Jan 26, 2025
Merged

Conversation

rockerBOO
Copy link
Contributor

@rockerBOO rockerBOO commented Jan 3, 2025

Related #1856 #1858 #1165 #914

Original implementation by @rockerBOO
Timestep validation implementation by @gesen2egee
Updated implementation for sd3/flux by @hinablue

I went through and tried to merge the different PR's together. I probably messed up some things in the process.

One thing I wanted to note is that process_batch was made to limit duplication of the code for validation and training to keep them consistent. I implemented the timestep processing so it could work for both. Noted that it was using only debiased_estimation in other PR's but i didn't know why it was like that.

train_db.py I did not update appropriately to my goal of a unified process_batch, as I do not have a good way to test them. I will try to get them in an acceptable state and we can refine it.

I'm posting this a little early so others can view and give me feedback. I am still working on some issues with the code so let me know before you dive in to fix anything. Open to commits to this PR, can post them to this branch on my fork.

Testing

  • Test training code is actually training
  • Test validation epoch (Test validation every epoch)
  • Test validate per n steps (After n steps it will run a validation run)
  • Test validate per n epochs (After n epochs will run validation epochs)
  • Test max validation steps
  • Test validation split (The validation split should be split accordingly, 0.2 should produce 20% dataset of the primary dataset)
  • Test validation split from train_network.py arguments (--validation_split) as well as dataset_config.toml (validation_split=0.1)
  • Test validation seed (Seed is used for dataset shuffling only right now)
  • Test image latent caching (validation and training datasets)
  • Test tokenizing strategy (SD, SDXL, SD3, Flux)
  • Test text encoding strategy (SD, SDXL, SD3, Flux)
  • Test --network_train_text_encoder_only
  • Test --network_train_unet_only
  • Test training some text encoders (I think this is a feature?)
  • Test on SD1.5, SDXL, SD3, Flux LoRAs

Parameters

Validation dataset is for dreambooth datasets (text/image pairs) and will split the dataset into 2 parts, train_dataset and validation_dataset depending on the split.

  • --validation_seed Validation seed for shuffling validation dataset, training --seed used otherwise / 検証データセットをシャッフルするための検証シード、それ以外の場合はトレーニング --seed を使用する
  • --validation_split Split for validation images out of the training dataset / 学習画像から検証画像に分割する割合
  • --validate_every_n_steps Run validation on validation dataset every N steps. By default, validation will only occur every epoch if a validation dataset is available / 検証データセットの検証をNステップごとに実行します。デフォルトでは、検証データセットが利用可能な場合にのみ、検証はエポックごとに実行されます
  • --validate_every_n_epochs Run validation dataset every N epochs. By default, validation will run every epoch if a validation dataset is available / 検証データセットをNエポックごとに実行します。デフォルトでは、検証データセットが利用可能な場合、検証はエポックごとに実行されます
  • --max_validation_steps Max number of validation dataset items processed. By default, validation will run the entire validation dataset / 処理される検証データセット項目の最大数。デフォルトでは、検証は検証データセット全体を実行します

validation_seed and validation_split can be set inside the dataset_config.toml

I'm open to feedback about this approach and if anything needs to be fixed in the code to be accurate.

@kohya-ss
Copy link
Owner

kohya-ss commented Jan 8, 2025

Not at all - your code is excellent and very well structured. Thank you for taking the time to submit this PR.

In config_util.py, the following will return two DatasetGroups (or one DatasetGroup and None):

    return (
        DatasetGroup(datasets),
        DatasetGroup(val_datasets) if val_datasets else None
    )

However flux_train.py etc. seem to expect one DatasetGroup.

        train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)

I haven't run the code yet, so I apologize if it actually works.

@rockerBOO
Copy link
Contributor Author

I have updated all calls to config_util.generate_dataset_group_by_blueprint to handle extracting from the Tuple and added the return type to that function to help with typechecking. Thanks for pointing it out.

Copy link
Owner

@kohya-ss kohya-ss left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the update. I'm sorry the code is complicated.

I'm currently checking the training script to see if it still works as before. There was one problem (in two places), so I would appreciate it if you could check it.

library/train_util.py Show resolved Hide resolved
Divergence is the difference between training and validation to
allow a clear value to indicate the difference between the two
in the logs.
Adding metadata recording for validation arguments
Add comments about the validation split for clarity of intention
@rockerBOO
Copy link
Contributor Author

Screenshot 2025-01-12 at 20-11-19 cyberpunk-boo-kohya-lora Workspace – Weights   Biases

Added divergence value for step and epoch, indicating the difference between training and validation. Will make it easier to see the difference and not have to rely on overlapping. Maybe a different term would be better as divergence might indicate how much it's moving apart and away from convergence. Also might be better to invert the current to go to the other way to match the loss values.

Fixed a bunch of things with regularization images datasets and repeats. Fix some issues with validate every n steps (which is important when using repeats and regularization images)

@rockerBOO
Copy link
Contributor Author

Bug: If text encoders are cached and validation is enabled, and validation is running the process_batch

input_ids = [ids.to(accelerator.device) for ids in batch["input_ids_list"]]

Errors that the batch['input_ids_list'] is None

Copy link
Owner

@kohya-ss kohya-ss left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for the delay. I've started testing. I've added a review of what I've noticed so far. Please check it out when you have time.

train_network.py Show resolved Hide resolved
train_network.py Show resolved Hide resolved
train_network.py Show resolved Hide resolved
train_network.py Show resolved Hide resolved
train_network.py Show resolved Hide resolved
@rockerBOO
Copy link
Contributor Author

  • Added 0 check for LossRecorder
  • Added train_text_encoder/train_unet values for validation batches.
  • Added val_dataset_group to assert_extra_args
  • Added numpy<=2.0 to requirements for compatibility (some libraries were not constraining this and caused things to break)

@stepfunction83
Copy link

stepfunction83 commented Jan 26, 2025

Shit, I should've looked here before going nuts on my own over the past two days: #1898

Only looks like I'm a little over a year late to the party.

I've been working off of this since it was posted: https://github.com/spacepxl/demystifying-sd-finetuning

In my implementation of this, I refactor the training loop by extracting the loss function outside of it. Then within the loop, I run the calculations for test/validation loss with accumulation turned off, and finally run the loss calculation for the actual training sample and continuing. Everything happens within a single global step. It's pretty straightforward and doesn't really take advantage of some of the nice PyTorch features, but it's also simple and gives a nice clean result.

It wasn't really clear from the code, but are you running the same noise/timesteps through the loss calculation each time for the validation set?

Ideally, it is my understanding that the same samples are used for test and validation sets each time and each sample gets the same noise and timesteps for each loss calculation, that way the loss is consistently calculated each time, with the model being the only variable that changes. To accomplish this, on the first iteration, I create and record a state variable from the loss calculation for each test/validation sample which captures the noise/timesteps/sigmas and then replays these values for future loss calculations.

@rockerBOO
Copy link
Contributor Author

There was some code in previous versions to allow a distribution of timesteps to be set (instead of random) which I think is aligned to what you're suggesting. It could be a single timestep or we can cache a random timestep and use the same one. This was removed to allow us to merge this PR and we can approach it with a new PR to add that as it would be involved with some different systems.

I think having more static options or stable options might be a good idea for limited datasets because the variability of a limit dataset might skew the results. A validation dataset may have a few times like 2-10 so having it pick a couple poor timesteps could cause issues the loss calculation to not highlight the right things. The current idea is to take a distribution of timesteps and average them. Like 50, 250, 500, 600, 900 timesteps. This could also be applied to regular training to smooth out the variance from random timesteps, but with a significantly increased training time.

Stabilizing the noise of the initial latents may be worth experimenting with but storing those latents could cause the training take up more memory that scales up with validation size.

Also need to consider how limited the datasets were working with, training time, and goals of what validation is doing for us in training. As the datasets get larger these variance become less particular (more examples) and smoothing out the loss for the charts might be enough to highlight it.

For loss we also do an average of the other steps in an epoch which smooths out the "current" loss and the same for validation but because of the limited datasets I think having more samples from timesteps can help smooth it out with less data.

@stepfunction83
Copy link

stepfunction83 commented Jan 26, 2025 via email

@kohya-ss
Copy link
Owner

This PR was about to be merged. I wish I had worked on it a little sooner...

It shouldn't be too difficult to use the same timesteps for each validation. We may consider addressing this in a separate PR in the future.

@kohya-ss kohya-ss changed the base branch from sd3 to val-loss January 26, 2025 12:06
@kohya-ss kohya-ss merged commit b833d47 into kohya-ss:val-loss Jan 26, 2025
2 checks passed
@kohya-ss
Copy link
Owner

sdxl_train_textual_inversion.py and sdxl_train_control_net_lllite.py raises an error, so I merged this into a new branch. I will merge it into sd3 after I fix it. I will finish it today.

Thank you again for this great work!

@kohya-ss kohya-ss mentioned this pull request Jan 26, 2025
@stepfunction83
Copy link

Is this only implemented for LoRA training?

@kohya-ss
Copy link
Owner

Is this only implemented for LoRA training?

Currently only available for LoRA training.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants