Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Instruct tuning for lora/finetune? #484

Open
fblissjr opened this issue Feb 24, 2024 · 8 comments · May be fixed by #1211
Open

Instruct tuning for lora/finetune? #484

fblissjr opened this issue Feb 24, 2024 · 8 comments · May be fixed by #1211
Labels
enhancement New feature or request

Comments

@fblissjr
Copy link

fblissjr commented Feb 24, 2024

Please correct me if I'm wrong, but it looks like the current examples for lora training all build a loss function around completion, which lines up with the lora example of using only the 'text' field from the jsonl dataset.

Are there any forks or plans to allow for instruct tuning, where the input is an input prompt, and the loss function is targeting the input/output pair?

Or did I miss something?

Thanks!

edit: example below:

{
"prompt": "[INST] Your input prompt here[/INST]",
"text": "The expected output result here"
}

Whereas it looks like the current lora process is:
{
"text": Predict what comes [next]
}

@Solido
Copy link

Solido commented Feb 24, 2024

Need confirmation but the input inside text is multiline.
My comprehension is that line return is the input and the completion.

@fblissjr
Copy link
Author

fblissjr commented Feb 24, 2024

I think it's still just going to optimize for completion of the full text field and doesn't differentiate between the input/output? At least based on the lora code in mlx-llm.

from https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/lora.py

class Dataset:
    """
    Light-weight wrapper to hold lines from a jsonl file
    """

    def __init__(self, path: Path, key: str = "text"):
        if not path.exists():
            self._data = None
        else:
            with open(path, "r") as fid:
                self._data = [json.loads(l) for l in fid]
        self._key = key

    def __getitem__(self, idx: int):
        return self._data[idx][self._key]

    def __len__(self):
        if self._data is None:
            return 0
        return len(self._data)

@fblissjr
Copy link
Author

Looking at the mlx-llm code, I think we need to adjust the Dataset class to account for various dataset types (instruct, chat, etc). Which starts to turn into an axolotl type project eventually, but for simplicity, probably just being able to pass custom dataset types.

Then I think the tuner.trainer needs to be modified for the default loss function (or a new one added for instruct templates).

Let me know if I'm wrong here, but from what I'm seeing, the only training function available right now is for completions.

@fblissjr
Copy link
Author

fblissjr commented Feb 24, 2024

I believe this fork handles it correctly: https://github.com/chimezie/mlx-tuning-fork/blob/main/src/mlx_tuning_fork/training.py (#235)

edit: saw this PR (#213), looks like the goal is to keep the lora purely as an example, but i do think it may cause confusion for folks trying to do SFT on instruct or chat style datasets. maybe just an edit to the LoRA.md?

@awni
Copy link
Member

awni commented Feb 25, 2024

We have a more featured version of lora in mlx-lm https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/LORA.md

Assuming it doesn’t add much code complexity I think it would be cool to update it to support alternative losses / styles of training. Depending on how niche the approach is and/or complex it may also make sense to do it as a standalone package. I’ll mark this issue as enhancement for now.

@awni awni added the enhancement New feature or request label Feb 25, 2024
@chimezie
Copy link
Contributor

chimezie commented Mar 8, 2024

The recent changes to allow the loss and iterate_batches functions to be specified for the tuning process have made doing this a lot more straightforward to do. I have done this in mlx-tuning-fork, a happily shrinking thin layer over mlx_lm.lora . I can create a PR specifically for instruction tuning w/ (optional) masking of the input in the loss calculation.

However, depending on how this particular kind of tuning is specified in configs/options, I don't know how niche that would be.

@Solido
Copy link

Solido commented Mar 8, 2024

MLX is growing fast and community will soon build around a lot.
Everything that can be common ground for those projects should be welcome.
I'm myself working exclusively on instruct and patiently waiting for more options.

@chimezie
Copy link
Contributor

chimezie commented Nov 3, 2024

Gentle ping regarding this: #1086 . I don't think this approach would be too niche or add too much complexity (depending on how feasible it is to continue to rely on apply_chat_templates for handling the prompt formatting while keeping the distinction between where input tokens end and output tokens begin)

anupamme added a commit to anupamme/mlx-examples that referenced this issue Jan 20, 2025
Fixes ml-explore#484

Add support for instruct tuning with input/output pairs and alternative loss functions.

* **llms/mlx_lm/lora.py**
  - Add `CompletionsDataset` class to support input/output pairs.
  - Modify `Dataset` class to handle different dataset types.
  - Update `main` function to include new dataset type.

* **llms/mlx_lm/tuner/trainer.py**
  - Modify `default_loss` function to support alternative loss functions.
  - Add new `instruct_loss` function for instruct tuning.

* **llms/mlx_lm/LORA.md**
  - Add instructions for instruct tuning with input/output pairs.
  - Update documentation to include alternative loss functions.

* **llms/tests/test_datasets.py**
  - Add tests for `CompletionsDataset` and `create_dataset` functions.

* **llms/tests/test_trainer.py**
  - Add tests for `default_loss` and `instruct_loss` functions.

---

For more details, open the [Copilot Workspace session](https://copilot-workspace.githubnext.com/ml-explore/mlx-examples/issues/484?shareId=XXXX-XXXX-XXXX-XXXX).
@anupamme anupamme linked a pull request Jan 20, 2025 that will close this issue
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants