diff --git a/docs/embeddings/cli.md b/docs/embeddings/cli.md index ce92f5b6..e7a16c87 100644 --- a/docs/embeddings/cli.md +++ b/docs/embeddings/cli.md @@ -148,6 +148,7 @@ All three mechanisms support these options: - `-d database.db` to specify a different database file to store the embeddings in - `--store` to store the original content in the embeddings table in addition to the embedding vector - `--prefix` to prepend a prefix to the stored ID of each item +- `--batch-size SIZE` to process embeddings in batches of the specified size (embeddings-cli-embed-multi-csv-etc)= ### Embedding data from a CSV, TSV or JSON file diff --git a/docs/help.md b/docs/help.md index afe2d59e..6813ef7a 100644 --- a/docs/help.md +++ b/docs/help.md @@ -516,6 +516,7 @@ Options: --sql TEXT Read input using this SQL query --attach ... Additional databases to attach - specify alias and file path + --batch-size INTEGER Batch size to use when running embeddings --prefix TEXT Prefix to add to the IDs -m, --model TEXT Embedding model to use --store Store the text itself in the database diff --git a/llm/cli.py b/llm/cli.py index 2ea2706c..8e649a6d 100644 --- a/llm/cli.py +++ b/llm/cli.py @@ -1182,6 +1182,9 @@ def get_db(): multiple=True, help="Additional databases to attach - specify alias and file path", ) +@click.option( + "--batch-size", type=int, help="Batch size to use when running embeddings" +) @click.option("--prefix", help="Prefix to add to the IDs", default="") @click.option("-m", "--model", help="Embedding model to use") @click.option("--store", is_flag=True, help="Store the text itself in the database") @@ -1200,6 +1203,7 @@ def embed_multi( binary, sql, attach, + batch_size, prefix, model, store, @@ -1324,7 +1328,10 @@ def tuples() -> Iterable[Tuple[str, Union[bytes, str]]]: else: yield id, " ".join(v or "" for v in values[1:]) - collection_obj.embed_multi(tuples(), store=store) + embed_kwargs = {"store": store} + if batch_size: + embed_kwargs["batch_size"] = batch_size + collection_obj.embed_multi(tuples(), **embed_kwargs) @cli.command() diff --git a/tests/test_embed_cli.py b/tests/test_embed_cli.py index 8322bf0d..ab435cf3 100644 --- a/tests/test_embed_cli.py +++ b/tests/test_embed_cli.py @@ -369,6 +369,40 @@ def test_embed_multi_sql(tmpdir, use_other_db, prefix): ] +def test_embed_multi_batch_size(embed_demo, tmpdir): + db_path = str(tmpdir / "data.db") + runner = CliRunner() + sql = """ + with recursive cte (id) as ( + select 1 + union all + select id+1 from cte where id < 100 + ) + select id, 'Row ' || cast(id as text) as value from cte + """ + assert getattr(embed_demo, "batch_count", 0) == 0 + result = runner.invoke( + cli, + [ + "embed-multi", + "rows", + "--sql", + sql, + "-d", + db_path, + "-m", + "embed-demo", + "--store", + "--batch-size", + "8", + ], + ) + assert result.exit_code == 0 + db = sqlite_utils.Database(db_path) + assert db["embeddings"].count == 100 + assert embed_demo.batch_count == 13 + + @pytest.fixture def multi_files(tmpdir): db_path = str(tmpdir / "files.db")