Skip to content

Commit

Permalink
Harrison/text splitter (#5417)
Browse files Browse the repository at this point in the history
adds support for keeping separators around when using recursive text
splitter
  • Loading branch information
hwchase17 authored May 29, 2023
1 parent cf5803e commit 72f99ff
Show file tree
Hide file tree
Showing 3 changed files with 99 additions and 59 deletions.
42 changes: 10 additions & 32 deletions docs/modules/indexes/text_splitters/examples/python.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -42,17 +42,17 @@
" \n",
"def foo():\n",
"\n",
"def testing_func():\n",
"def testing_func_with_long_name():\n",
"\n",
"def bar():\n",
"\"\"\"\n",
"python_splitter = PythonCodeTextSplitter(chunk_size=30, chunk_overlap=0)"
"python_splitter = PythonCodeTextSplitter(chunk_size=40, chunk_overlap=0)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "6cdc55f3",
"id": "8cc33770",
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -62,15 +62,16 @@
{
"cell_type": "code",
"execution_count": 4,
"id": "8cc33770",
"id": "f5f70775",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[Document(page_content='Foo:\\n\\n def bar():', lookup_str='', metadata={}, lookup_index=0),\n",
" Document(page_content='foo():\\n\\ndef testing_func():', lookup_str='', metadata={}, lookup_index=0),\n",
" Document(page_content='bar():', lookup_str='', metadata={}, lookup_index=0)]"
"[Document(page_content='class Foo:\\n\\n def bar():', metadata={}),\n",
" Document(page_content='def foo():', metadata={}),\n",
" Document(page_content='def testing_func_with_long_name():', metadata={}),\n",
" Document(page_content='def bar():', metadata={})]"
]
},
"execution_count": 4,
Expand All @@ -82,33 +83,10 @@
"docs"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "de625e08-c440-489d-beed-020b6c53bf69",
"metadata": {
"tags": []
},
"outputs": [
{
"data": {
"text/plain": [
"['Foo:\\n\\n def bar():', 'foo():\\n\\ndef testing_func():', 'bar():']"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"python_splitter.split_text(python_text)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "55aadd84-75ca-48ae-9b84-b39c368488ed",
"id": "6e096d42",
"metadata": {},
"outputs": [],
"source": []
Expand All @@ -130,7 +108,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.6"
"version": "3.9.1"
},
"vscode": {
"interpreter": {
Expand Down
76 changes: 56 additions & 20 deletions langchain/text_splitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import copy
import logging
import re
from abc import ABC, abstractmethod
from typing import (
AbstractSet,
Expand All @@ -27,6 +28,23 @@
TS = TypeVar("TS", bound="TextSplitter")


def _split_text(text: str, separator: str, keep_separator: bool) -> List[str]:
# Now that we have the separator, split the text
if separator:
if keep_separator:
# The parentheses in the pattern keep the delimiters in the result.
_splits = re.split(f"({separator})", text)
splits = [_splits[i] + _splits[i + 1] for i in range(1, len(_splits), 2)]
if len(_splits) % 2 == 0:
splits += _splits[-1:]
splits = [_splits[0]] + splits
else:
splits = text.split(separator)
else:
splits = list(text)
return [s for s in splits if s != ""]


class TextSplitter(BaseDocumentTransformer, ABC):
"""Interface for splitting text into chunks."""

Expand All @@ -35,8 +53,16 @@ def __init__(
chunk_size: int = 4000,
chunk_overlap: int = 200,
length_function: Callable[[str], int] = len,
keep_separator: bool = False,
):
"""Create a new TextSplitter."""
"""Create a new TextSplitter.
Args:
chunk_size: Maximum size of chunks to return
chunk_overlap: Overlap in characters between chunks
length_function: Function that measures the length of given chunks
keep_separator: Whether or not to keep the separator in the chunks
"""
if chunk_overlap > chunk_size:
raise ValueError(
f"Got a larger chunk overlap ({chunk_overlap}) than chunk size "
Expand All @@ -45,6 +71,7 @@ def __init__(
self._chunk_size = chunk_size
self._chunk_overlap = chunk_overlap
self._length_function = length_function
self._keep_separator = keep_separator

@abstractmethod
def split_text(self, text: str) -> List[str]:
Expand Down Expand Up @@ -211,11 +238,9 @@ def __init__(self, separator: str = "\n\n", **kwargs: Any):
def split_text(self, text: str) -> List[str]:
"""Split incoming text and return chunks."""
# First we naively split the large input into a bunch of smaller ones.
if self._separator:
splits = text.split(self._separator)
else:
splits = list(text)
return self._merge_splits(splits, self._separator)
splits = _split_text(text, self._separator, self._keep_separator)
_separator = "" if self._keep_separator else self._separator
return self._merge_splits(splits, _separator)


class TokenTextSplitter(TextSplitter):
Expand Down Expand Up @@ -274,45 +299,56 @@ class RecursiveCharacterTextSplitter(TextSplitter):
that works.
"""

def __init__(self, separators: Optional[List[str]] = None, **kwargs: Any):
def __init__(
self,
separators: Optional[List[str]] = None,
keep_separator: bool = True,
**kwargs: Any,
):
"""Create a new TextSplitter."""
super().__init__(**kwargs)
super().__init__(keep_separator=keep_separator, **kwargs)
self._separators = separators or ["\n\n", "\n", " ", ""]

def split_text(self, text: str) -> List[str]:
def _split_text(self, text: str, separators: List[str]) -> List[str]:
"""Split incoming text and return chunks."""
final_chunks = []
# Get appropriate separator to use
separator = self._separators[-1]
for _s in self._separators:
separator = separators[-1]
new_separators = None
for i, _s in enumerate(separators):
if _s == "":
separator = _s
break
if _s in text:
separator = _s
new_separators = separators[i + 1 :]
break
# Now that we have the separator, split the text
if separator:
splits = text.split(separator)
else:
splits = list(text)

splits = _split_text(text, separator, self._keep_separator)
# Now go merging things, recursively splitting longer texts.
_good_splits = []
_separator = "" if self._keep_separator else separator
for s in splits:
if self._length_function(s) < self._chunk_size:
_good_splits.append(s)
else:
if _good_splits:
merged_text = self._merge_splits(_good_splits, separator)
merged_text = self._merge_splits(_good_splits, _separator)
final_chunks.extend(merged_text)
_good_splits = []
other_info = self.split_text(s)
final_chunks.extend(other_info)
if new_separators is None:
final_chunks.append(s)
else:
other_info = self._split_text(s, new_separators)
final_chunks.extend(other_info)
if _good_splits:
merged_text = self._merge_splits(_good_splits, separator)
merged_text = self._merge_splits(_good_splits, _separator)
final_chunks.extend(merged_text)
return final_chunks

def split_text(self, text: str) -> List[str]:
return self._split_text(text, self._separators)


class NLTKTextSplitter(TextSplitter):
"""Implementation of splitting text that looks at sentences using NLTK."""
Expand Down
40 changes: 33 additions & 7 deletions tests/unit_tests/test_text_splitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,23 @@
from langchain.docstore.document import Document
from langchain.text_splitter import (
CharacterTextSplitter,
PythonCodeTextSplitter,
RecursiveCharacterTextSplitter,
)

FAKE_PYTHON_TEXT = """
class Foo:
def bar():
def foo():
def testing_func():
def bar():
"""


def test_character_text_splitter() -> None:
"""Test splitting by character count."""
Expand Down Expand Up @@ -135,15 +149,16 @@ def test_iterative_text_splitter() -> None:
"Okay then",
"f f f f.",
"This is a",
"a weird",
"weird",
"text to",
"write, but",
"gotta test",
"the",
"splittingg",
"ggg",
"write,",
"but gotta",
"test the",
"splitting",
"gggg",
"some how.",
"Bye!\n\n-H.",
"Bye!",
"-H.",
]
assert output == expected_output

Expand All @@ -168,3 +183,14 @@ def test_split_documents() -> None:
Document(page_content="z", metadata={"source": "1"}),
]
assert splitter.split_documents(docs) == expected_output


def test_python_text_splitter() -> None:
splitter = PythonCodeTextSplitter(chunk_size=30, chunk_overlap=0)
splits = splitter.split_text(FAKE_PYTHON_TEXT)
split_0 = """class Foo:\n\n def bar():"""
split_1 = """def foo():"""
split_2 = """def testing_func():"""
split_3 = """def bar():"""
expected_splits = [split_0, split_1, split_2, split_3]
assert splits == expected_splits

0 comments on commit 72f99ff

Please sign in to comment.