diff --git a/docs/_src/api/api/reader.md b/docs/_src/api/api/reader.md index 1ce4ed4e73..bc58a11b60 100644 --- a/docs/_src/api/api/reader.md +++ b/docs/_src/api/api/reader.md @@ -333,6 +333,25 @@ Saves the Reader model so that it can be reused at a later point in time. - `directory`: Directory where the Reader model should be saved + + +#### FARMReader.save\_to\_remote + +```python +def save_to_remote(repo_id: str, private: Optional[bool] = None, commit_message: str = "Add new model to Hugging Face.") +``` + +Saves the Reader model to Hugging Face Model Hub with the given model_name. For this to work: + +- Be logged in to Hugging Face on your machine via transformers-cli +- Have git lfs installed (https://packagecloud.io/github/git-lfs/install), you can test it by git lfs --version + +**Arguments**: + +- `repo_id`: A namespace (user or an organization) and a repo name separated by a '/' of the model you want to save to Hugging Face +- `private`: Set to true to make the model repository private +- `commit_message`: Commit message while saving to Hugging Face + #### FARMReader.predict\_batch diff --git a/haystack/nodes/reader/farm.py b/haystack/nodes/reader/farm.py index b4355672d7..971afd5f5e 100644 --- a/haystack/nodes/reader/farm.py +++ b/haystack/nodes/reader/farm.py @@ -4,8 +4,12 @@ import multiprocessing from pathlib import Path from collections import defaultdict +import os +import tempfile from time import perf_counter + import torch +from huggingface_hub import create_repo, HfFolder, Repository from haystack.errors import HaystackError from haystack.modeling.data_handler.data_silo import DataSilo, DistillationDataSilo @@ -688,6 +692,58 @@ def save(self, directory: Path): self.inferencer.model.save(directory) self.inferencer.processor.save(directory) + def save_to_remote( + self, repo_id: str, private: Optional[bool] = None, commit_message: str = "Add new model to Hugging Face." + ): + """ + Saves the Reader model to Hugging Face Model Hub with the given model_name. For this to work: + - Be logged in to Hugging Face on your machine via transformers-cli + - Have git lfs installed (https://packagecloud.io/github/git-lfs/install), you can test it by git lfs --version + + :param repo_id: A namespace (user or an organization) and a repo name separated by a '/' of the model you want to save to Hugging Face + :param private: Set to true to make the model repository private + :param commit_message: Commit message while saving to Hugging Face + """ + # Note: This function was inspired by the save_to_hub function in the sentence-transformers repo (https://github.com/UKPLab/sentence-transformers/) + # Especially for git-lfs tracking. + + token = HfFolder.get_token() + if token is None: + raise ValueError( + "To save this reader model to Hugging Face, make sure you login to the hub on this computer by typing `transformers-cli login`." + ) + + repo_url = create_repo(token=token, repo_id=repo_id, private=private, repo_type=None, exist_ok=True) + + transformer_models = self.inferencer.model.convert_to_transformers() + + with tempfile.TemporaryDirectory() as tmp_dir: + repo = Repository(tmp_dir, clone_from=repo_url) + + self.inferencer.processor.tokenizer.save_pretrained(tmp_dir) + + # convert_to_transformers (above) creates one model per prediction head. + # As the FarmReader models only have one head (QA) we go with this. + transformer_models[0].save_pretrained(tmp_dir) + + large_files = [] + for root, dirs, files in os.walk(tmp_dir): + for filename in files: + file_path = os.path.join(root, filename) + rel_path = os.path.relpath(file_path, tmp_dir) + + if os.path.getsize(file_path) > (5 * 1024 * 1024): + large_files.append(rel_path) + + if len(large_files) > 0: + logger.info("Track files with git lfs: {}".format(", ".join(large_files))) + repo.lfs_track(large_files) + + logger.info("Push model to the hub. This might take a while") + commit_url = repo.push_to_hub(commit_message=commit_message) + + return commit_url + def predict_batch( self, queries: List[str], diff --git a/setup.cfg b/setup.cfg index 33913a9f6b..cbc9f75c23 100644 --- a/setup.cfg +++ b/setup.cfg @@ -75,6 +75,8 @@ install_requires = # azure-core>=1.23 needs typing-extensions>=4.0.1 # pip unfortunately backtracks into the databind direction ultimately getting lost. azure-core<1.23 + # audio's espnet-model-zoo requires huggingface-hub version <0.8 while we need >=0.5 to be able to use create_repo in FARMReader + huggingface-hub<0.8.0,>=0.5.0 # Preprocessing more_itertools # for windowing @@ -157,7 +159,6 @@ audio = espnet espnet-model-zoo pydub - huggingface-hub<0.8.0 beir = beir; platform_system != 'Windows' crawler =