From 28d0ce6211bcb9da99aec560769f2f3ec5f52975 Mon Sep 17 00:00:00 2001 From: Wessel Bruinsma Date: Tue, 20 Aug 2024 12:08:51 +0200 Subject: [PATCH] Extend docs --- docs/beware.md | 18 ++++++++++++++++++ docs/finetuning.md | 19 ++++++++++++++++++- 2 files changed, 36 insertions(+), 1 deletion(-) diff --git a/docs/beware.md b/docs/beware.md index f2d5f71..632dfec 100644 --- a/docs/beware.md +++ b/docs/beware.md @@ -35,3 +35,21 @@ you should do two things: 1. Set `torch.use_deterministic_algorithms(True)` to make PyTorch operations deterministic. 2. Set `model.eval()` to disable drop-out. + +## Loading a Checkpoint Onto an Extended Model + +If you changed the model and added or removed parameters, you need to set `strict=False` when +loading a checkpoint `Aurora.load_checkpoint(..., strict=False)`. +Importantly, enabling or disabling LoRA for a model that was trained respectively without or +with LoRA changes the parameters! + +## Extending the Model with New Surface-Level Variables + +Whereas we have attempted to design a robust and flexible model, +inevitably some unfortunate design choices slipped through. + +A notable unfortunate design choice is that extending the model with a new surface-level +variable breaks compatibility with existing checkpoints. +It is possible to hack around this in a relatively simple way. +We are working on a more principled fix. +Please open an issue if this is a problem for you. diff --git a/docs/finetuning.md b/docs/finetuning.md index 274471b..f8726d0 100644 --- a/docs/finetuning.md +++ b/docs/finetuning.md @@ -10,4 +10,21 @@ model = Aurora(use_lora=False) # Model is not fine-tuned. model.load_checkpoint("wbruinsma/aurora", "aurora-0.25-pretrained.ckpt") ``` -More specific instructions coming soon. +You are also free to extend the model for your particular use case. +In that case, it might be that you add or remove parameters. +Then `Aurora.load_checkpoint` will error, +because the existing checkpoint now mismatches with the model's parameters. +Simply set `Aurora.load_checkpoint(..., strict=False)`: + +```python +from aurora import Aurora + + +model = Aurora(...) + +... # Modify `model` + +model.load_checkpoint("wbruinsma/aurora", "aurora-0.25-pretrained.ckpt", strict=False) +``` + +More instructions coming soon!