diff --git a/lazyllm/tools/rag/store.py b/lazyllm/tools/rag/store.py index b86983e7..7a93ba75 100644 --- a/lazyllm/tools/rag/store.py +++ b/lazyllm/tools/rag/store.py @@ -92,7 +92,7 @@ def has_embedding(self) -> bool: return self.embedding and self.embedding[0] != -1 # placeholder def do_embedding(self, embed: Callable) -> None: - self.embedding = embed(self.text) + self.embedding = embed(self.get_content(MetadataMode.EMBED)) self.is_saved = False def get_content(self, metadata_mode: MetadataMode = MetadataMode.NONE) -> str: diff --git a/tests/basic_tests/test_doc_node.py b/tests/basic_tests/test_doc_node.py index 7029c0c7..3cce441c 100644 --- a/tests/basic_tests/test_doc_node.py +++ b/tests/basic_tests/test_doc_node.py @@ -1,3 +1,4 @@ +from unittest.mock import MagicMock from lazyllm.tools.rag.store import DocNode, MetadataMode @@ -15,6 +16,12 @@ def setup_method(self): self.node.excluded_embed_metadata_keys = ["author"] self.node.excluded_llm_metadata_keys = ["date"] + def test_do_embedding(self): + """Test that do_embedding passes the correct content to the embed function.""" + mock_embed = MagicMock(return_value=[0.4, 0.5, 0.6]) + self.node.do_embedding(mock_embed) + mock_embed.assert_called_once_with(self.node.get_content(MetadataMode.EMBED)) + def test_node_creation(self): """Test the creation of a DocNode.""" assert self.node.text == self.text