-
Notifications
You must be signed in to change notification settings - Fork 965
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Instruct tuning for lora/finetune? #484
Comments
Need confirmation but the input inside text is multiline. |
I think it's still just going to optimize for completion of the full text field and doesn't differentiate between the input/output? At least based on the lora code in mlx-llm. from https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/lora.py
|
Looking at the mlx-llm code, I think we need to adjust the Dataset class to account for various dataset types (instruct, chat, etc). Which starts to turn into an axolotl type project eventually, but for simplicity, probably just being able to pass custom dataset types. Then I think the tuner.trainer needs to be modified for the default loss function (or a new one added for instruct templates). Let me know if I'm wrong here, but from what I'm seeing, the only training function available right now is for completions. |
I believe this fork handles it correctly: https://github.com/chimezie/mlx-tuning-fork/blob/main/src/mlx_tuning_fork/training.py (#235) edit: saw this PR (#213), looks like the goal is to keep the lora purely as an example, but i do think it may cause confusion for folks trying to do SFT on instruct or chat style datasets. maybe just an edit to the LoRA.md? |
We have a more featured version of lora in mlx-lm https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/LORA.md Assuming it doesn’t add much code complexity I think it would be cool to update it to support alternative losses / styles of training. Depending on how niche the approach is and/or complex it may also make sense to do it as a standalone package. I’ll mark this issue as enhancement for now. |
The recent changes to allow the loss and iterate_batches functions to be specified for the tuning process have made doing this a lot more straightforward to do. I have done this in mlx-tuning-fork, a happily shrinking thin layer over mlx_lm.lora . I can create a PR specifically for instruction tuning w/ (optional) masking of the input in the loss calculation. However, depending on how this particular kind of tuning is specified in configs/options, I don't know how niche that would be. |
MLX is growing fast and community will soon build around a lot. |
Gentle ping regarding this: #1086 . I don't think this approach would be too niche or add too much complexity (depending on how feasible it is to continue to rely on apply_chat_templates for handling the prompt formatting while keeping the distinction between where input tokens end and output tokens begin) |
Fixes ml-explore#484 Add support for instruct tuning with input/output pairs and alternative loss functions. * **llms/mlx_lm/lora.py** - Add `CompletionsDataset` class to support input/output pairs. - Modify `Dataset` class to handle different dataset types. - Update `main` function to include new dataset type. * **llms/mlx_lm/tuner/trainer.py** - Modify `default_loss` function to support alternative loss functions. - Add new `instruct_loss` function for instruct tuning. * **llms/mlx_lm/LORA.md** - Add instructions for instruct tuning with input/output pairs. - Update documentation to include alternative loss functions. * **llms/tests/test_datasets.py** - Add tests for `CompletionsDataset` and `create_dataset` functions. * **llms/tests/test_trainer.py** - Add tests for `default_loss` and `instruct_loss` functions. --- For more details, open the [Copilot Workspace session](https://copilot-workspace.githubnext.com/ml-explore/mlx-examples/issues/484?shareId=XXXX-XXXX-XXXX-XXXX).
Please correct me if I'm wrong, but it looks like the current examples for lora training all build a loss function around completion, which lines up with the lora example of using only the 'text' field from the jsonl dataset.
Are there any forks or plans to allow for instruct tuning, where the input is an input prompt, and the loss function is targeting the input/output pair?
Or did I miss something?
Thanks!
edit: example below:
{
"prompt": "[INST] Your input prompt here[/INST]",
"text": "The expected output result here"
}
Whereas it looks like the current lora process is:
{
"text": Predict what comes [next]
}
The text was updated successfully, but these errors were encountered: