This repo provides tools to compute the feature attribution maps for Laplace-approximated LLMs. Given a Laplace-approximated model, the attribution can be conducted on any intermediate layer or input tokens. Additionally, ithe repo contains an easy-to-use Flask app for visualizing the obtained attribution maps.
- Install
torch
following the documentations ofPyTorch
. - Install
flash-attn
viaFLASH_ATTENTION_SKIP_CUDA_BUILD=TRUE pip install flash-attn --no-build-isolation
. - Install the package in editable mode via
pip install -e .
(do no forget the "."). For Developers, install the package along with extra dependencies viapip install -e ".[dev]"
.
Install pre-commit hooks via pre-commit install
.
run python tools/fine_tune.py configs/ft_llama-2_arc-c.yaml -w workdirs/debug
to fine-tune the model and run the laplace approximation or check python file for details.
python tools/attribution_cli.py configs/attr_llama-2_arc-c.yaml -w workdirs/debug/
or check the python file for details.
The app requires Flask
, which is already listed in "dev"
feature in setup.cfg
.
python tools/app.py /path/to/vis_attributions/
and then open http://127.0.0.1:5000
in your browser.