-
Notifications
You must be signed in to change notification settings - Fork 2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat!: Add extra for inference dependencies such as torch #5147
Changes from 4 commits
01ee103
3e63947
208f993
75b012b
f0be1d2
032a8a4
c84adaa
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,17 +3,21 @@ | |
import os | ||
import logging | ||
from tqdm.auto import tqdm | ||
import torch | ||
from torch.utils.data.sampler import SequentialSampler | ||
from torch.utils.data import Dataset | ||
|
||
|
||
from haystack.modeling.data_handler.dataloader import NamedDataLoader | ||
from haystack.modeling.data_handler.processor import Processor, InferenceProcessor | ||
from haystack.modeling.data_handler.samples import SampleBasket | ||
from haystack.modeling.utils import initialize_device_settings, set_all_seeds | ||
from haystack.modeling.data_handler.inputs import QAInput | ||
from haystack.modeling.model.adaptive_model import AdaptiveModel, BaseAdaptiveModel | ||
from haystack.modeling.model.predictions import QAPred | ||
from haystack.lazy_imports import LazyImport | ||
|
||
with LazyImport() as torch_import: | ||
import torch | ||
from torch.utils.data.sampler import SequentialSampler | ||
from torch.utils.data import Dataset | ||
from haystack.modeling.utils import initialize_device_settings, set_all_seeds # pylint: disable=ungrouped-imports | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This should not be needed: this check is done in If that check does not work, let's rather understand why and fix it at that level. Otherwise we would have to try-catch all imports across all modeling modules, which I'd rather avoid 😅 |
||
|
||
|
||
logger = logging.getLogger(__name__) | ||
|
@@ -73,6 +77,7 @@ def __init__( | |
:return: An instance of the Inferencer. | ||
|
||
""" | ||
torch_import.check() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same as above |
||
# Init device and distributed settings | ||
self.devices, _ = initialize_device_settings(devices=devices, use_cuda=gpu, multi_gpu=False) | ||
if len(self.devices) > 1: | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nitpick: I would always put these check at the very first line of the init. If the deps are not there, no point doing anything else, running any super(), making any checks, etc... It should be the first issue the users are warned about.