This repository contains the code for "An Exploration of Self-Supervised Mutual Information Alignment for Multi-Task Settings". We build on the initial SAMI reference repository to explore using SAMI for multi-task settings.
- Clone the repository:
git clone https://github.com/sohamgovande/sami-extension.git
- Change the directory to the cloned repository:
cd sami-extension
- Create a new conda environment named 'sami' with Python version 3.10.0:
conda create -n sami python=3.10.0
- Activate the newly created conda environment:
conda activate sami
- Install the package in editable mode so that you can make changes to the code:
pip install -e .
We successfully ran all data generation and training jobs on a machine with one NVIDIA A100 GPU. If you run into out-of-memory (OOM) errors, try reducing the batch size or use a smaller model. Empirically, we find that mistral-7b
works well on A40 GPUs, and mistral-tiny-base
works well on CPU-only hosts.
Folder: experiments/mt_bench
Models: llama3.1-8b-instruct
, llama3.1-8b
- Generate data (Optional): In our experiment, we use
llama3.1-8b-instruct
to generate data to trainllama3.1-8b
. You can adjust the model and other parameters inconf/generate.yaml
. Then, runpython generate_sami_data.py
to generate the data. This step is optional, as we have provided the data in thetraining_data
folder. Note that if you need to generate training data, then you must first download the datasets usingexperiments/mt_bench/mt_bench/input_data/training_data_generator.ipynb
. - Train via SAMI: You can adjust the model and other parameters in
conf/train.yaml
. Then, runpython train.py
to start training. (Looking to train via DPO? Check out Eric Mitchell's repo, which is what we used to do this.) - Evaluate: You can adjust the model and other parameters in
conf/evaluate.yaml
. This will run the models to generate outputs on the MT-Bench dataset. - Compute win rates: You can adjust the model and other parameters in
conf/win_rates.yaml
. This will compute the win rates between the models.
Folder: experiments/math
Model: mistral-7b
- Generate data (Optional): We use
mistral-7b
to self-generate data and filter out poor-quality responses as defined ingenerate.py
. After making this file and/orconf/generate.yaml
, runpython generate.py
to generate the data. This step is optional, as we have provided the data in thetraining_data
folder. - Train via SAMI: You can adjust the model and other parameters in
conf/train.yaml
. Then, runpython train.py
to start training. To train via SFT, you can use a generic HuggingFace SFTTrainer, or Eric Mitchell's repo, which we used to do this. - Evaluate: You can adjust the model and other parameters in
conf/evaluate.yaml
. This will run the models to generate outputs on the MT-Bench dataset. - Compute Accuracy: You can adjust the model and other parameters in
conf/accuracy.yaml
, and then run viapython accuracy_rate.py
. This will compute the accuracy of the model.