Skip to content

Commit

Permalink
Merge branch 'main' into ODSC-63451/oci_odsc_embedding
Browse files Browse the repository at this point in the history
  • Loading branch information
mrDzurb authored Dec 12, 2024
2 parents a01f809 + 095d410 commit 097e144
Show file tree
Hide file tree
Showing 6 changed files with 273 additions and 22 deletions.
4 changes: 2 additions & 2 deletions llama-index-core/llama_index/core/multi_modal_llms/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
)
from llama_index.core.instrumentation import DispatcherSpanMixin
from llama_index.core.llms.callbacks import llm_chat_callback, llm_completion_callback
from llama_index.core.schema import BaseComponent, ImageNode
from llama_index.core.schema import BaseComponent, ImageDocument, ImageNode


class MultiModalLLMMetadata(BaseModel):
Expand Down Expand Up @@ -217,7 +217,7 @@ def _validate_component_inputs(self, input: Dict[str, Any]) -> Dict[str, Any]:
if not isinstance(input["image_documents"], list):
raise ValueError("image_documents must be a list.")
for doc in input["image_documents"]:
if not isinstance(doc, ImageNode):
if not isinstance(doc, (ImageDocument, ImageNode)):
raise ValueError(
"image_documents must be a list of ImageNode objects."
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def build_nodes_from_splits(
image_url=document.image_url,
excluded_embed_metadata_keys=document.excluded_embed_metadata_keys,
excluded_llm_metadata_keys=document.excluded_llm_metadata_keys,
metadata_seperator=document.metadata_seperator,
metadata_seperator=document.metadata_separator,
metadata_template=document.metadata_template,
text_template=document.text_template,
relationships=relationships,
Expand Down
133 changes: 120 additions & 13 deletions llama-index-core/llama_index/core/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
)

import filetype
import requests
from dataclasses_json import DataClassJsonMixin
from deprecated import deprecated
from typing_extensions import Self
Expand Down Expand Up @@ -507,24 +508,29 @@ class MediaResource(BaseModel):
url: AnyUrl | None = Field(default=None, description="URL to reach this resource.")

@model_validator(mode="after")
def guess_mimetype(self) -> Self:
"""Guess the mimetype when possible.
def data_to_base64(self) -> Self:
"""If binary data was passed, store the resource as base64 and guess the mimetype when possible.
In case the model was built passing its content but without a mimetype,
In case the model was built passing binary data but without a mimetype,
we try to guess it using the filetype library. To avoid resource-intense
operations, we won't load the path or the URL to guess the mimetype.
"""
if not self.data or self.mimetype:
if not self.data:
return self

try:
# Check if data is already base64 encoded
decoded_data = base64.b64decode(self.data)
except Exception:
decoded_data = self.data
# Not base64 - encode it
self.data = base64.b64encode(self.data)

if not self.mimetype:
guess = filetype.guess(decoded_data)
self.mimetype = guess.mime if guess else None
except Exception as e:
logging.debug("Data is not base64 encoded, cannot guess mimetype")
finally:
return self

return self

@property
def hash(self) -> str:
Expand Down Expand Up @@ -597,7 +603,7 @@ def get_content(self, metadata_mode: MetadataMode = MetadataMode.NONE) -> str:
def set_content(self, value: str) -> None:
"""Set the text content of the node.
Provided for backward compatibility, set self.text instead.
Provided for backward compatibility, set self.text_resource instead.
"""
self.text_resource = MediaResource(text=value)

Expand Down Expand Up @@ -628,7 +634,10 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
"""Make TextNode forward-compatible with Node by supporting 'text_resource' in the constructor."""
if "text_resource" in kwargs:
tr = kwargs.pop("text_resource")
kwargs["text"] = tr["text"]
if isinstance(tr, MediaResource):
kwargs["text"] = tr.text
else:
kwargs["text"] = tr["text"]
super().__init__(*args, **kwargs)

text: str = Field(default="", description="Text content of the node.")
Expand Down Expand Up @@ -976,7 +985,7 @@ def __str__(self) -> str:
version="0.12.2",
reason="'get_doc_id' is deprecated, access the 'id_' property instead.",
)
def get_doc_id(self) -> str:
def get_doc_id(self) -> str: # pragma: nocover
return self.id_

def to_langchain_format(self) -> LCDocument:
Expand Down Expand Up @@ -1096,13 +1105,111 @@ def from_cloud_document(
)


class ImageDocument(Document, ImageNode):
"""Data document containing an image."""
class ImageDocument(Document):
"""Backward compatible wrapper around Document containing an image."""

def __init__(self, **kwargs: Any) -> None:
image = kwargs.pop("image", None)
image_path = kwargs.pop("image_path", None)
image_url = kwargs.pop("image_url", None)
image_mimetype = kwargs.pop("image_mimetype", None)
text_embedding = kwargs.pop("text_embedding", None)

if image:
kwargs["image_resource"] = MediaResource(data=image)
elif image_path:
kwargs["image_resource"] = MediaResource(path=image_path)
elif image_url:
kwargs["image_resource"] = MediaResource(url=image_url)

super().__init__(**kwargs)

@property
def image(self) -> str | None:
if self.image_resource and self.image_resource.data:
return self.image_resource.data.decode("utf-8")
return None

@image.setter
def image(self, image: str) -> None:
self.image_resource = MediaResource(data=image.encode("utf-8"))

@property
def image_path(self) -> str | None:
if self.image_resource and self.image_resource.path:
return str(self.image_resource.path)
return None

@image_path.setter
def image_path(self, image_path: str) -> None:
self.image_resource = MediaResource(path=Path(image_path))

@property
def image_url(self) -> str | None:
if self.image_resource and self.image_resource.url:
return str(self.image_resource.url)
return None

@image_url.setter
def image_url(self, image_url: str) -> None:
self.image_resource = MediaResource(url=AnyUrl(url=image_url))

@property
def image_mimetype(self) -> str | None:
if self.image_resource:
return self.image_resource.mimetype
return None

@image_mimetype.setter
def image_mimetype(self, image_mimetype: str) -> None:
if self.image_resource:
self.image_resource.mimetype = image_mimetype

@property
def text_embedding(self) -> list[float] | None:
if self.text_resource and self.text_resource.embeddings:
return self.text_resource.embeddings.get("dense")
return None

@text_embedding.setter
def text_embedding(self, embeddings: list[float]) -> None:
if self.text_resource:
if self.text_resource.embeddings is None:
self.text_resource.embeddings = {}
self.text_resource.embeddings["dense"] = embeddings

@classmethod
def class_name(cls) -> str:
return "ImageDocument"

def resolve_image(self, as_base64: bool = False) -> BytesIO:
"""Resolve an image such that PIL can read it.
Args:
as_base64 (bool): whether the resolved image should be returned as base64-encoded bytes
"""
if self.image_resource is None:
return BytesIO()

if self.image_resource.data is not None:
if as_base64:
return BytesIO(self.image_resource.data)
return BytesIO(base64.b64decode(self.image_resource.data))
elif self.image_resource.path is not None:
img_bytes = self.image_resource.path.read_bytes()
if as_base64:
return BytesIO(base64.b64encode(img_bytes))
return BytesIO(img_bytes)
elif self.image_resource.url is not None:
# load image from URL
response = requests.get(str(self.image_resource.url))
img_bytes = response.content
if as_base64:
return BytesIO(base64.b64encode(img_bytes))
return BytesIO(img_bytes)
else:
raise ValueError("No image found in the chat message!")


@dataclass
class QueryBundle(DataClassJsonMixin):
Expand Down
6 changes: 4 additions & 2 deletions llama-index-core/tests/schema/test_media_resource.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from pathlib import Path

from llama_index.core.bridge.pydantic import AnyUrl
from llama_index.core.schema import MediaResource

Expand All @@ -21,10 +23,10 @@ def test_hash():
assert (
MediaResource(
data=b"test bytes",
path="foo/bar/baz",
path=Path("foo/bar/baz"),
url=AnyUrl("http://example.com"),
text="some text",
).hash
== "7ac964db7843a9ffb37cda7b5b9822b0f84111d6a271b4991dd26d1fc68490d3"
== "04414a5f03ad7fa055229b4d3690d47427cb0b65bc7eb8f770d1ecbd54ab4909"
)
assert MediaResource().hash == ""
2 changes: 1 addition & 1 deletion llama-index-core/tests/schema/test_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,5 +17,5 @@ def test_hash():
node.text_resource = MediaResource(text="some text", mimetype="text/plain")
node.video_resource = MediaResource(data=b"some video", mimetype="video/mpeg")
assert (
node.hash == "ee411edd3dffb27470eef165ccf4df9fabaa02e7c7c39415950d3ac4d7e35e61"
node.hash == "6f08712269634de7e53e62a3aaee59d60e9a32a43bc05284a21244f960f0cda4"
)
Loading

0 comments on commit 097e144

Please sign in to comment.