Skip to content

Commit

Permalink
llm embed-multi --batch-size option, closes #273
Browse files Browse the repository at this point in the history
  • Loading branch information
simonw committed Sep 13, 2023
1 parent b9478e6 commit 33dee47
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 1 deletion.
1 change: 1 addition & 0 deletions docs/embeddings/cli.md
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@ All three mechanisms support these options:
- `-d database.db` to specify a different database file to store the embeddings in
- `--store` to store the original content in the embeddings table in addition to the embedding vector
- `--prefix` to prepend a prefix to the stored ID of each item
- `--batch-size SIZE` to process embeddings in batches of the specified size

(embeddings-cli-embed-multi-csv-etc)=
### Embedding data from a CSV, TSV or JSON file
Expand Down
1 change: 1 addition & 0 deletions docs/help.md
Original file line number Diff line number Diff line change
Expand Up @@ -516,6 +516,7 @@ Options:
--sql TEXT Read input using this SQL query
--attach <TEXT FILE>... Additional databases to attach - specify alias
and file path
--batch-size INTEGER Batch size to use when running embeddings
--prefix TEXT Prefix to add to the IDs
-m, --model TEXT Embedding model to use
--store Store the text itself in the database
Expand Down
9 changes: 8 additions & 1 deletion llm/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -1182,6 +1182,9 @@ def get_db():
multiple=True,
help="Additional databases to attach - specify alias and file path",
)
@click.option(
"--batch-size", type=int, help="Batch size to use when running embeddings"
)
@click.option("--prefix", help="Prefix to add to the IDs", default="")
@click.option("-m", "--model", help="Embedding model to use")
@click.option("--store", is_flag=True, help="Store the text itself in the database")
Expand All @@ -1200,6 +1203,7 @@ def embed_multi(
binary,
sql,
attach,
batch_size,
prefix,
model,
store,
Expand Down Expand Up @@ -1324,7 +1328,10 @@ def tuples() -> Iterable[Tuple[str, Union[bytes, str]]]:
else:
yield id, " ".join(v or "" for v in values[1:])

collection_obj.embed_multi(tuples(), store=store)
embed_kwargs = {"store": store}
if batch_size:
embed_kwargs["batch_size"] = batch_size
collection_obj.embed_multi(tuples(), **embed_kwargs)


@cli.command()
Expand Down
34 changes: 34 additions & 0 deletions tests/test_embed_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,40 @@ def test_embed_multi_sql(tmpdir, use_other_db, prefix):
]


def test_embed_multi_batch_size(embed_demo, tmpdir):
db_path = str(tmpdir / "data.db")
runner = CliRunner()
sql = """
with recursive cte (id) as (
select 1
union all
select id+1 from cte where id < 100
)
select id, 'Row ' || cast(id as text) as value from cte
"""
assert getattr(embed_demo, "batch_count", 0) == 0
result = runner.invoke(
cli,
[
"embed-multi",
"rows",
"--sql",
sql,
"-d",
db_path,
"-m",
"embed-demo",
"--store",
"--batch-size",
"8",
],
)
assert result.exit_code == 0
db = sqlite_utils.Database(db_path)
assert db["embeddings"].count == 100
assert embed_demo.batch_count == 13


@pytest.fixture
def multi_files(tmpdir):
db_path = str(tmpdir / "files.db")
Expand Down

0 comments on commit 33dee47

Please sign in to comment.