diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index cc84e4e8..2abe2275 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -61,6 +61,7 @@ jobs: run: | set -ex cd ${{ env.CI_PATH }} + pip install -r tests/requirements.txt realpath . env | grep '^SCC' export LAZYLLM_SCO_ENV_NAME=lazyllm diff --git a/lazyllm/tools/rag/__init__.py b/lazyllm/tools/rag/__init__.py index 2319a7fb..d92a70f9 100644 --- a/lazyllm/tools/rag/__init__.py +++ b/lazyllm/tools/rag/__init__.py @@ -4,6 +4,9 @@ from .transform import SentenceSplitter, LLMParser, NodeTransform from .index import register_similarity from .store import DocNode +from .readers import (PDFReader, DocxReader, HWPReader, PPTXReader, ImageReader, IPYNBReader, EpubReader, + MarkdownReader, MboxReader, PandasCSVReader, PandasExcelReader, VideoAudioReader) +from .dataReader import SimpleDirectoryReader __all__ = [ @@ -16,4 +19,17 @@ "register_similarity", "register_reranker", "DocNode", + "PDFReader", + "DocxReader", + "HWPReader", + "PPTXReader", + "ImageReader", + "IPYNBReader", + "EpubReader", + "MarkdownReader", + "MboxReader", + "PandasCSVReader", + "PandasExcelReader", + "VideoAudioReader", + "SimpleDirectoryReader", ] diff --git a/lazyllm/tools/rag/dataReader.py b/lazyllm/tools/rag/dataReader.py new file mode 100644 index 00000000..c2cc20f3 --- /dev/null +++ b/lazyllm/tools/rag/dataReader.py @@ -0,0 +1,260 @@ +""" +The overall process of SimpleDirectoryReader is borrowed from LLAMA_INDEX, but we have added a customized part +based on it, that is, allowing users to register custom rules instead of processing only based on file suffixes. +""" +import os +import mimetypes +import multiprocessing +import fnmatch +from tqdm import tqdm +from datetime import datetime +from functools import reduce +from itertools import repeat +from typing import Dict, Optional, List, Callable, Type +from pathlib import Path, PurePosixPath, PurePath +from fsspec import AbstractFileSystem +from lazyllm import ModuleBase, LOG +from .store import DocNode +from .readers import (ReaderBase, PDFReader, DocxReader, HWPReader, PPTXReader, ImageReader, IPYNBReader, + EpubReader, MarkdownReader, MboxReader, PandasCSVReader, PandasExcelReader, VideoAudioReader, + get_default_fs, is_default_fs) + +def _file_timestamp_format(timestamp: float, include_time: bool = False) -> Optional[str]: + try: + if include_time: + return datetime.utcfromtimestamp(timestamp).strftime("%Y-%m-%dT%H:%M:%SZ") + return datetime.fromtimestamp(timestamp).strftime("%Y-%m-%d") + except Exception: + return None + +class _DefaultFileMetadataFunc: + def __init__(self, fs: Optional[AbstractFileSystem] = None): + self._fs = fs or get_default_fs() + + def __call__(self, file_path: str) -> Dict: + stat_result = self._fs.stat(file_path) + + try: + file_name = os.path.basename(str(stat_result['name'])) + except Exception: + file_name = os.path.basename(file_path) + + creation_date = _file_timestamp_format(stat_result.get("created")) + last_modified_date = _file_timestamp_format(stat_result.get("mtime")) + last_accessed_date = _file_timestamp_format(stat_result.get("atime")) + default_meta = { + "file_path": file_path, + "file_name": file_name, + "file_type": mimetypes.guess_type(file_path)[0], + "file_size": stat_result.get("size"), + "creation_date": creation_date, + "last_modified_date": last_modified_date, + "last_accessed_date": last_accessed_date, + } + + return {meta_key: meta_value for meta_key, meta_value in default_meta.items() if meta_value is not None} + +class SimpleDirectoryReader(ModuleBase): + default_file_readers: Dict[str, Type[ReaderBase]] = { + "*.pdf": PDFReader, + "*.docx": DocxReader, + "*.hwp": HWPReader, + "*.pptx": PPTXReader, + "*.ppt": PPTXReader, + "*.pptm": PPTXReader, + "*.gif": ImageReader, + "*.jpeg": ImageReader, + "*.jpg": ImageReader, + "*.png": ImageReader, + "*.webp": ImageReader, + "*.ipynb": IPYNBReader, + "*.epub": EpubReader, + "*.md": MarkdownReader, + "*.mbox": MboxReader, + "*.csv": PandasCSVReader, + "*.xls": PandasExcelReader, + "*.xlsx": PandasExcelReader, + "*.mp3": VideoAudioReader, + "*.mp4": VideoAudioReader, + } + + def __init__(self, input_dir: Optional[str] = None, input_files: Optional[List] = None, + exclude: Optional[List] = None, exclude_hidden: bool = True, recursive: bool = False, + encoding: str = "utf-8", filename_as_id: bool = False, required_exts: Optional[List[str]] = None, + file_extractor: Optional[Dict[str, Callable]] = None, fs: Optional[AbstractFileSystem] = None, + file_metadata: Optional[Callable[[str], Dict]] = None, num_files_limit: Optional[int] = None, + return_trace: bool = False) -> None: + super().__init__(return_trace=return_trace) + + if (not input_dir and not input_files) or (input_dir and input_files): + raise ValueError("Must provide either `input_dir` or `input_files`.") + + self._fs = fs or get_default_fs() + self._encoding = encoding + + self._exclude = exclude + self._recursive = recursive + self._exclude_hidden = exclude_hidden + self._required_exts = required_exts + self._num_files_limit = num_files_limit + self._Path = Path if is_default_fs(self._fs) else PurePosixPath + + if input_files: + self._input_files = [] + for path in input_files: + if not self._fs.isfile(path): + raise ValueError(f"File {path} does not exist.") + input_file = self._Path(path) + self._input_files.append(input_file) + elif input_dir: + if not self._fs.isdir(input_dir): + raise ValueError(f"Directory {input_dir} does not exist.") + self._input_dir = self._Path(input_dir) + self._input_files = self._add_files(self._input_dir) + + self._file_extractor = file_extractor or {} + + self._file_metadata = file_metadata or _DefaultFileMetadataFunc(self._fs) + self._filename_as_id = filename_as_id + + def _add_files(self, input_dir: Path) -> List[Path]: # noqa: C901 + all_files = set() + rejected_files = set() + rejected_dirs = set() + + if self._exclude is not None: + for excluded_pattern in self._exclude: + if self._recursive: + excluded_glob = self._Path(input_dir) / self._Path("**") / excluded_pattern + else: + excluded_glob = self._Path(input_dir) / excluded_pattern + for file in self._fs.glob(str(excluded_glob)): + if self._fs.isdir(file): + rejected_dirs.add(self._Path(file)) + else: + rejected_files.add(self._Path(file)) + + file_refs: List[str] = [] + if self._recursive: + file_refs = self._fs.glob(str(input_dir) + "/**/*") + else: + file_refs = self._fs.glob(str(input_dir) + "/*") + + for ref in file_refs: + ref = self._Path(ref) + is_dir = self._fs.isdir(ref) + skip_hidden = self._exclude_hidden and self._is_hidden(ref) + skip_bad_exts = (self._required_exts is not None and ref.suffix not in self._required_exts) + skip_excluded = ref in rejected_files + if not skip_excluded: + if is_dir: + ref_parent_dir = ref + else: + ref_parent_dir = self._fs._parent(ref) + for rejected_dir in rejected_dirs: + if str(ref_parent_dir).startswith(str(rejected_dir)): + skip_excluded = True + LOG.warning(f"Skipping {ref} because it in parent dir " + f"{ref_parent_dir} which is in {rejected_dir}.") + break + + if is_dir or skip_hidden or skip_bad_exts or skip_excluded: + continue + else: + all_files.add(ref) + + new_input_files = sorted(all_files) + + if len(new_input_files) == 0: + raise ValueError(f"No files found in {input_dir}.") + if self._num_files_limit is not None and self._num_files_limit > 0: + new_input_files = new_input_files[0: self._num_files_limit] + + LOG.debug(f"[SimpleDirectoryReader] Total files add: {len(new_input_files)}") + + LOG.info(f"input_files: {new_input_files}") + return new_input_files + + def _is_hidden(self, path: Path) -> bool: + return any(part.startswith(".") and part not in [".", ".."] for part in path.parts) + + def _exclude_metadata(self, documents: List[DocNode]) -> List[DocNode]: + for doc in documents: + doc.excluded_embed_metadata_keys.extend( + ["file_name", "file_type", "file_size", "creation_date", + "last_modified_date", "last_accessed_date"]) + doc.excluded_llm_metadata_keys.extend( + ["file_name", "file_type", "file_size", "creation_date", + "last_modified_date", "last_accessed_date"]) + return documents + + @staticmethod + def load_file(input_file: Path, file_metadata: Callable[[str], Dict], file_extractor: Dict[str, Callable], + filename_as_id: bool = False, encoding: str = "utf-8", pathm: PurePath = Path, + fs: Optional[AbstractFileSystem] = None) -> List[DocNode]: + metadata: Optional[dict] = None + documents: List[DocNode] = [] + + if file_metadata is not None: metadata = file_metadata(str(input_file)) + + file_reader_patterns = list(file_extractor.keys()) + + for pattern in file_reader_patterns: + pt = str(pathm(pattern)) + match_pattern = pt if pt.startswith("*") else os.path.join(str(pathm.cwd()), pt) + if fnmatch.fnmatch(input_file, match_pattern): + reader = file_extractor[pattern] + reader = reader() if isinstance(reader, type) else reader + kwargs = {"extra_info": metadata} + if fs and not is_default_fs(fs): kwargs['fs'] = fs + docs = reader(input_file, **kwargs) + + if filename_as_id: + for i, doc in enumerate(docs): + doc.uid = f"{input_file!s}_index_{i}" + documents.extend(docs) + break + else: + fs = fs or get_default_fs() + with fs.open(input_file, encoding=encoding) as f: + data = f.read().decode(encoding) + + doc = DocNode(text=data, metadata=metadata or {}) + if filename_as_id: doc.uid = str(input_file) + documents.append(doc) + + return documents + + def _load_data(self, show_progress: bool = False, num_workers: Optional[int] = None, + fs: Optional[AbstractFileSystem] = None) -> List[DocNode]: + documents = [] + + fs = fs or self._fs + process_file = self._input_files + file_readers = self._file_extractor.copy() + for key, func in self.default_file_readers.items(): + if key not in file_readers: file_readers[key] = func + + if num_workers and num_workers >= 1: + if num_workers > multiprocessing.cpu_count(): + LOG.warning("Specified num_workers exceed number of CPUs in the system. " + "Setting `num_workers` down to the maximum CPU count.") + with multiprocessing.get_context("spawn").Pool(num_workers) as p: + results = p.starmap(SimpleDirectoryReader.load_file, + zip(process_file, repeat(self._file_metadata), repeat(file_readers), + repeat(self._filename_as_id), repeat(self._encoding), repeat(self._Path), + repeat(self._fs))) + documents = reduce(lambda x, y: x + y, results) + else: + if show_progress: + process_file = tqdm(self._input_files, desc="Loading files", unit="file") + for input_file in process_file: + documents.extend( + SimpleDirectoryReader.load_file(input_file=input_file, file_metadata=self._file_metadata, + file_extractor=file_readers, filename_as_id=self._filename_as_id, + encoding=self._encoding, pathm=self._Path, fs=self._fs)) + + return self._exclude_metadata(documents) + + def forward(self, *args, **kwargs) -> List[DocNode]: + return self._load_data(*args, **kwargs) diff --git a/lazyllm/tools/rag/data_loaders.py b/lazyllm/tools/rag/data_loaders.py index e8135ae8..61d0392a 100644 --- a/lazyllm/tools/rag/data_loaders.py +++ b/lazyllm/tools/rag/data_loaders.py @@ -1,28 +1,26 @@ -from typing import List, Optional +from typing import List, Optional, Dict from .store import DocNode, LAZY_ROOT_NAME from lazyllm import LOG - +from .dataReader import SimpleDirectoryReader class DirectoryReader: - def __init__(self, input_files: List[str]) -> None: + def __init__(self, input_files: List[str], local_readers: Optional[Dict] = None, + global_readers: Optional[Dict] = None) -> None: self._input_files = input_files + self._local_readers = local_readers + self._global_readers = global_readers def load_data(self, input_files: Optional[List[str]] = None) -> List[DocNode]: input_files = input_files or self._input_files - from llama_index.core import SimpleDirectoryReader - + file_readers = self._local_readers.copy() + for key, func in self._global_readers.items(): + if key not in file_readers: file_readers[key] = func LOG.info(f"DirectoryReader loads data, input files: {input_files}") - reader = SimpleDirectoryReader(input_files=input_files) + reader = SimpleDirectoryReader(input_files=input_files, file_extractor=file_readers) nodes: List[DocNode] = [] - for doc in reader.load_data(): - node = DocNode( - text=doc.text, - group=LAZY_ROOT_NAME, - ) - node.metadata = doc.metadata - node.excluded_embed_metadata_keys = doc.excluded_embed_metadata_keys - node.excluded_llm_metadata_keys = doc.excluded_llm_metadata_keys - nodes.append(node) + for doc in reader(): + doc.group = LAZY_ROOT_NAME + nodes.append(doc) if not nodes: LOG.warning( f"No nodes load from path {self.input_files}, please check your data path." diff --git a/lazyllm/tools/rag/doc_impl.py b/lazyllm/tools/rag/doc_impl.py index dffd7e16..08025f6d 100644 --- a/lazyllm/tools/rag/doc_impl.py +++ b/lazyllm/tools/rag/doc_impl.py @@ -25,9 +25,10 @@ def wrapper(*args, **kwargs) -> List[float]: class DocImpl: - def __init__(self, embed, doc_files=Optional[List[str]], **kwargs): + def __init__(self, embed, doc_files=Optional[List[str]], local_readers: Optional[Dict] = None, + global_readers: Optional[Dict] = None, **kwargs): super().__init__() - self.directory_reader = DirectoryReader(doc_files) + self.directory_reader = DirectoryReader(doc_files, local_readers=local_readers, global_readers=global_readers) self.node_groups: Dict[str, Dict] = {LAZY_ROOT_NAME: {}} self._create_node_group_default() self.embed = embed_wrapper(embed) diff --git a/lazyllm/tools/rag/document.py b/lazyllm/tools/rag/document.py index 7e4a311a..4c27ea85 100644 --- a/lazyllm/tools/rag/document.py +++ b/lazyllm/tools/rag/document.py @@ -1,7 +1,7 @@ from functools import partial import os -from typing import Callable, Optional +from typing import Callable, Optional, Dict import lazyllm from lazyllm import ModuleBase, ServerModule, TrainableModule @@ -12,6 +12,8 @@ class Document(ModuleBase): + _registered_file_reader: Dict[str, Callable] = {} + def __init__(self, dataset_path: str, embed: Optional[TrainableModule] = None, create_ui: bool = True, launcher=None): super().__init__() @@ -21,15 +23,14 @@ def __init__(self, dataset_path: str, embed: Optional[TrainableModule] = None, dataset_path = defatult_path self._create_ui = create_ui launcher = launcher if launcher else lazyllm.launchers.remote(sync=False) + self._local_file_reader: Dict[str, Callable] = {} + self._impl = DocGroupImpl(dataset_path=dataset_path, embed=embed, local_readers=self._local_file_reader, + global_readers=self._registered_file_reader) if create_ui: - self._impl = DocGroupImpl(dataset_path=dataset_path, embed=embed) doc_manager = DocManager(self._impl) self.doc_server = ServerModule(doc_manager, launcher=launcher) - self.web = DocWebModule(doc_server=self.doc_server) - else: - self._impl = DocGroupImpl(dataset_path=dataset_path, embed=embed) def forward(self, func_name: str, *args, **kwargs): if self._create_ui: @@ -51,3 +52,17 @@ def create_node_group( self, name: str, transform: Callable, parent: str = LAZY_ROOT_NAME, **kwargs ) -> None: self._impl.create_node_group(name, transform, parent, **kwargs) + + def add_reader(self, pattern: str, func: Callable): + self._local_file_reader[pattern] = func + + @classmethod + def register_global_reader(cls, pattern: str, func: Optional[Callable] = None): + if func is not None: + cls._registered_file_reader[pattern] = func + + def decorator(klass): + if callable(klass): cls._registered_file_reader[pattern] = klass + else: raise TypeError(f"The registered object {klass} is not a callable object.") + return klass + return decorator diff --git a/lazyllm/tools/rag/group_doc.py b/lazyllm/tools/rag/group_doc.py index aac17afb..b19efbce 100644 --- a/lazyllm/tools/rag/group_doc.py +++ b/lazyllm/tools/rag/group_doc.py @@ -1,6 +1,6 @@ import os import shutil -from typing import Callable, List +from typing import Callable, List, Optional, Dict from .doc_impl import DocImpl import lazyllm from .store import LAZY_ROOT_NAME @@ -10,7 +10,8 @@ class DocGroupImpl(lazyllm.ModuleBase): - def __init__(self, dataset_path, embed) -> None: + def __init__(self, dataset_path, embed, local_readers: Optional[Dict] = None, + global_readers: Optional[Dict] = None) -> None: super().__init__() self._dataset_path = dataset_path self._embed = embed @@ -22,7 +23,8 @@ def __init__(self, dataset_path, embed) -> None: file_paths = self._list_all_files(self.dataset_path, lambda x: DATA_DIR in x) self._impl: DocImpl = DocImpl( - doc_files=file_paths, embed=self._embed, doc_name="lazyllm_doc" + doc_files=file_paths, embed=self._embed, local_readers=local_readers, global_readers=global_readers, + doc_name="lazyllm_doc" ) @property diff --git a/lazyllm/tools/rag/readers/__init__.py b/lazyllm/tools/rag/readers/__init__.py new file mode 100644 index 00000000..ff3182b6 --- /dev/null +++ b/lazyllm/tools/rag/readers/__init__.py @@ -0,0 +1,30 @@ +from .readerBase import LazyLLMReaderBase as ReaderBase, get_default_fs, is_default_fs +from .pdfReader import PDFReader +from .docxReader import DocxReader +from .hwpReader import HWPReader +from .pptxReader import PPTXReader +from .imageReader import ImageReader +from .ipynbReader import IPYNBReader +from .epubReader import EpubReader +from .markdownReader import MarkdownReader +from .mboxreader import MboxReader +from .pandasReader import PandasCSVReader, PandasExcelReader +from .videoAudioReader import VideoAudioReader + +__all__ = [ + "ReaderBase", + "get_default_fs", + "is_default_fs", + "PDFReader", + "DocxReader", + "HWPReader", + "PPTXReader", + "ImageReader", + "IPYNBReader", + "EpubReader", + "MarkdownReader", + "MboxReader", + "PandasCSVReader", + "PandasExcelReader", + "VideoAudioReader", +] diff --git a/lazyllm/tools/rag/readers/docxReader.py b/lazyllm/tools/rag/readers/docxReader.py new file mode 100644 index 00000000..ff472013 --- /dev/null +++ b/lazyllm/tools/rag/readers/docxReader.py @@ -0,0 +1,25 @@ +from pathlib import Path +from fsspec import AbstractFileSystem +from typing import Dict, Optional, List + +from .readerBase import LazyLLMReaderBase +from ..store import DocNode + +class DocxReader(LazyLLMReaderBase): + def _load_data(self, file: Path, extra_info: Optional[Dict] = None, + fs: Optional[AbstractFileSystem] = None) -> List[DocNode]: + if not isinstance(file, Path): file = Path(file) + try: + import docx2txt + except ImportError: + raise ImportError("docx2txt is required to read Microsoft Word files: `pip install docx2txt`") + + if fs: + with fs.open(file) as f: + text = docx2txt.process(f) + else: + text = docx2txt.process(file) + metadata = {"file_name": file.name} + if extra_info is not None: metadata.update(extra_info) + + return [DocNode(text=text, metadata=metadata)] diff --git a/lazyllm/tools/rag/readers/epubReader.py b/lazyllm/tools/rag/readers/epubReader.py new file mode 100644 index 00000000..0e208dbf --- /dev/null +++ b/lazyllm/tools/rag/readers/epubReader.py @@ -0,0 +1,33 @@ +from pathlib import Path +from typing import Dict, List, Optional +from fsspec import AbstractFileSystem + +from .readerBase import LazyLLMReaderBase +from ..store import DocNode +from lazyllm import LOG + +class EpubReader(LazyLLMReaderBase): + def _load_data(self, file: Path, extra_info: Optional[Dict] = None, + fs: Optional[AbstractFileSystem] = None) -> List[DocNode]: + try: + import ebooklib + import html2text + from ebooklib import epub + except ImportError: + raise ImportError("Please install extra dependencies that are required " + "for the EpubReader: `pip install EbookLib html2text`") + + if not isinstance(file, Path): file = Path(file) + + if fs: + LOG.warning("fs was specified but EpubReader doesn't support loading from " + "fsspec filesystems. Will load from local filesystem instead.") + + text_list = [] + book = epub.read_epub(file, options={"ignore_ncs": True}) + + for item in book.get_items(): + if item.get_type() == ebooklib.ITEM_DOCUMENT: + text_list.append(html2text.html2text(item.get_content().decode("utf-8"))) + text = "\n".join(text_list) + return [DocNode(text=text, metadata=extra_info or {})] diff --git a/lazyllm/tools/rag/readers/hwpReader.py b/lazyllm/tools/rag/readers/hwpReader.py new file mode 100644 index 00000000..9678336e --- /dev/null +++ b/lazyllm/tools/rag/readers/hwpReader.py @@ -0,0 +1,92 @@ +from fsspec import AbstractFileSystem +from pathlib import Path +import struct +from typing import Optional, Dict, List, Any +import zlib + +from .readerBase import LazyLLMReaderBase +from ..store import DocNode +from lazyllm import LOG + +class HWPReader(LazyLLMReaderBase): + def __init__(self, return_trace: bool = True) -> None: + super().__init__(return_trace=return_trace) + self._FILE_HEADER_SECTION = "FileHeader" + self._HWP_SUMMARY_SECTION = "\x05HwpSummaryInformation" + self._SECTION_NAME_LENGTH = len("Section") + self._BODYTEXT_SECTION = "BodyText" + self._HWP_TEXT_TAGS = [67] + self._text = "" + + def _load_data(self, file: Path, extra_info: Optional[Dict] = None, + fs: Optional[AbstractFileSystem] = None) -> List[DocNode]: + try: + import olefile + except ImportError: + raise ImportError("olefile is required to read hwp files: `pip install olefile`") + + if fs: + LOG.warning("fs was specified but HWPReader doesn't support loading from " + "fsspec filesystems. Will load from local filesystem instead.") + + if not isinstance(file, Path): file = Path(file) + + load_file = olefile.OleFileIO(file) + file_dir = load_file.listdir() + if self._is_valid(file_dir) is False: raise Exception("Not Valid HwpFile") + + result_text = self._get_text(load_file, file_dir) + return [DocNode(text=result_text, metadata=extra_info or {})] + + def _is_valid(self, dirs: List[str]) -> bool: + if [self._FILE_HEADER_SECTION] not in dirs: return False + return [self._HWP_SUMMARY_SECTION] in dirs + + def _text_to_docnode(self, text: str, extra_info: Optional[Dict] = None) -> DocNode: + return DocNode(text=text, metadata=extra_info or {}) + + def _get_text(self, load_file: Any, file_dirs: List[str]) -> str: + sections = self._get_body_sections(file_dirs) + text = "" + for section in sections: + text += self._get_text_from_section(load_file, section) + text += "\n" + + self._text = text + return self._text + + def _get_body_sections(self, dirs: List[str]) -> List[str]: + m = [] + for d in dirs: + if d[0] == self._BODYTEXT_SECTION: + m.append(int(d[1][self._SECTION_NAME_LENGTH:])) + + return ["BodyText/Section" + str(x) for x in sorted(m)] + + def _is_compressed(self, load_file: Any) -> bool: + header = load_file.openstream("FileHeader") + header_data = header.read() + return (header_data[36] & 1) == 1 + + def _get_text_from_section(self, load_file: Any, section: str) -> str: + bodytext = load_file.openstream(section) + data = bodytext.read() + + unpacked_data = (zlib.decompress(data, -15) if self._is_compressed(load_file) else data) + size = len(unpacked_data) + + i = 0 + text = "" + while i < size: + header = struct.unpack_from("> 10) & 0x3FF + rec_len = (header >> 20) & 0xFFF + + if rec_type in self._HWP_TEXT_TAGS: + rec_data = unpacked_data[i + 4: i + 4 + rec_len] + text += rec_data.decode("utf-16") + text += "\n" + + i += 4 + rec_len + return text diff --git a/lazyllm/tools/rag/readers/imageReader.py b/lazyllm/tools/rag/readers/imageReader.py new file mode 100644 index 00000000..ee610bbc --- /dev/null +++ b/lazyllm/tools/rag/readers/imageReader.py @@ -0,0 +1,102 @@ +import base64 +import re +from io import BytesIO +from pathlib import Path +from typing import Dict, List, Optional, cast +from fsspec import AbstractFileSystem +from PIL import Image + +from .readerBase import LazyLLMReaderBase, infer_torch_device +from ..store import DocNode + +def img_2_b64(image: Image, format: str = "JPEG") -> str: + buff = BytesIO() + image.save(buff, format=format) + return cast(str, base64.b64encode(buff.getvalue())) + +def b64_2_img(data: str) -> Image: + buff = BytesIO(base64.b64decode(data)) + return Image.open(buff) + +class ImageReader(LazyLLMReaderBase): + def __init__(self, parser_config: Optional[Dict] = None, keep_image: bool = False, parse_text: bool = False, + text_type: str = "text", pytesseract_model_kwargs: Optional[Dict] = None, + return_trace: bool = True) -> None: + super().__init__(return_trace=return_trace) + self._text_type = text_type + if parser_config is None and parse_text: + if text_type == "plain_text": + try: + import pytesseract + except ImportError: + raise ImportError("Please install extra dependencies that are required for the ImageReader " + "when text_type is 'plain_text': `pip install pytesseract`") + + processor = None + model = pytesseract + else: + try: + import sentencepiece # noqa + import torch # noqa + from PIL import Image # noqa + from transformers import DonutProcessor, VisionEncoderDecoderModel + except ImportError: + raise ImportError("Please install extra dependencies that are required for the " + "ImageCaptionReader: `pip install torch transformers sentencepiece Pillow`") + + processor = DonutProcessor.from_pretrained("naver-clova-ix/donut-base-finetuned-cord-v2") + model = VisionEncoderDecoderModel.from_pretrained("naver-clova-ix/donut-base-finetuned-cord-v2") + parser_config = {'processor': processor, 'model': model} + + self._parser_config = parser_config + self._keep_image = keep_image + self._parse_text = parse_text + self._pytesseract_model_kwargs = pytesseract_model_kwargs or {} + + def _load_data(self, file: Path, extra_info: Optional[Dict] = None, + fs: Optional[AbstractFileSystem] = None) -> List[DocNode]: + if not isinstance(file, Path): file = Path(file) + + if fs: + with fs.open(path=file) as f: + image = Image.open(f.read()) + else: + image = Image.open(file) + + if image.mode != "RGB": image = image.convert("RGB") + + image_str: Optional[str] = None # noqa + if self._keep_image: image_str = img_2_b64(image) # noqa + + text_str: str = "" + if self._parse_text: + assert self._parser_config is not None + model = self._parser_config["model"] + processor = self._parser_config["processor"] + + if processor: + device = infer_torch_device() + model.to(device) + + task_prompt = "" + decoder_input_ids = processor.tokenizer(task_prompt, add_special_tokens=False, + return_tensors='pt').input_ids + pixel_values = processor(image, return_tensors='pt').pixel_values + + output = model.generate(pixel_values.to(device), decoder_input_ids=decoder_input_ids.to(device), + max_length=model.decoder.config.max_position_embeddings, early_stopping=True, + pad_token_id=processor.tokenizer.pad_token_id, + eos_token_id=processor.tokenizer.eos_token_id, use_cache=True, num_beams=3, + bad_words_ids=[[processor.tokenizer.unk_token_id]], + return_dict_in_generate=True) + + sequence = processor.batch_decode(output.sequences)[0] + sequence = sequence.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "") + text_str = re.sub(r"<.*?>", "", sequence, count=1).strip() + else: + import pytesseract + + model = cast(pytesseract, self._parser_config['model']) + text_str = model.image_to_string(image, **self._pytesseract_model_kwargs) + + return [DocNode(text=text_str, metadata=extra_info or {})] diff --git a/lazyllm/tools/rag/readers/ipynbReader.py b/lazyllm/tools/rag/readers/ipynbReader.py new file mode 100644 index 00000000..66c0e192 --- /dev/null +++ b/lazyllm/tools/rag/readers/ipynbReader.py @@ -0,0 +1,37 @@ +import re +from pathlib import Path +from typing import Dict, List, Optional +from fsspec import AbstractFileSystem + +from .readerBase import LazyLLMReaderBase +from ..store import DocNode + +class IPYNBReader(LazyLLMReaderBase): + def __init__(self, parser_config: Optional[Dict] = None, concatenate: bool = False, return_trace: bool = True): + super().__init__(return_trace=return_trace) + self._parser_config = parser_config + self._concatenate = concatenate + + def _load_data(self, file: Path, extra_info: Optional[Dict] = None, + fs: Optional[AbstractFileSystem] = None) -> List[DocNode]: + if not isinstance(file, Path): file = Path(file) + + if file.name.endswith(".ipynb"): + try: + import nbconvert + except ImportError: + raise ImportError("Please install nbconvert `pip install nbconvert`") + + if fs: + with fs.open(file, encoding='utf-8') as f: + doc_str = nbconvert.exporters.ScriptExporter().from_file(f)[0] + else: + doc_str = nbconvert.exporters.ScriptExporter().from_file(file)[0] + + splits = re.split(r"In\[\d+\]:", doc_str) + splits.pop(0) + + if self._concatenate: docs = [DocNode(text="\n\n".join(splits), metadata=extra_info or {})] + else: docs = [DocNode(text=s, metadata=extra_info or {}) for s in splits] + + return docs diff --git a/lazyllm/tools/rag/readers/markdownReader.py b/lazyllm/tools/rag/readers/markdownReader.py new file mode 100644 index 00000000..c1748f55 --- /dev/null +++ b/lazyllm/tools/rag/readers/markdownReader.py @@ -0,0 +1,67 @@ +import re +from pathlib import Path +from fsspec import AbstractFileSystem +from fsspec.implementations.local import LocalFileSystem +from typing import Dict, List, Optional, Tuple + +from .readerBase import LazyLLMReaderBase +from ..store import DocNode + +class MarkdownReader(LazyLLMReaderBase): + def __init__(self, remove_hyperlinks: bool = True, remove_images: bool = True, return_trace: bool = True) -> None: + super().__init__(return_trace=return_trace) + self._remove_hyperlinks = remove_hyperlinks + self._remove_images = remove_images + + def _markdown_to_tups(self, markdown_text: str) -> List[Tuple[Optional[str], str]]: + markdown_tups: List[Tuple[Optional[str], str]] = [] + lines = markdown_text.split("\n") + + current_header = None + current_lines = [] + in_code_block = False + + for line in lines: + if line.startswith("```"): in_code_block = not in_code_block + + header_match = re.match(r"^#+\s", line) + if not in_code_block and header_match: + if current_header is not None or len(current_lines) > 0: + markdown_tups.append((current_header, "\n".join(current_lines))) + current_header = line + current_lines.clear() + else: + current_lines.append(line) + + markdown_tups.append((current_header, "\n".join(current_lines))) + return [(key if key is None else re.sub(r"#", "", key).strip(), re.sub(r"<.*?>", "", value),) + for key, value in markdown_tups] + + def remove_images(self, content: str) -> str: + pattern = r"!{1}\[\[(.*)\]\]" + return re.sub(pattern, "", content) + + def remove_hyperlinks(self, content: str) -> str: + pattern = r"\[(.*)\]\((.*)\)" + return re.sub(pattern, r"\1", content) + + def _parse_tups(self, filepath: Path, errors: str = "ignore", + fs: Optional[AbstractFileSystem] = None) -> List[Tuple[Optional[str], str]]: + fs = fs or LocalFileSystem() + + with fs.open(filepath, encoding="utf-8") as f: + content = f.read().decode(encoding="utf-8") + + if self._remove_hyperlinks: content = self.remove_hyperlinks(content) + if self._remove_images: content = self.remove_images(content) + return self._markdown_to_tups(content) + + def _load_data(self, file: Path, extra_info: Optional[Dict] = None, + fs: Optional[AbstractFileSystem] = None) -> List[DocNode]: + if not isinstance(file, Path): file = Path(file) + + tups = self._parse_tups(file, fs=fs) + results = [DocNode(text=value if header is None else f"\n\n{header}\n{value}", metadata=extra_info or {}) + for header, value in tups] + + return results diff --git a/lazyllm/tools/rag/readers/mboxreader.py b/lazyllm/tools/rag/readers/mboxreader.py new file mode 100644 index 00000000..567854c8 --- /dev/null +++ b/lazyllm/tools/rag/readers/mboxreader.py @@ -0,0 +1,68 @@ +from pathlib import Path +from typing import Dict, List, Optional +from fsspec import AbstractFileSystem + +from .readerBase import LazyLLMReaderBase +from ..store import DocNode +from lazyllm import LOG + +class MboxReader(LazyLLMReaderBase): + DEFAULT_MESSAGE_FORMAT: str = ( + "Date: {_date}\n" + "From: {_from}\n" + "To: {_to}\n" + "Subject: {_subject}\n" + "Content: {_content}" + ) + + def __init__(self, max_count: int = 0, message_format: str = DEFAULT_MESSAGE_FORMAT, + return_trace: bool = True) -> None: + try: + from bs4 import BeautifulSoup # noqa + except ImportError: + raise ImportError("`BeautifulSoup` package not found: `pip install beautifulsoup4`") + + super().__init__(return_trace=return_trace) + self._max_count = max_count + self._message_format = message_format + + def _load_data(self, file: Path, extra_info: Optional[Dict] = None, + fs: Optional[AbstractFileSystem] = None) -> List[DocNode]: + import mailbox + from email.parser import BytesParser + from email.policy import default + from bs4 import BeautifulSoup + + if fs: + LOG.warning("fs was specified but MboxReader doesn't support loading from " + "fsspec filesystems. Will load from local filesystem instead.") + + i = 0 + results: List[str] = [] + bytes_parser = BytesParser(policy=default).parse + mbox = mailbox.mbox(file, factory=bytes_parser) + + for _, _msg in enumerate(mbox): + try: + msg: mailbox.mboxMessage = _msg + if msg.is_multipart(): + for part in msg.walk(): + ctype = part.get_content_type() + cdispo = str(part.get("Content-Disposition")) + if ctype == "text/plain" and "attachment" not in cdispo: + content = part.get_payload(decode=True) + break + else: + content = msg.get_payload(decode=True) + + soup = BeautifulSoup(content) + stripped_content = " ".join(soup.get_text().split()) + msg_string = self._message_format.format(_date=msg["date"], _from=msg["from"], _to=msg["to"], + _subject=msg["subject"], _content=stripped_content) + results.append(msg_string) + except Exception as e: + LOG.warning(f"Failed to parse message:\n{_msg}\n with exception {e}") + + i += 1 + if self._max_count > 0 and i >= self._max_count: break + return [DocNode(text=result, metadata=extra_info or {}) for result in results] diff --git a/lazyllm/tools/rag/readers/pandasReader.py b/lazyllm/tools/rag/readers/pandasReader.py new file mode 100644 index 00000000..bbe2cb60 --- /dev/null +++ b/lazyllm/tools/rag/readers/pandasReader.py @@ -0,0 +1,71 @@ +from pathlib import Path +from typing import Dict, List, Optional +from fsspec import AbstractFileSystem +import importlib +import pandas as pd + +from .readerBase import LazyLLMReaderBase +from ..store import DocNode + +class PandasCSVReader(LazyLLMReaderBase): + def __init__(self, concat_rows: bool = True, col_joiner: str = ", ", row_joiner: str = "\n", + pandas_config: Optional[Dict] = None, return_trace: bool = True) -> None: + super().__init__(return_trace=return_trace) + self._concat_rows = concat_rows + self._col_joiner = col_joiner + self._row_joiner = row_joiner + self._pandas_config = pandas_config or {} + + def _load_data(self, file: Path, extra_info: Optional[Dict] = None, + fs: Optional[AbstractFileSystem] = None) -> List[DocNode]: + if not isinstance(file, Path): file = Path(file) + + if fs: + with fs.open(file) as f: + df = pd.read_csv(f, **self._pandas_config) + else: + df = pd.read_csv(file, **self._pandas_config) + + text_list = df.apply(lambda row: (self._col_joiner).join(row.astype(str).tolist()), axis=1).tolist() + + if self._concat_rows: return [DocNode(text=(self._row_joiner).join(text_list), metadata=extra_info or {})] + else: return [DocNode(text=text, metadata=extra_info or {}) for text in text_list] + +class PandasExcelReader(LazyLLMReaderBase): + def __init__(self, concat_rows: bool = True, sheet_name: Optional[str] = None, + pandas_config: Optional[Dict] = None, return_trace: bool = True) -> None: + super().__init__(return_trace=return_trace) + self._concat_rows = concat_rows + self._sheet_name = sheet_name + self._pandas_config = pandas_config or {} + + def _load_data(self, file: Path, extra_info: Optional[Dict] = None, + fs: Optional[AbstractFileSystem] = None) -> List[DocNode]: + openpyxl_spec = importlib.util.find_spec("openpyxl") + if openpyxl_spec is not None: pass + else: raise ImportError("Please install openpyxl to read Excel files. " + "You can install it with `pip install openpyxl`") + + if not isinstance(file, Path): file = Path(file) + if fs: + with fs.open(file) as f: + dfs = pd.read_excel(f, self._sheet_name, **self._pandas_config) + else: + dfs = pd.read_excel(file, self._sheet_name, **self._pandas_config) + + documents = [] + if isinstance(dfs, pd.DataFrame): + df = dfs.fillna("") + text_list = (df.astype(str).apply(lambda row: " ".join(row.values), axis=1).tolist()) + + if self._concat_rows: documents.append(DocNode(text="\n".join(text_list), metadata=extra_info or {})) + else: documents.extend([DocNode(text=text, metadata=extra_info or {}) for text in text_list]) + else: + for df in dfs.values(): + df = df.fillna("") + text_list = (df.astype(str).apply(lambda row: " ".join(row), axis=1).tolist()) + + if self._concat_rows: documents.append(DocNode(text="\n".join(text_list), metadata=extra_info or {})) + else: documents.extend([DocNode(text=text, metadata=extra_info or {}) for text in text_list]) + + return documents diff --git a/lazyllm/tools/rag/readers/pdfReader.py b/lazyllm/tools/rag/readers/pdfReader.py new file mode 100644 index 00000000..8982c424 --- /dev/null +++ b/lazyllm/tools/rag/readers/pdfReader.py @@ -0,0 +1,45 @@ +import io +from tenacity import retry, stop_after_attempt +from pathlib import Path +from typing import Dict, List, Optional +from fsspec import AbstractFileSystem + +from .readerBase import LazyLLMReaderBase, get_default_fs, is_default_fs +from ..store import DocNode + +RETRY_TIMES = 3 + +class PDFReader(LazyLLMReaderBase): + def __init__(self, return_full_document: bool = False, return_trace: bool = True) -> None: + super().__init__(return_trace=return_trace) + self._return_full_document = return_full_document + + @retry(stop=stop_after_attempt(RETRY_TIMES)) + def _load_data(self, file: Path, extra_info: Optional[Dict] = None, + fs: Optional[AbstractFileSystem] = None) -> List[DocNode]: + if not isinstance(file, Path): file = Path(file) + + try: + import pypdf + except ImportError: + raise ImportError("pypdf is required to read PDF files: `pip install pypdf`") + + fs = fs or get_default_fs() + with fs.open(file, 'rb') as fp: + stream = fp if is_default_fs(fs) else io.BytesIO(fp.read()) + pdf = pypdf.PdfReader(stream) + num_pages = len(pdf.pages) + docs = [] + if self._return_full_document: + metadata = {"file_name": file.name} + if extra_info is not None: metadata.update(extra_info) + text = "\n".join(pdf.pages[page].extract_text() for page in range(num_pages)) + docs.append(DocNode(text=text, metadata=metadata)) + else: + for page in range(num_pages): + page_text = pdf.pages[page].extract_text() + page_label = pdf.page_labels[page] + metadata = {"page_label": page_label, "file_name": file.name} + if extra_info is not None: metadata.update(extra_info) + docs.append(DocNode(text=page_text, metadata=metadata)) + return docs diff --git a/lazyllm/tools/rag/readers/pptxReader.py b/lazyllm/tools/rag/readers/pptxReader.py new file mode 100644 index 00000000..8085844d --- /dev/null +++ b/lazyllm/tools/rag/readers/pptxReader.py @@ -0,0 +1,81 @@ +import os +import tempfile +from fsspec import AbstractFileSystem +from pathlib import Path +from typing import Optional, Dict, List + +from .readerBase import LazyLLMReaderBase, infer_torch_device +from ..store import DocNode + +class PPTXReader(LazyLLMReaderBase): + def __init__(self, return_trace: bool = True) -> None: + try: + import torch # noqa + from PIL import Image # noqa + from pptx import Presentation # noqa + from transformers import (AutoTokenizer, VisionEncoderDecoderModel, ViTFeatureExtractor,) + except ImportError: + raise ImportError("Please install extra dependencies that are required for the " + "PPTXReader: `pip install torch transformers python-pptx Pillow`") + + super().__init__(return_trace=return_trace) + model = VisionEncoderDecoderModel.from_pretrained("nlpconnect/vit-gpt2-image-captioning") + feature_extractor = ViTFeatureExtractor.from_pretrained("nlpconnect/vit-gpt2-image-captioning") + tokenizer = AutoTokenizer.from_pretrained("nlpconnect/vit-gpt2-image-captioning") + + self._parser_config = {"feature_extractor": feature_extractor, "model": model, "tokenizer": tokenizer} + + def _caption_image(self, tmp_image_file: str) -> str: + from PIL import Image + + model = self._parser_config['model'] + feature_extractor = self._parser_config['feature_extractor'] + tokenizer = self._parser_config['tokenizer'] + + device = infer_torch_device() + model.to(device) + + max_length = 16 + num_beams = 4 + gen_kwargs = {"max_length": max_length, "num_beams": num_beams} + + i_image = Image.open(tmp_image_file) + if i_image.mode != "RGB": i_image = i_image.convert(mode="RGB") + + pixel_values = feature_extractor(images=[i_image], return_tensors="pt").pixel_values + pixel_values = pixel_values.to(device) + + output_ids = model.generate(pixel_values, **gen_kwargs) + + preds = tokenizer.batch_decode(output_ids, skip_special_tokens=True) + return preds[0].strip() + + def _load_data(self, file: Path, extra_info: Optional[Dict] = None, + fs: Optional[AbstractFileSystem] = None) -> List[DocNode]: + from pptx import Presentation + + if not isinstance(file, Path): file = Path(file) + + if fs: + with fs.open(file) as f: + presentation = Presentation(f) + else: + presentation = Presentation(file) + + result = "" + for i, slide in enumerate(presentation.slides): + result += f"\n\nSlide #{i}: \n" + for shape in slide.shapes: + if hasattr(shape, "image"): + image = shape.image + image_bytes = image.blob + f = tempfile.NamedTemporaryFile("wb", delete=False) + try: + f.write(image_bytes) + f.close() + result += f"\n Image: {self._caption_image(f.name)}\n\n" + finally: + os.unlink(f.name) + + if hasattr(shape, "text"): result += f"{shape.text}\n" + return [DocNode(text=result, metadata=extra_info or {})] diff --git a/lazyllm/tools/rag/readers/readerBase.py b/lazyllm/tools/rag/readers/readerBase.py new file mode 100644 index 00000000..70515e52 --- /dev/null +++ b/lazyllm/tools/rag/readers/readerBase.py @@ -0,0 +1,38 @@ +import fsspec +from fsspec.implementations.local import LocalFileSystem +from typing import Iterable, List + +from ....common import LazyLLMRegisterMetaClass +from ..store import DocNode +from lazyllm.module import ModuleBase + +class LazyLLMReaderBase(ModuleBase, metaclass=LazyLLMRegisterMetaClass): + def __init__(self, *args, return_trace: bool = True, **kwargs): + super().__init__(return_trace=return_trace) + + def _lazy_load_data(self, *args, **load_kwargs) -> Iterable[DocNode]: + raise NotImplementedError(f"{self.__class__.__name__} does not implement lazy_load_data method.") + + def _load_data(self, *args, **load_kwargs) -> List[DocNode]: + return list(self._lazy_load_data(*args, **load_kwargs)) + + def forward(self, *args, **kwargs) -> List[DocNode]: + return self._load_data(*args, **kwargs) + + +def get_default_fs(): + return LocalFileSystem() + +def is_default_fs(fs: fsspec.AbstractFileSystem) -> bool: + return isinstance(fs, LocalFileSystem) or not fs.auto_mkdir + +def infer_torch_device() -> str: + try: + has_cuda = torch.cuda.is_available() + except NameError: + import torch + has_cuda = torch.cuda.is_available() + + if has_cuda: return "cuda" + if torch.backends.mps.is_available(): return "mps" + return "cpu" diff --git a/lazyllm/tools/rag/readers/readme.md b/lazyllm/tools/rag/readers/readme.md new file mode 100644 index 00000000..c3e8e627 --- /dev/null +++ b/lazyllm/tools/rag/readers/readme.md @@ -0,0 +1 @@ +Each reader module is borrowd from LLAMA_INDEX, but we have added customized parts, including the entire reader, which is inherited from the Modulebase base class, making all reader modules callable. diff --git a/lazyllm/tools/rag/readers/videoAudioReader.py b/lazyllm/tools/rag/readers/videoAudioReader.py new file mode 100644 index 00000000..02236e75 --- /dev/null +++ b/lazyllm/tools/rag/readers/videoAudioReader.py @@ -0,0 +1,48 @@ +from pathlib import Path +from typing import Dict, List, Optional, cast +from fsspec import AbstractFileSystem + +from .readerBase import LazyLLMReaderBase +from ..store import DocNode + +class VideoAudioReader(LazyLLMReaderBase): + def __init__(self, model_version: str = "base", return_trace: bool = True) -> None: + super().__init__(return_trace=return_trace) + self._model_version = model_version + + try: + import whisper + except ImportError: + raise ImportError("Please install OpenAI whisper model " + "`pip install git+https://github.com/openai/whisper.git` to use the model") + + model = whisper.load_model(self._model_version) + self._parser_config = {"model": model} + + def _load_data(self, file: Path, extra_info: Optional[Dict] = None, + fs: Optional[AbstractFileSystem] = None) -> List[DocNode]: + import whisper + + if not isinstance(file, Path): file = Path(file) + + if file.name.endswith("mp4"): + try: + from pydub import AudioSegment + except ImportError: + raise ImportError("Please install pydub `pip install pydub`") + + if fs: + with fs.open(file, 'rb') as f: + video = AudioSegment.from_file(f, format="mp4") + else: + video = AudioSegment.from_file(file, format="mp4") + + audio = video.split_to_mono()[0] + file_str = str(file)[:-4] + ".mp3" + audio.export(file_str, format="mp3") + + model = cast(whisper.Whisper, self._parser_config["model"]) + result = model.transcribe(str(file)) + + transcript = result['text'] + return [DocNode(text=transcript, metadata=extra_info or {})] diff --git a/lazyllm/tools/rag/store.py b/lazyllm/tools/rag/store.py index 91669445..e304ab5c 100644 --- a/lazyllm/tools/rag/store.py +++ b/lazyllm/tools/rag/store.py @@ -28,12 +28,13 @@ def __init__( group: Optional[str] = None, embedding: Optional[List[float]] = None, parent: Optional["DocNode"] = None, + metadata: Optional[Dict[str, Any]] = None, ) -> None: self.uid: str = uid if uid else str(uuid.uuid4()) self.text: Optional[str] = text self.group: Optional[str] = group self.embedding: Optional[List[float]] = embedding or None - self._metadata: Dict[str, Any] = {} + self._metadata: Dict[str, Any] = metadata or {} # Metadata keys that are excluded from text for the embed model. self._excluded_embed_metadata_keys: List[str] = [] # Metadata keys that are excluded from text for the LLM. diff --git a/tests/basic_tests/test_rag_reader.py b/tests/basic_tests/test_rag_reader.py new file mode 100644 index 00000000..0d479fcc --- /dev/null +++ b/tests/basic_tests/test_rag_reader.py @@ -0,0 +1,64 @@ +import os +import lazyllm +from lazyllm import Document +from lazyllm.tools.rag.readers import ReaderBase +from lazyllm.tools.rag import SimpleDirectoryReader, DocNode + +class YmlReader(ReaderBase): + def _load_data(self, file, extra_info=None, fs=None): + with open(file, 'r') as f: + data = f.read() + node = DocNode(text=data, metadata=extra_info or {}) + node.text = "Call the class YmlReader." + return [node] + +def processYml(file, extra_info=None): + with open(file, 'r') as f: + data = f.read() + return [DocNode(text=data, metadata=extra_info or {})] + +class TestRagReader(object): + @classmethod + def setup_class(cls): + cls.doc = Document(dataset_path="ci_data/rag_reader", create_ui=False) + cls.datasets = os.path.join(lazyllm.config['data_path'], "ci_data/rag_reader/default/__data/sources") + + def test_reader_file(self): + files = [os.path.join(self.datasets, "联网搜索.pdf"), os.path.join(self.datasets, "说明文档测试.docx")] + reader = SimpleDirectoryReader(input_files=files) + docs = [] + for doc in reader(): + print(doc) + docs.append(doc) + print(len(docs)) + assert len(docs) == 3 + + def test_reader_dir(self): + input_dir = self.datasets + reader = SimpleDirectoryReader(input_dir=input_dir, + exclude=["*.jpg", "*.mp3", "*.yml", "*.pdf", ".docx", "*.pptx"]) + docs = [] + for doc in reader(): + print(doc) + docs.append(doc) + print(len(docs)) + assert len(docs) == 13 + + def test_register_local_reader(self): + self.doc.add_reader("**/*.yml", processYml) + files = [os.path.join(self.datasets, "reader_test.yml")] + docs = self.doc._impl._impl.directory_reader.load_data(input_files=files) + assert len(docs) == 1 + + def test_register_global_reader(self): + Document.register_global_reader("**/*.yml", processYml) + files = [os.path.join(self.datasets, "reader_test.yml")] + docs = self.doc._impl._impl.directory_reader.load_data(input_files=files) + assert len(docs) == 1 + + def test_register_local_and_global_reader(self): + Document.register_global_reader("**/*.yml", processYml) + self.doc.add_reader("**/*.yml", YmlReader) + files = [os.path.join(self.datasets, "reader_test.yml")] + docs = self.doc._impl._impl.directory_reader.load_data(input_files=files) + assert docs[0].text == "Call the class YmlReader." diff --git a/tests/requirements.txt b/tests/requirements.txt index a9bc6c2e..953766fd 100644 --- a/tests/requirements.txt +++ b/tests/requirements.txt @@ -1 +1,3 @@ -wikipedia \ No newline at end of file +wikipedia +docx2txt +olefile