Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Langchain/New-feature: Added Haskell support in langchain.text_splitter module #16191

Merged
merged 11 commits into from
Mar 29, 2024
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": 4,
"id": "a9e37aa1",
"metadata": {},
"outputs": [],
Expand All @@ -35,7 +35,7 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 5,
"id": "e21a2434",
"metadata": {},
"outputs": [
Expand All @@ -61,10 +61,14 @@
" 'html',\n",
" 'sol',\n",
" 'csharp',\n",
" 'cobol']"
" 'cobol',\n",
" 'c',\n",
" 'lua',\n",
" 'perl',\n",
" 'haskell']"
]
},
"execution_count": 2,
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
Expand Down Expand Up @@ -564,13 +568,50 @@
"c_docs"
]
},
{
"cell_type": "markdown",
"id": "af9de667-230e-4c2a-8c5f-122a28515d97",
"metadata": {},
"source": [
"## Haskell\n",
"Here's an example using the Haskell text splitter:"
]
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 3,
"id": "688185b5",
"metadata": {},
"outputs": [],
"source": []
"outputs": [
{
"data": {
"text/plain": [
"[Document(page_content='main :: IO ()'),\n",
" Document(page_content='main = do\\n putStrLn \"Hello, World!\"\\n-- Some'),\n",
" Document(page_content='sample functions\\nadd :: Int -> Int -> Int\\nadd x y'),\n",
" Document(page_content='= x + y')]"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"HASKELL_CODE = \"\"\"\n",
"main :: IO ()\n",
"main = do\n",
" putStrLn \"Hello, World!\"\n",
"-- Some sample functions\n",
"add :: Int -> Int -> Int\n",
"add x y = x + y\n",
"\"\"\"\n",
"haskell_splitter = RecursiveCharacterTextSplitter.from_language(\n",
" language=Language.HASKELL, chunk_size=50, chunk_overlap=0\n",
")\n",
"haskell_docs = haskell_splitter.create_documents([HASKELL_CODE])\n",
"haskell_docs"
]
}
],
"metadata": {
Expand Down
1 change: 1 addition & 0 deletions libs/text-splitters/langchain_text_splitters/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,7 @@ class Language(str, Enum):
C = "c"
LUA = "lua"
PERL = "perl"
HASKELL = "haskell"


@dataclass(frozen=True)
Expand Down
40 changes: 39 additions & 1 deletion libs/text-splitters/langchain_text_splitters/character.py
Original file line number Diff line number Diff line change
Expand Up @@ -571,7 +571,45 @@ def get_separators_for_language(language: Language) -> List[str]:
" ",
"",
]

elif language == Language.HASKELL:
return [
# Split along function definitions
"\nmain :: ",
"\nmain = ",
"\nlet ",
"\nin ",
"\ndo ",
"\nwhere ",
"\n:: ",
"\n= ",
# Split along type declarations
"\ndata ",
"\nnewtype ",
"\ntype ",
"\n:: ",
# Split along module declarations
"\nmodule ",
# Split along import statements
"\nimport ",
"\nqualified ",
"\nimport qualified ",
# Split along typeclass declarations
"\nclass ",
"\ninstance ",
# Split along case expressions
"\ncase ",
# Split along guards in function definitions
"\n| ",
# Split along record field declarations
"\ndata ",
"\n= {",
"\n, ",
# Split by the normal type of lines
"\n\n",
"\n",
" ",
"",
]
else:
raise ValueError(
f"Language {language} is not supported! "
Expand Down
32 changes: 32 additions & 0 deletions libs/text-splitters/tests/unit_tests/test_text_splitters.py
Original file line number Diff line number Diff line change
Expand Up @@ -1248,6 +1248,38 @@ def test_solidity_code_splitter() -> None:
]


def test_haskell_code_splitter() -> None:
splitter = RecursiveCharacterTextSplitter.from_language(
Language.HASKELL, chunk_size=CHUNK_SIZE, chunk_overlap=0
)
code = """
main :: IO ()
main = do
putStrLn "Hello, World!"

-- Some sample functions
add :: Int -> Int -> Int
add x y = x + y
"""
# Adjusted expected chunks to account for indentation and newlines
expected_chunks = [
"main ::",
"IO ()",
"main = do",
"putStrLn",
'"Hello, World!"',
"--",
"Some sample",
"functions",
"add :: Int ->",
"Int -> Int",
"add x y = x",
"+ y",
]
chunks = splitter.split_text(code)
assert chunks == expected_chunks


@pytest.mark.requires("lxml")
def test_html_header_text_splitter(tmp_path: Path) -> None:
splitter = HTMLHeaderTextSplitter(
Expand Down
Loading