Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat!: Add extra for inference dependencies such as torch #5147

Merged
merged 7 commits into from
Jun 20, 2023
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ jobs:
python-version: ${{ env.PYTHON_VERSION }}

- name: Install Haystack
run: pip install .[elasticsearch,dev,preprocessing]
run: pip install .[elasticsearch,dev,preprocessing,inference]

- name: Run tests
run: |
Expand Down Expand Up @@ -608,7 +608,7 @@ jobs:
python-version: ${{ env.PYTHON_VERSION }}

- name: Install Haystack
run: pip install .[dev,preprocessing]
run: pip install .[dev,preprocessing,inference]

- name: Run tests
run: |
Expand Down Expand Up @@ -662,7 +662,7 @@ jobs:
python-version: ${{ env.PYTHON_VERSION }}

- name: Install Haystack
run: pip install .[dev,preprocessing]
run: pip install .[dev,preprocessing,inference]

- name: Run tests
run: |
Expand Down Expand Up @@ -716,7 +716,7 @@ jobs:
python-version: ${{ env.PYTHON_VERSION }}

- name: Install Haystack
run: pip install .[dev,preprocessing]
run: pip install .[dev,preprocessing,inference]

- name: Run tests
run: |
Expand Down
2 changes: 2 additions & 0 deletions haystack/document_stores/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,8 @@ def __init__(
self.bm25_parameters = bm25_parameters
self.bm25: Dict[str, rank_bm25.BM25] = {}

torch_import.check()

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nitpick: I would always put these check at the very first line of the init. If the deps are not there, no point doing anything else, running any super(), making any checks, etc... It should be the first issue the users are warned about.

self.devices, _ = initialize_device_settings(devices=devices, use_cuda=self.use_gpu, multi_gpu=False)
if len(self.devices) > 1:
logger.warning(
Expand Down
13 changes: 9 additions & 4 deletions haystack/modeling/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,21 @@
import os
import logging
from tqdm.auto import tqdm
import torch
from torch.utils.data.sampler import SequentialSampler
from torch.utils.data import Dataset


from haystack.modeling.data_handler.dataloader import NamedDataLoader
from haystack.modeling.data_handler.processor import Processor, InferenceProcessor
from haystack.modeling.data_handler.samples import SampleBasket
from haystack.modeling.utils import initialize_device_settings, set_all_seeds
from haystack.modeling.data_handler.inputs import QAInput
from haystack.modeling.model.adaptive_model import AdaptiveModel, BaseAdaptiveModel
from haystack.modeling.model.predictions import QAPred
from haystack.lazy_imports import LazyImport

with LazyImport() as torch_import:
import torch
from torch.utils.data.sampler import SequentialSampler
from torch.utils.data import Dataset
from haystack.modeling.utils import initialize_device_settings, set_all_seeds # pylint: disable=ungrouped-imports
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should not be needed: this check is done in haystack/modeling/__init__.py, which always run before this module: https://github.com/deepset-ai/haystack/blob/main/haystack/modeling/__init__.py

If that check does not work, let's rather understand why and fix it at that level. Otherwise we would have to try-catch all imports across all modeling modules, which I'd rather avoid 😅



logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -73,6 +77,7 @@ def __init__(
:return: An instance of the Inferencer.

"""
torch_import.check()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above

# Init device and distributed settings
self.devices, _ = initialize_device_settings(devices=devices, use_cuda=gpu, multi_gpu=False)
if len(self.devices) > 1:
Expand Down
1 change: 1 addition & 0 deletions haystack/nodes/question_generator/question_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ def __init__(
parameter is not used and a single CPU device is used for inference.

"""
torch_and_transformers_import.check()
ZanSara marked this conversation as resolved.
Show resolved Hide resolved
super().__init__()
self.devices, _ = initialize_device_settings(devices=devices, use_cuda=use_gpu, multi_gpu=False)
if len(self.devices) > 1:
Expand Down
18 changes: 11 additions & 7 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ classifiers = [
dependencies = [
"requests",
"pydantic",
"transformers[torch,sentencepiece]==4.30.1",
"transformers==4.30.1",
"pandas",
"rank_bm25",
"scikit-learn>=1.0.0", # TF-IDF, SklearnQueryClassifier and metrics
Expand All @@ -62,16 +62,15 @@ dependencies = [
"quantulum3", # quantities extraction from text
"posthog", # telemetry
# audio's espnet-model-zoo requires huggingface-hub version <0.8 while we need >=0.5 to be able to use create_repo in FARMReader
"huggingface-hub>=0.5.0",
"tenacity", # retry decorator
"sseclient-py", # server side events for OpenAI streaming
"more_itertools", # utilities

# Web Retriever
"boilerpy3",

# See haystack/nodes/retriever/_embedding_encoder.py, _SentenceTransformersEmbeddingEncoder
"sentence-transformers>=2.2.0",
# Multimodal Embedder haystack/nodes/retriever/multimodal/embedder.py
"Pillow",

# OpenAI tokenizer
"tiktoken>=0.3.2",
Expand All @@ -89,6 +88,11 @@ dependencies = [
]

[project.optional-dependencies]
inference = [
"transformers[torch,sentencepiece]==4.30.1",
"sentence-transformers>=2.2.0", # See haystack/nodes/retriever/_embedding_encoder.py, _SentenceTransformersEmbeddingEncoder
"huggingface-hub>=0.5.0",
]
elasticsearch = [
"elasticsearch>=7.17,<8",
]
Expand Down Expand Up @@ -212,11 +216,11 @@ formatting = [
]

all = [
"farm-haystack[docstores,audio,crawler,preprocessing,file-conversion,pdf,ocr,ray,onnx,beir,metrics]",
"farm-haystack[inference,docstores,audio,crawler,preprocessing,file-conversion,pdf,ocr,ray,onnx,beir,metrics]",
]
all-gpu = [
# beir is incompatible with faiss-gpu: https://github.com/beir-cellar/beir/issues/71
"farm-haystack[docstores-gpu,audio,crawler,preprocessing,file-conversion,pdf,ocr,ray,onnx-gpu,metrics]",
"farm-haystack[inference,docstores-gpu,audio,crawler,preprocessing,file-conversion,pdf,ocr,ray,onnx-gpu,metrics]",
]

[project.scripts]
Expand Down