diff --git a/haystack/components/audio/whisper_local.py b/haystack/components/audio/whisper_local.py index 00f73fe9f6..e94697e442 100644 --- a/haystack/components/audio/whisper_local.py +++ b/haystack/components/audio/whisper_local.py @@ -104,7 +104,7 @@ def from_dict(cls, data: Dict[str, Any]) -> "LocalWhisperTranscriber": The deserialized component. """ init_params = data["init_parameters"] - if init_params["device"] is not None: + if init_params.get("device") is not None: init_params["device"] = ComponentDevice.from_dict(init_params["device"]) return default_from_dict(cls, data) diff --git a/haystack/components/embedders/sentence_transformers_document_embedder.py b/haystack/components/embedders/sentence_transformers_document_embedder.py index fb90c2b370..6e4916a5b1 100644 --- a/haystack/components/embedders/sentence_transformers_document_embedder.py +++ b/haystack/components/embedders/sentence_transformers_document_embedder.py @@ -126,9 +126,9 @@ def from_dict(cls, data: Dict[str, Any]) -> "SentenceTransformersDocumentEmbedde Deserialized component. """ init_params = data["init_parameters"] - if init_params["device"] is not None: + if init_params.get("device") is not None: init_params["device"] = ComponentDevice.from_dict(init_params["device"]) - deserialize_secrets_inplace(data["init_parameters"], keys=["token"]) + deserialize_secrets_inplace(init_params, keys=["token"]) return default_from_dict(cls, data) def warm_up(self): diff --git a/haystack/components/embedders/sentence_transformers_text_embedder.py b/haystack/components/embedders/sentence_transformers_text_embedder.py index d5e569eca5..f4dfc14e9d 100644 --- a/haystack/components/embedders/sentence_transformers_text_embedder.py +++ b/haystack/components/embedders/sentence_transformers_text_embedder.py @@ -116,9 +116,9 @@ def from_dict(cls, data: Dict[str, Any]) -> "SentenceTransformersTextEmbedder": Deserialized component. """ init_params = data["init_parameters"] - if init_params["device"] is not None: + if init_params.get("device") is not None: init_params["device"] = ComponentDevice.from_dict(init_params["device"]) - deserialize_secrets_inplace(data["init_parameters"], keys=["token"]) + deserialize_secrets_inplace(init_params, keys=["token"]) return default_from_dict(cls, data) def warm_up(self): diff --git a/haystack/components/extractors/named_entity_extractor.py b/haystack/components/extractors/named_entity_extractor.py index b8083742f9..bfd5e0dbc7 100644 --- a/haystack/components/extractors/named_entity_extractor.py +++ b/haystack/components/extractors/named_entity_extractor.py @@ -221,7 +221,7 @@ def from_dict(cls, data: Dict[str, Any]) -> "NamedEntityExtractor": """ try: init_params = data["init_parameters"] - if init_params["device"] is not None: + if init_params.get("device") is not None: init_params["device"] = ComponentDevice.from_dict(init_params["device"]) init_params["backend"] = NamedEntityExtractorBackend[init_params["backend"]] return default_from_dict(cls, data) diff --git a/haystack/components/rankers/sentence_transformers_diversity.py b/haystack/components/rankers/sentence_transformers_diversity.py index 9a6fea0199..8319532892 100644 --- a/haystack/components/rankers/sentence_transformers_diversity.py +++ b/haystack/components/rankers/sentence_transformers_diversity.py @@ -142,9 +142,9 @@ def from_dict(cls, data: Dict[str, Any]) -> "SentenceTransformersDiversityRanker The deserialized component. """ init_params = data["init_parameters"] - if init_params["device"] is not None: + if init_params.get("device") is not None: init_params["device"] = ComponentDevice.from_dict(init_params["device"]) - deserialize_secrets_inplace(data["init_parameters"], keys=["token"]) + deserialize_secrets_inplace(init_params, keys=["token"]) return default_from_dict(cls, data) def _prepare_texts_to_embed(self, documents: List[Document]) -> List[str]: diff --git a/haystack/components/rankers/transformers_similarity.py b/haystack/components/rankers/transformers_similarity.py index 58da869007..980de5fd90 100644 --- a/haystack/components/rankers/transformers_similarity.py +++ b/haystack/components/rankers/transformers_similarity.py @@ -176,11 +176,12 @@ def from_dict(cls, data: Dict[str, Any]) -> "TransformersSimilarityRanker": :returns: Deserialized component. """ - deserialize_secrets_inplace(data["init_parameters"], keys=["token"]) init_params = data["init_parameters"] - if init_params["device"] is not None: + if init_params.get("device") is not None: init_params["device"] = ComponentDevice.from_dict(init_params["device"]) - deserialize_hf_model_kwargs(init_params["model_kwargs"]) + if init_params.get("model_kwargs") is not None: + deserialize_hf_model_kwargs(init_params["model_kwargs"]) + deserialize_secrets_inplace(init_params, keys=["token"]) return default_from_dict(cls, data) diff --git a/haystack/components/readers/extractive.py b/haystack/components/readers/extractive.py index 03a090d783..85d90a74cf 100644 --- a/haystack/components/readers/extractive.py +++ b/haystack/components/readers/extractive.py @@ -170,10 +170,11 @@ def from_dict(cls, data: Dict[str, Any]) -> "ExtractiveReader": Deserialized component. """ init_params = data["init_parameters"] - deserialize_secrets_inplace(data["init_parameters"], keys=["token"]) - if init_params["device"] is not None: + deserialize_secrets_inplace(init_params, keys=["token"]) + if init_params.get("device") is not None: init_params["device"] = ComponentDevice.from_dict(init_params["device"]) - deserialize_hf_model_kwargs(init_params["model_kwargs"]) + if init_params.get("model_kwargs") is not None: + deserialize_hf_model_kwargs(init_params["model_kwargs"]) return default_from_dict(cls, data) diff --git a/releasenotes/notes/hf-models-from-dict-default-values-47c2c73136ea6643.yaml b/releasenotes/notes/hf-models-from-dict-default-values-47c2c73136ea6643.yaml new file mode 100644 index 0000000000..71b09ebf7f --- /dev/null +++ b/releasenotes/notes/hf-models-from-dict-default-values-47c2c73136ea6643.yaml @@ -0,0 +1,3 @@ +fixes: + - | + This updates the components, TransformersSimilarityRanker, SentenceTransformersDiversityRanker, SentenceTransformersTextEmbedder, SentenceTransformersDocumentEmbedder and LocalWhisperTranscriber from_dict methods to work when loading with init_parameters only containing required parameters. diff --git a/test/components/audio/test_whisper_local.py b/test/components/audio/test_whisper_local.py index ee2c07c314..d30ce3eebc 100644 --- a/test/components/audio/test_whisper_local.py +++ b/test/components/audio/test_whisper_local.py @@ -74,6 +74,13 @@ def test_from_dict(self): assert transcriber.whisper_params == {} assert transcriber._model is None + def test_from_dict_no_default_parameters(self): + data = {"type": "haystack.components.audio.whisper_local.LocalWhisperTranscriber", "init_parameters": {}} + transcriber = LocalWhisperTranscriber.from_dict(data) + assert transcriber.model == "large" + assert transcriber.device == ComponentDevice.resolve_device(None) + assert transcriber.whisper_params == {} + def test_from_dict_none_device(self): data = { "type": "haystack.components.audio.whisper_local.LocalWhisperTranscriber", diff --git a/test/components/embedders/test_sentence_transformers_document_embedder.py b/test/components/embedders/test_sentence_transformers_document_embedder.py index 4f495e2ea8..67a8a840a9 100644 --- a/test/components/embedders/test_sentence_transformers_document_embedder.py +++ b/test/components/embedders/test_sentence_transformers_document_embedder.py @@ -137,6 +137,25 @@ def test_from_dict(self): assert component.trust_remote_code assert component.meta_fields_to_embed == ["meta_field"] + def test_from_dict_no_default_parameters(self): + component = SentenceTransformersDocumentEmbedder.from_dict( + { + "type": "haystack.components.embedders.sentence_transformers_document_embedder.SentenceTransformersDocumentEmbedder", + "init_parameters": {}, + } + ) + assert component.model == "sentence-transformers/all-mpnet-base-v2" + assert component.device == ComponentDevice.resolve_device(None) + assert component.token == Secret.from_env_var(["HF_API_TOKEN", "HF_TOKEN"], strict=False) + assert component.prefix == "" + assert component.suffix == "" + assert component.batch_size == 32 + assert component.progress_bar is True + assert component.normalize_embeddings is False + assert component.embedding_separator == "\n" + assert component.trust_remote_code is False + assert component.meta_fields_to_embed == [] + def test_from_dict_none_device(self): init_parameters = { "model": "model", diff --git a/test/components/embedders/test_sentence_transformers_text_embedder.py b/test/components/embedders/test_sentence_transformers_text_embedder.py index 7eb712e673..8c603f6ebf 100644 --- a/test/components/embedders/test_sentence_transformers_text_embedder.py +++ b/test/components/embedders/test_sentence_transformers_text_embedder.py @@ -122,6 +122,22 @@ def test_from_dict(self): assert component.normalize_embeddings is False assert component.trust_remote_code is False + def test_from_dict_no_default_parameters(self): + data = { + "type": "haystack.components.embedders.sentence_transformers_text_embedder.SentenceTransformersTextEmbedder", + "init_parameters": {}, + } + component = SentenceTransformersTextEmbedder.from_dict(data) + assert component.model == "sentence-transformers/all-mpnet-base-v2" + assert component.device == ComponentDevice.resolve_device(None) + assert component.token == Secret.from_env_var(["HF_API_TOKEN", "HF_TOKEN"], strict=False) + assert component.prefix == "" + assert component.suffix == "" + assert component.batch_size == 32 + assert component.progress_bar is True + assert component.normalize_embeddings is False + assert component.trust_remote_code is False + def test_from_dict_none_device(self): data = { "type": "haystack.components.embedders.sentence_transformers_text_embedder.SentenceTransformersTextEmbedder", diff --git a/test/components/extractors/test_named_entity_extractor.py b/test/components/extractors/test_named_entity_extractor.py index 180bee41e8..a4826c1e95 100644 --- a/test/components/extractors/test_named_entity_extractor.py +++ b/test/components/extractors/test_named_entity_extractor.py @@ -40,6 +40,17 @@ def test_named_entity_extractor_serde(): _ = NamedEntityExtractor.from_dict(serde_data) +def test_named_entity_extractor_from_dict_no_default_parameters_hf(): + data = { + "type": "haystack.components.extractors.named_entity_extractor.NamedEntityExtractor", + "init_parameters": {"backend": "HUGGING_FACE", "model": "dslim/bert-base-NER"}, + } + extractor = NamedEntityExtractor.from_dict(data) + + assert extractor._backend.model_name == "dslim/bert-base-NER" + assert extractor._backend.device == ComponentDevice.resolve_device(None) + + # tests for NamedEntityExtractor serialization/deserialization in a pipeline def test_named_entity_extractor_pipeline_serde(tmp_path): extractor = NamedEntityExtractor(backend=NamedEntityExtractorBackend.HUGGING_FACE, model="dslim/bert-base-NER") diff --git a/test/components/rankers/test_sentence_transformers_diversity.py b/test/components/rankers/test_sentence_transformers_diversity.py index f012a7d9a2..ba3b10ae5c 100644 --- a/test/components/rankers/test_sentence_transformers_diversity.py +++ b/test/components/rankers/test_sentence_transformers_diversity.py @@ -144,6 +144,25 @@ def test_from_dict_none_device(self): assert ranker.meta_fields_to_embed == [] assert ranker.embedding_separator == "\n" + def test_from_dict_no_default_parameters(self): + data = { + "type": "haystack.components.rankers.sentence_transformers_diversity.SentenceTransformersDiversityRanker", + "init_parameters": {}, + } + ranker = SentenceTransformersDiversityRanker.from_dict(data) + + assert ranker.model_name_or_path == "sentence-transformers/all-MiniLM-L6-v2" + assert ranker.top_k == 10 + assert ranker.device == ComponentDevice.resolve_device(None) + assert ranker.similarity == "cosine" + assert ranker.token == Secret.from_env_var(["HF_API_TOKEN", "HF_TOKEN"], strict=False) + assert ranker.query_prefix == "" + assert ranker.document_prefix == "" + assert ranker.query_suffix == "" + assert ranker.document_suffix == "" + assert ranker.meta_fields_to_embed == [] + assert ranker.embedding_separator == "\n" + def test_to_dict_with_custom_init_parameters(self): component = SentenceTransformersDiversityRanker( model="sentence-transformers/msmarco-distilbert-base-v4", diff --git a/test/components/rankers/test_transformers_similarity.py b/test/components/rankers/test_transformers_similarity.py index 04b9fe425b..9e083ffa68 100644 --- a/test/components/rankers/test_transformers_similarity.py +++ b/test/components/rankers/test_transformers_similarity.py @@ -172,6 +172,27 @@ def test_from_dict(self): "device_map": ComponentDevice.resolve_device(None).to_hf(), } + def test_from_dict_no_default_parameters(self): + data = { + "type": "haystack.components.rankers.transformers_similarity.TransformersSimilarityRanker", + "init_parameters": {}, + } + + component = TransformersSimilarityRanker.from_dict(data) + assert component.device is None + assert component.model_name_or_path == "cross-encoder/ms-marco-MiniLM-L-6-v2" + assert component.token == Secret.from_env_var(["HF_API_TOKEN", "HF_TOKEN"], strict=False) + assert component.top_k == 10 + assert component.query_prefix == "" + assert component.document_prefix == "" + assert component.meta_fields_to_embed == [] + assert component.embedding_separator == "\n" + assert component.scale_score + assert component.calibration_factor == 1.0 + assert component.score_threshold is None + # torch_dtype is correctly deserialized + assert component.model_kwargs == {"device_map": ComponentDevice.resolve_device(None).to_hf()} + @patch("torch.sigmoid") @patch("torch.sort") def test_embed_meta(self, mocked_sort, mocked_sigmoid): diff --git a/test/components/readers/test_extractive.py b/test/components/readers/test_extractive.py index 2f2b098d47..9c42c44254 100644 --- a/test/components/readers/test_extractive.py +++ b/test/components/readers/test_extractive.py @@ -243,6 +243,25 @@ def test_from_dict(): } +def test_from_dict_no_default_parameters(): + data = {"type": "haystack.components.readers.extractive.ExtractiveReader", "init_parameters": {}} + + component = ExtractiveReader.from_dict(data) + assert component.model_name_or_path == "deepset/roberta-base-squad2-distilled" + assert component.device is None + assert component.token == Secret.from_env_var(["HF_API_TOKEN", "HF_TOKEN"], strict=False) + assert component.top_k == 20 + assert component.score_threshold is None + assert component.max_seq_length == 384 + assert component.stride == 128 + assert component.max_batch_size is None + assert component.answers_per_seq is None + assert component.no_answer + assert component.calibration_factor == 0.1 + assert component.overlap_threshold == 0.01 + assert component.model_kwargs == {"device_map": ComponentDevice.resolve_device(None).to_hf()} + + def test_from_dict_no_token(): data = { "type": "haystack.components.readers.extractive.ExtractiveReader",