Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add curriculum learning callback #1256

Merged
merged 1 commit into from
Jun 25, 2024
Merged

Add curriculum learning callback #1256

merged 1 commit into from
Jun 25, 2024

Conversation

b-chu
Copy link
Contributor

@b-chu b-chu commented Jun 5, 2024

Curriculum learning callback

Requirements

  • Requires StreamingDataset
  • Each datamix must have a duration and train_loader
  • Duration must be positive and defined in terms of epochs or tokens
  • Duration units must match max_duration
  • The length of the schedule must be equal to max_duration
  • The part of the schedule that has already been trained on cannot be changed in future resumption runs

Features

  • Defines schedules by specifying datamixes in terms of training duration
  • Supports epochs and tokens for each iteration
  • Enables single run curriculum learning
  • Swaps the entire dataloader at the start of each iteration

Other

  • Refactor build_tokenizer to avoid circular dependencies
  • Refactor process_init_dist to avoid circular dependencies

Manual tests

Matches old callback behavior

image

Resumes correctly in the middle of the schedule

image

Resumes correctly when new datamix added to schedule

image

Resumes correctly when callback added after initial training run

image

API

Old API:

train_loader:
  <some params>
callbacks:
  curriculum_learning:
    dataset_index: 0

Start a new run

train_loader:
  <some params>
callbacks:
  curriculum_learning:
    dataset_index: 1

Start a new run

train_loader:
  <some params>
callbacks:
  curriculum_learning:
    dataset_index: 2

New API:

train_loader:
  <dataloader parameters>
callback:
  curriculum_learning:
  - duration: <number>tok
    train_loader:  # matches top level train_loader
      <dataloader parameters>
  - duration: <number>tok
    train_loader:
      <dataloader parameters>
  - duration: <number>tok
    train_loader:
      <dataloader parameters>

@b-chu b-chu changed the title Refactor curriculum learning callback Add curriculum learning callback scheduling Jun 6, 2024
@b-chu b-chu changed the title Add curriculum learning callback scheduling Add curriculum learning callback Jun 6, 2024
@b-chu b-chu force-pushed the cl_callback branch 2 times, most recently from 29b3b64 to 2765769 Compare June 11, 2024 17:51
@b-chu b-chu marked this pull request as ready for review June 11, 2024 18:35
@b-chu b-chu requested a review from a team as a code owner June 11, 2024 18:35
@b-chu b-chu marked this pull request as draft June 11, 2024 18:36
@snarayan21
Copy link
Contributor

@b-chu about the new API, couple questions:

train_loader:
  <some params>
callbacks:
  curriculum_learning:
    duration: 5000000tok
    schedule:
    - duration: 5000000tok
      train_loader:
        <some params>
    - duration: 5000000tok
      train_loader:
        <some params>
  1. so I still have to specify train_loader as a top-level entry?
  2. the first duration specified is for the top-level train_loader?

@snarayan21
Copy link
Contributor

snarayan21 commented Jun 11, 2024

Also, I'm worried about the loss curves in the plots you shared, they don't look fully deterministic to me. What model size and batch size were you running at, and with which datasets? Longer training runs with a bigger model and small batch size, without shuffling, would be helpful so that we can determine if the loss curves are actually deterministic or not. Just looking at the first few steps most training runs will look pretty similar regardless of the data ordering.

Copy link
Collaborator

@mvpatel2000 mvpatel2000 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This needs a composer release first, right?

@b-chu b-chu force-pushed the cl_callback branch 19 times, most recently from bce0270 to 596b761 Compare June 13, 2024 16:45
Copy link
Collaborator

@mvpatel2000 mvpatel2000 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

First pass, mostly lgtm but some minor comments.

llmfoundry/callbacks/curriculum_learning_callback.py Outdated Show resolved Hide resolved
llmfoundry/callbacks/curriculum_learning_callback.py Outdated Show resolved Hide resolved
Copy link
Collaborator

@mvpatel2000 mvpatel2000 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

second pass, lgtm besides a few minor comments. requiring review from @milocress

llmfoundry/callbacks/curriculum_learning_callback.py Outdated Show resolved Hide resolved
llmfoundry/callbacks/curriculum_learning_callback.py Outdated Show resolved Hide resolved
llmfoundry/callbacks/curriculum_learning_callback.py Outdated Show resolved Hide resolved
llmfoundry/callbacks/curriculum_learning_callback.py Outdated Show resolved Hide resolved
Copy link
Contributor

@milocress milocress left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, just a couple questions

@b-chu b-chu force-pushed the cl_callback branch 4 times, most recently from 08d9b7f to 42406c0 Compare June 24, 2024 17:15
@b-chu b-chu requested a review from dakinggg June 24, 2024 17:19
@b-chu b-chu force-pushed the cl_callback branch 5 times, most recently from 32579a0 to 6cc5c10 Compare June 24, 2024 18:14
Copy link
Collaborator

@dakinggg dakinggg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will approve after comments are addressed, overall lgtm

@b-chu b-chu force-pushed the cl_callback branch 3 times, most recently from ec28600 to a5fa8a5 Compare June 24, 2024 23:06
@b-chu b-chu requested a review from dakinggg June 24, 2024 23:07
@b-chu b-chu enabled auto-merge (squash) June 24, 2024 23:23
@b-chu b-chu force-pushed the cl_callback branch 2 times, most recently from ddf2876 to 6f95810 Compare June 25, 2024 14:39
@b-chu b-chu merged commit ef14849 into main Jun 25, 2024
11 checks passed
@dakinggg dakinggg deleted the cl_callback branch August 6, 2024 18:41
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants