Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LoRA and DoRA PEFT support for Fine-Tuning TimesFM #104

Merged
merged 24 commits into from
Aug 6, 2024

Conversation

tanmayshishodia
Copy link
Contributor

@tanmayshishodia tanmayshishodia commented Jul 16, 2024

Thank you for this great project and providing open source inference code.

What does this PR do?

  • A generic fine-tuning pipeline which supports 4 fine-tuning strategies:
    • Full Fine-Tuning
    • Linear Probing [fine-tunes only the residual blocks and the embedding layer]
    • LoRA [fine-tunes only a small number of parameters by decomposing the weight matrices into low-rank matrices, making it efficient in terms of memory and computational resources]
    • DoRA [an extension of LoRA that decomposes the pre-trained weight into magnitude and direction components and exploits LoRA for ‌directional adaptation improving learning capacity and stability without additional inference overhead; accepted in ICML 24]
  • Add testing framework [pytest]

Why is this PR needed?

  • Primary motivation for this PR is to leverage LoRA to enable efficient training of multiple adapter weights on various tasks, domains, and time-series datasets while maintaining the same base model.
  • Implementation of PEFT techniques are largely unexplored on time series foundational models. Help the research community analyze PEFT on these models.

Performance Comparison of LoRA/DoRA with Linear Probing

image

Experiments were performed with a split of 60-20-20 train, val, and test split. Black denotes best, blue denotes second best.
Benchmarking was done with context_len=128 and horizon_len=96, and fine-tuning was done with context_len=128 and horizon_len=128.

Caveats

  • The loading of adapter weights is currently under optimized.
  • I am new to the PaxML framework. Please assist with optimization wherever possible.

Functional Testing

  • Test FFT all params get updated.
  • Test Linear Probing only residual block params get updated.
  • Test only LoRA params get updated in LoRA FT. Vary rank r.
  • Test LoRA and DoRA magnitude vector is updated in DoRA FT. Vary rank r.

Ref PRs

@rajatsen91
Copy link
Collaborator

rajatsen91 commented Jul 16, 2024

Thanks @tanmayshishodia. These are very welcome contributions. Since the CL is pretty big it will take us some time to review and merge:

  1. We added a notebook that can do linear probing (as well as full finetuning) under notebooks/finetuning.ipynb. This really needs just a one line change. The rest of the notebook is just writing the training loop and dealing with paxml models. It might make sense to add LoRA and DoRA examples to that notebook.

  2. A general style nit is that we are trying to not import layers individually but the whole module. For instance instead of from praxis.layers.linears import Linear we would prefer from praxis.layers import linears and then do linears.LInear.

  3. Please let us know when your changes are ready to review and @siriuz42 and I can take a shot at reviewing.

@rajatsen91 rajatsen91 self-assigned this Jul 16, 2024
@rajatsen91
Copy link
Collaborator

rajatsen91 commented Jul 16, 2024

  1. Would you be able to setup testing using your favorite framework ?

@tanmayshishodia
Copy link
Contributor Author

Sure.

  1. This script is largely modified from the example notebook shared in notebooks/finetuning.ipynb. I can add an example of LoRA/DoRA there. The script shall be useful as it let's you run multiple experiments simultaneously with different configurations. I will be adding .sh for common configurations to quickly run them, for example.
  2. Sure, will make changes to address this.
  3. Sure. I will change PR status to ready for review. There are some issues/bottlenecks I am aware of, I will mention them so you can help address them.
  4. I can set it up with pytest, if that sounds good? Shall I keep it in a separate PR or this one?

Feel free to ask me any other questions you may have.

@rajatsen91
Copy link
Collaborator

SGTM to all. For 4, pytest sounds fine.

You can add some tests to this PR it self. Something that covers the call function of the layers should be fine. Let us know if you have any praxis or paxml related questions.

@tanmayshishodia
Copy link
Contributor Author

  1. The current workflow for loading the model for adapter fine-tuning is:
    • Load the base model checkpoint which loads train_state and jit_decodes the model [tfm].
    • Create an instance of finetune model [model].
    • Replace the attention and linear layers in thestacked transformer block of the model with LoRA/DoRA layers defined in adapters/lora_layers.py, adapters/dora_layers.py. [ref]
    • Since tfm was loaded with base model checkpoint, I have to manually add the LoRA/DoRA adapter weights for each layer. [ref]. The setup method defined in the layer files don't aid in performing initialization.

I believe there should be a better way to load the checkpoint along with the initialized lora and dora weights. I did try replacing the layers, instantiating and then loading the checkpoint but could not do so successfully. Let me know your thoughts on this.

  1. After fine-tuning only the adapter weights are saved which I manually extract using ref. The loading of the fine-tuned model is as follows:
    • Load the base model ckpt and jit_decode.
    • To load the adapter ckpt I first create the config for a adapter model and extract the necessary var_weight_hparams to load the adapter ckpt and then merge them using the same forward pass logic and then again jit_decode the model.

This whole process can be more optimized I believe. Let me know how that can be done and if you have any ideas.

  1. Although the initialization of the adapter weights is not being done using setup method currently, dora_m param is initialized with the column norm of the pre-trained weight. How to do so with WeightInit in setup method so that it can be used later?

@rajatsen91
Copy link
Collaborator

Hi @tanmayshishodia,

Regarding "Replace the attention and linear layers in the stacked transformer block of the model with LoRA/DoRA layers defined in adapters/lora_layers.py, adapters/dora_layers.py"

afaik LoRA adds a low rank adapter additively to the original weights which are held fixed as in $Wx + \Delta W x$. However, here you are replacing the original attention weights. May be this is just a terminology issue and you are just adding $\Delta W$ and not removing $W$ ?

@tanmayshishodia
Copy link
Contributor Author

tanmayshishodia commented Jul 17, 2024

Hi @rajatsen91

Yes, apologies for framing it incorrectly. The LoRA/DoRA layers defined in the files inherit original Linear and attention layers. While doing the forward pass we get the original weight matrix and add the LoRA delta (A and B which multiply to form delta W) while the original weight matrix is fixed. [ref].

also sharing a diff between weights before training LoRA/DoRA and after a training epoch: https://www.diffchecker.com/SOotQLkx/.

@tanmayshishodia tanmayshishodia marked this pull request as ready for review July 18, 2024 03:24
@tanmayshishodia tanmayshishodia changed the title LoRA/DoRA support for Fine-Tuning TimesFM LoRA and DoRA PEFT support for Fine-Tuning TimesFM Jul 19, 2024
@tanmayshishodia
Copy link
Contributor Author

@rajatsen91 @siriuz42 PR is ready for first round of review. I have added a simple test case which tests inference. I will be adding more, let me know what scenarios need to be covered.

@rghosh08
Copy link

@tanmayshishodia We, at Nutanix, greatly appreciate your effort as an intern in fine-tuning TimesFM. As your mentor, I am super proud to see this.

@rajatsen91 We, at Nutanix, are looking forward to contributing in your TimesFM project. We believe this will have significant impact across the industries. Thanks!

@rajatsen91
Copy link
Collaborator

Thanks again for the PR. Our team is traveling this week, will look into this next week

peft/fft.sh Outdated
@@ -0,0 +1,22 @@
#!/bin/bash

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel like different .sh scripts are not needed, one script with command line options is good enough. or we can not check in these scripts and have it as example usage in the header comment of finetune.py

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, I will keep only one of them.

Running python3 finetune.py --help gives the following o/p, which is self-explanatory. Perhaps I can add it in README?
image

Copy link
Collaborator

@rajatsen91 rajatsen91 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM overall. Left a minor comment about the shell scripts.

@rajatsen91
Copy link
Collaborator

The numbers in the table "Performance Comparison of LoRA/DoRA with Linear Probing" does not match the numbers I get from the finetuning.ipynb notebook that I had. Can you please clarify what are the differences ?

In particular for ETTm1 test split I get MAE: 0.351 for the base model.

--adam-clip-threshold=1e2 \
--early-stop-patience=10 \
--datetime-col="date" \
--boundaries=1000 46080 57600 \
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is this 1000 here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My bad, thank you for pointing it out. I will remove this param in the .sh file. The script automatically takes the split as 60-20-20.

@tanmayshishodia
Copy link
Contributor Author

tanmayshishodia commented Aug 3, 2024

The numbers in the table "Performance Comparison of LoRA/DoRA with Linear Probing" does not match the numbers I get from the finetuning.ipynb notebook that I had. Can you please clarify what are the differences ?

In particular for ETTm1 test split I get MAE: 0.351 for the base model.

Hi Rajat, all the experiments were run with boundaries of 60-20-20 % of the given dataset. I just noticed that even though the boundaries in the notebook are spaced in the same configuration they aren't using the whole dataset. For ettm1, the number of data points is 69680, however, the test boundary is 57600 in the notebook. Hence it does not match. Can you try running the notebook with boundaries: [41808, 55744, 69680] for ettm1, you should get the same result.

I will add this in the PR description.

@rajatsen91
Copy link
Collaborator

Ok sounds great. I think from my side it is ready to merge, great work. It would be great if you can add a README.md in the peft folder along with the results in your table. Thanks again for the great contribution.

@rajatsen91
Copy link
Collaborator

LGTM from my side. Will wait to hear from @siriuz42 and try to merge by EOD.

@rajatsen91 rajatsen91 merged commit 577e4e8 into google-research:master Aug 6, 2024
1 of 2 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants