Skip to content

Application of Self-Supervised Alignment with Mutual Information (SAMI) to math, reading, coding, and chain-of-thought

License

Notifications You must be signed in to change notification settings

SohamGovande/sami-extension

Repository files navigation

An Exploration of Self-Supervised Mutual Information Alignment for Multi-Task Settings

This repository contains the code for "An Exploration of Self-Supervised Mutual Information Alignment for Multi-Task Settings". We build on the initial SAMI reference repository to explore using SAMI for multi-task settings.

Getting Started

  1. Clone the repository: git clone https://github.com/sohamgovande/sami-extension.git
  2. Change the directory to the cloned repository: cd sami-extension
  3. Create a new conda environment named 'sami' with Python version 3.10.0: conda create -n sami python=3.10.0
  4. Activate the newly created conda environment: conda activate sami
  5. Install the package in editable mode so that you can make changes to the code: pip install -e .

We successfully ran all data generation and training jobs on a machine with one NVIDIA A100 GPU. If you run into out-of-memory (OOM) errors, try reducing the batch size or use a smaller model. Empirically, we find that mistral-7b works well on A40 GPUs, and mistral-tiny-base works well on CPU-only hosts.

📚 Experiment 1: Multi-Task Learning with SAMI for MT-Bench

Folder: experiments/mt_bench

Models: llama3.1-8b-instruct, llama3.1-8b

  1. Generate data (Optional): In our experiment, we use llama3.1-8b-instruct to generate data to train llama3.1-8b. You can adjust the model and other parameters in conf/generate.yaml. Then, run python generate_sami_data.py to generate the data. This step is optional, as we have provided the data in the training_data folder. Note that if you need to generate training data, then you must first download the datasets using experiments/mt_bench/mt_bench/input_data/training_data_generator.ipynb.
  2. Train via SAMI: You can adjust the model and other parameters in conf/train.yaml. Then, run python train.py to start training. (Looking to train via DPO? Check out Eric Mitchell's repo, which is what we used to do this.)
  3. Evaluate: You can adjust the model and other parameters in conf/evaluate.yaml. This will run the models to generate outputs on the MT-Bench dataset.
  4. Compute win rates: You can adjust the model and other parameters in conf/win_rates.yaml. This will compute the win rates between the models.

Experiment 1

🤔 Experiment 2: Chain-of-Thought Mathematical Reasoning

Folder: experiments/math

Model: mistral-7b

  1. Generate data (Optional): We use mistral-7b to self-generate data and filter out poor-quality responses as defined in generate.py. After making this file and/or conf/generate.yaml, run python generate.py to generate the data. This step is optional, as we have provided the data in the training_data folder.
  2. Train via SAMI: You can adjust the model and other parameters in conf/train.yaml. Then, run python train.py to start training. To train via SFT, you can use a generic HuggingFace SFTTrainer, or Eric Mitchell's repo, which we used to do this.
  3. Evaluate: You can adjust the model and other parameters in conf/evaluate.yaml. This will run the models to generate outputs on the MT-Bench dataset.
  4. Compute Accuracy: You can adjust the model and other parameters in conf/accuracy.yaml, and then run via python accuracy_rate.py. This will compute the accuracy of the model.

Experiment 2 Experiment 2, Figure 2

About

Application of Self-Supervised Alignment with Mutual Information (SAMI) to math, reading, coding, and chain-of-thought

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published