Skip to content

Commit

Permalink
Allows specifying chunk size and overlap with /learn (jupyterlab#267)
Browse files Browse the repository at this point in the history
* Allows specifying chunk size and overlap with /learn

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Refactored as per PR review comments

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Documents -c and -o options

* Update docs/source/users/index.md

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Jason Weill <[email protected]>
  • Loading branch information
3 people authored Jul 18, 2023
1 parent 584878d commit d8066a3
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 24 deletions.
14 changes: 14 additions & 0 deletions docs/source/users/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,20 @@ To clear the local vector database, you can run `/learn -d` and Jupyter AI will
alt='Screen shot of a "/learn -d" command and a response.'
class="screenshot" />

With the `/learn` command, some models work better with custom chunk size and chunk overlap values. To override the defaults,
use the `-c` or `--chunk-size` option and the `-o` or `--chunk-overlap` option.

```
# default chunk size and chunk overlap
/learn <directory>
# chunk size of 500, and chunk overlap of 50
/learn -c 500 -o 50 <directory>
# chunk size of 1000, and chunk overlap of 200
/learn --chunk-size 1000 --chunk-overlap 200 <directory>
```

### Additional chat commands

To clear the chat panel, use the `/clear` command. This does not reset the AI model; the model may still remember previous messages that you sent it, and it may use them to inform its responses.
Expand Down
57 changes: 33 additions & 24 deletions packages/jupyter-ai/jupyter_ai/chat_handlers/learn.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,13 @@
from dask.distributed import Client as DaskClient
from jupyter_ai.document_loaders.directory import get_embeddings, split
from jupyter_ai.document_loaders.splitter import ExtensionSplitter, NotebookSplitter
from jupyter_ai.models import HumanChatMessage, IndexedDir, IndexMetadata
from jupyter_ai.models import (
DEFAULT_CHUNK_OVERLAP,
DEFAULT_CHUNK_SIZE,
HumanChatMessage,
IndexedDir,
IndexMetadata,
)
from jupyter_core.paths import jupyter_data_dir
from langchain import FAISS
from langchain.schema import BaseRetriever, Document
Expand All @@ -30,12 +36,20 @@ def __init__(
super().__init__(*args, **kwargs)
self.root_dir = root_dir
self.dask_client_future = dask_client_future
self.chunk_size = 2000
self.chunk_overlap = 100
self.parser.prog = "/learn"
self.parser.add_argument("-v", "--verbose", action="store_true")
self.parser.add_argument("-d", "--delete", action="store_true")
self.parser.add_argument("-l", "--list", action="store_true")
self.parser.add_argument(
"-c", "--chunk-size", action="store", default=DEFAULT_CHUNK_SIZE, type=int
)
self.parser.add_argument(
"-o",
"--chunk-overlap",
action="store",
default=DEFAULT_CHUNK_OVERLAP,
type=int,
)
self.parser.add_argument("path", nargs=argparse.REMAINDER)
self.index_name = "default"
self.index = None
Expand Down Expand Up @@ -102,7 +116,7 @@ async def _process_message(self, message: HumanChatMessage):
if args.verbose:
self.reply(f"Loading and splitting files for {load_path}", message)

await self.learn_dir(load_path)
await self.learn_dir(load_path, args.chunk_size, args.chunk_overlap)
self.save()

response = f"""🎉 I have learned documents at **{load_path}** and I am ready to answer questions about them.
Expand All @@ -119,27 +133,18 @@ def _build_list_response(self):
{dir_list}"""
return message

async def learn_dir(self, path: str):
async def learn_dir(self, path: str, chunk_size: int, chunk_overlap: int):
dask_client = await self.dask_client_future
splitter_kwargs = {chunk_size: chunk_size, chunk_overlap: chunk_overlap}
splitters = {
".py": PythonCodeTextSplitter(
chunk_size=self.chunk_size, chunk_overlap=self.chunk_overlap
),
".md": MarkdownTextSplitter(
chunk_size=self.chunk_size, chunk_overlap=self.chunk_overlap
),
".tex": LatexTextSplitter(
chunk_size=self.chunk_size, chunk_overlap=self.chunk_overlap
),
".ipynb": NotebookSplitter(
chunk_size=self.chunk_size, chunk_overlap=self.chunk_overlap
),
".py": PythonCodeTextSplitter(**splitter_kwargs),
".md": MarkdownTextSplitter(**splitter_kwargs),
".tex": LatexTextSplitter(**splitter_kwargs),
".ipynb": NotebookSplitter(**splitter_kwargs),
}
splitter = ExtensionSplitter(
splitters=splitters,
default_splitter=RecursiveCharacterTextSplitter(
chunk_size=self.chunk_size, chunk_overlap=self.chunk_overlap
),
default_splitter=RecursiveCharacterTextSplitter(**splitter_kwargs),
)

delayed = split(path, splitter=splitter)
Expand All @@ -149,14 +154,18 @@ async def learn_dir(self, path: str):
delayed = get_embeddings(doc_chunks, em_provider_cls, em_provider_args)
embedding_records = await dask_client.compute(delayed)
self.index.add_embeddings(*embedding_records)
self._add_dir_to_metadata(path)
self._add_dir_to_metadata(path, chunk_size, chunk_overlap)
self.prev_em_id = em_provider_cls.id + ":" + em_provider_args["model_id"]

def _add_dir_to_metadata(self, path: str):
def _add_dir_to_metadata(self, path: str, chunk_size: int, chunk_overlap: int):
dirs = self.metadata.dirs
index = next((i for i, dir in enumerate(dirs) if dir.path == path), None)
if not index:
dirs.append(IndexedDir(path=path))
dirs.append(
IndexedDir(
path=path, chunk_size=chunk_size, chunk_overlap=chunk_overlap
)
)
self.metadata.dirs = dirs

async def delete_and_relearn(self):
Expand Down Expand Up @@ -213,7 +222,7 @@ async def relearn(self, metadata: IndexMetadata):
for dir in metadata.dirs:
# TODO: do not relearn directories in serial, but instead
# concurrently or in parallel
await self.learn_dir(dir.path)
await self.learn_dir(dir.path, dir.chunk_size, dir.chunk_overlap)

self.save()

Expand Down
5 changes: 5 additions & 0 deletions packages/jupyter-ai/jupyter_ai/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
from jupyter_ai_magics.providers import AuthStrategy, Field
from pydantic import BaseModel

DEFAULT_CHUNK_SIZE = 2000
DEFAULT_CHUNK_OVERLAP = 100


# the type of message used to chat with the agent
class ChatRequest(BaseModel):
Expand Down Expand Up @@ -86,6 +89,8 @@ class ListProvidersResponse(BaseModel):

class IndexedDir(BaseModel):
path: str
chunk_size: int = DEFAULT_CHUNK_SIZE
chunk_overlap: int = DEFAULT_CHUNK_OVERLAP


class IndexMetadata(BaseModel):
Expand Down

0 comments on commit d8066a3

Please sign in to comment.