Skip to content

Commit

Permalink
community[major]: breaking change in some APIs to force users to opt-…
Browse files Browse the repository at this point in the history
…in for pickling (#18696)

This is a PR that adds a dangerous load parameter to force users to opt in to use pickle.

This is a PR that's meant to raise user awareness that the pickling module is involved.
  • Loading branch information
eyurtsev authored Mar 6, 2024
1 parent 0e52961 commit 4c25b49
Show file tree
Hide file tree
Showing 10 changed files with 128 additions and 7 deletions.
20 changes: 19 additions & 1 deletion libs/community/langchain_community/llms/databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,6 @@ def _pickle_fn_to_hex_string(fn: Callable) -> str:


class Databricks(LLM):

"""Databricks serving endpoint or a cluster driver proxy app for LLM.
It supports two endpoint types:
Expand Down Expand Up @@ -374,6 +373,15 @@ class Databricks(LLM):
If not provided, the task is automatically inferred from the endpoint.
"""

allow_dangerous_deserialization: bool = False
"""Whether to allow dangerous deserialization of the data which
involves loading data using pickle.
If the data has been modified by a malicious actor, it can deliver a
malicious payload that results in execution of arbitrary code on the target
machine.
"""

_client: _DatabricksClientBase = PrivateAttr()

class Config:
Expand Down Expand Up @@ -435,6 +443,16 @@ def set_model_kwargs(cls, v: Optional[Dict[str, Any]]) -> Optional[Dict[str, Any
return v

def __init__(self, **data: Any):
if not data.get("allow_dangerous_deserialization"):
raise ValueError(
"This code relies on the pickle module. "
"You will need to set allow_dangerous_deserialization=True "
"if you want to opt-in to allow deserialization of data using pickle."
"Data can be compromised by a malicious actor if "
"not handled properly to include "
"a malicious payload that when deserialized with "
"pickle can execute arbitrary code on your machine."
)
if "transform_input_fn" in data and _is_hex_string(data["transform_input_fn"]):
data["transform_input_fn"] = _load_pickled_fn_from_hex_string(
data["transform_input_fn"]
Expand Down
15 changes: 15 additions & 0 deletions libs/community/langchain_community/llms/self_hosted.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,11 @@ def inference_fn(pipeline, prompt, stop = None):
model_reqs: List[str] = ["./", "torch"]
"""Requirements to install on hardware to inference the model."""

allow_dangerous_deserialization: bool = False
"""Allow deserialization using pickle which can be dangerous if
loading compromised data.
"""

class Config:
"""Configuration for this pydantic object."""

Expand All @@ -149,6 +154,16 @@ def __init__(self, **kwargs: Any):
and run on the server, i.e. in a module and not a REPL or closure.
Then, initialize the remote inference function.
"""
if not kwargs.get("allow_dangerous_deserialization"):
raise ValueError(
"SelfHostedPipeline relies on the pickle module. "
"You will need to set allow_dangerous_deserialization=True "
"if you want to opt-in to allow deserialization of data using pickle."
"Data can be compromised by a malicious actor if "
"not handled properly to include "
"a malicious payload that when deserialized with "
"pickle can execute arbitrary code. "
)
super().__init__(**kwargs)
try:
import runhouse as rh
Expand Down
20 changes: 20 additions & 0 deletions libs/community/langchain_community/vectorstores/annoy.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,14 +429,34 @@ def load_local(
cls,
folder_path: str,
embeddings: Embeddings,
*,
allow_dangerous_deserialization: bool = False,
) -> Annoy:
"""Load Annoy index, docstore, and index_to_docstore_id to disk.
Args:
folder_path: folder path to load index, docstore,
and index_to_docstore_id from.
embeddings: Embeddings to use when generating queries.
allow_dangerous_deserialization: whether to allow deserialization
of the data which involves loading a pickle file.
Pickle files can be modified by malicious actors to deliver a
malicious payload that results in execution of
arbitrary code on your machine.
"""
if not allow_dangerous_deserialization:
raise ValueError(
"The de-serialization relies loading a pickle file. "
"Pickle files can be modified to deliver a malicious payload that "
"results in execution of arbitrary code on your machine."
"You will need to set `allow_dangerous_deserialization` to `True` to "
"enable deserialization. If you do this, make sure that you "
"trust the source of the data. For example, if you are loading a "
"file that you created, and no that no one else has modified the file, "
"then this is safe to do. Do not set this to `True` if you are loading "
"a file from an untrusted source (e.g., some random site on the "
"internet.)."
)
path = Path(folder_path)
# load index separately since it is not picklable
annoy = dependable_annoy_import()
Expand Down
20 changes: 20 additions & 0 deletions libs/community/langchain_community/vectorstores/faiss.py
Original file line number Diff line number Diff line change
Expand Up @@ -1093,6 +1093,8 @@ def load_local(
folder_path: str,
embeddings: Embeddings,
index_name: str = "index",
*,
allow_dangerous_deserialization: bool = False,
**kwargs: Any,
) -> FAISS:
"""Load FAISS index, docstore, and index_to_docstore_id from disk.
Expand All @@ -1102,8 +1104,26 @@ def load_local(
and index_to_docstore_id from.
embeddings: Embeddings to use when generating queries
index_name: for saving with a specific index file name
allow_dangerous_deserialization: whether to allow deserialization
of the data which involves loading a pickle file.
Pickle files can be modified by malicious actors to deliver a
malicious payload that results in execution of
arbitrary code on your machine.
asynchronous: whether to use async version or not
"""
if not allow_dangerous_deserialization:
raise ValueError(
"The de-serialization relies loading a pickle file. "
"Pickle files can be modified to deliver a malicious payload that "
"results in execution of arbitrary code on your machine."
"You will need to set `allow_dangerous_deserialization` to `True` to "
"enable deserialization. If you do this, make sure that you "
"trust the source of the data. For example, if you are loading a "
"file that you created, and no that no one else has modified the file, "
"then this is safe to do. Do not set this to `True` if you are loading "
"a file from an untrusted source (e.g., some random site on the "
"internet.)."
)
path = Path(folder_path)
# load index separately since it is not picklable
faiss = dependable_faiss_import()
Expand Down
20 changes: 20 additions & 0 deletions libs/community/langchain_community/vectorstores/scann.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,6 +460,8 @@ def load_local(
folder_path: str,
embedding: Embeddings,
index_name: str = "index",
*,
allow_dangerous_deserialization: bool = False,
**kwargs: Any,
) -> ScaNN:
"""Load ScaNN index, docstore, and index_to_docstore_id from disk.
Expand All @@ -469,7 +471,25 @@ def load_local(
and index_to_docstore_id from.
embeddings: Embeddings to use when generating queries
index_name: for saving with a specific index file name
allow_dangerous_deserialization: whether to allow deserialization
of the data which involves loading a pickle file.
Pickle files can be modified by malicious actors to deliver a
malicious payload that results in execution of
arbitrary code on your machine.
"""
if not allow_dangerous_deserialization:
raise ValueError(
"The de-serialization relies loading a pickle file. "
"Pickle files can be modified to deliver a malicious payload that "
"results in execution of arbitrary code on your machine."
"You will need to set `allow_dangerous_deserialization` to `True` to "
"enable deserialization. If you do this, make sure that you "
"trust the source of the data. For example, if you are loading a "
"file that you created, and no that no one else has modified the file, "
"then this is safe to do. Do not set this to `True` if you are loading "
"a file from an untrusted source (e.g., some random site on the "
"internet.)."
)
path = Path(folder_path)
scann_path = path / "{index_name}.scann".format(index_name=index_name)
scann_path.mkdir(exist_ok=True, parents=True)
Expand Down
21 changes: 20 additions & 1 deletion libs/community/langchain_community/vectorstores/tiledb.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,9 +87,28 @@ def __init__(
docs_array_uri: str = "",
config: Optional[Mapping[str, Any]] = None,
timestamp: Any = None,
allow_dangerous_deserialization: bool = False,
**kwargs: Any,
):
"""Initialize with necessary components."""
"""Initialize with necessary components.
Args:
allow_dangerous_deserialization: whether to allow deserialization
of the data which involves loading data using pickle.
data can be modified by malicious actors to deliver a
malicious payload that results in execution of
arbitrary code on your machine.
"""
if not allow_dangerous_deserialization:
raise ValueError(
"TileDB relies on pickle for serialization and deserialization. "
"This can be dangerous if the data is intercepted and/or modified "
"by malicious actors prior to being de-serialized. "
"If you are sure that the data is safe from modification, you can "
" set allow_dangerous_deserialization=True to proceed. "
"Loading of compromised data using pickle can result in execution of "
"arbitrary code on your machine."
)
self.embedding = embedding
self.embedding_function = embedding.embed_query
self.index_uri = index_uri
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,9 @@ def test_annoy_local_save_load() -> None:

temp_dir = tempfile.TemporaryDirectory()
docsearch.save_local(temp_dir.name)
loaded_docsearch = Annoy.load_local(temp_dir.name, FakeEmbeddings())
loaded_docsearch = Annoy.load_local(
temp_dir.name, FakeEmbeddings(), allow_dangerous_deserialization=True
)

assert docsearch.index_to_docstore_id == loaded_docsearch.index_to_docstore_id
assert docsearch.docstore.__dict__ == loaded_docsearch.docstore.__dict__
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,9 @@ def test_scann_local_save_load() -> None:
temp_timestamp = datetime.datetime.utcnow().strftime("%Y%m%d-%H%M%S")
with tempfile.TemporaryDirectory(suffix="_" + temp_timestamp + "/") as temp_folder:
docsearch.save_local(temp_folder)
new_docsearch = ScaNN.load_local(temp_folder, FakeEmbeddings())
new_docsearch = ScaNN.load_local(
temp_folder, FakeEmbeddings(), allow_dangerous_deserialization=True
)
assert new_docsearch.index is not None


Expand Down
3 changes: 2 additions & 1 deletion libs/community/tests/unit_tests/llms/test_databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,9 @@ def test_serde_transform_input_fn(monkeypatch: MonkeyPatch) -> None:
monkeypatch.setenv("DATABRICKS_TOKEN", "my-default-token")

llm = Databricks(
endpoint_name="databricks-mixtral-8x7b-instruct",
endpoint_name="some_end_point_name", # Value should not matter for this test
transform_input_fn=transform_input,
allow_dangerous_deserialization=True,
)
params = llm._default_params
pickled_string = cloudpickle.dumps(transform_input).hex()
Expand Down
8 changes: 6 additions & 2 deletions libs/community/tests/unit_tests/vectorstores/test_faiss.py
Original file line number Diff line number Diff line change
Expand Up @@ -608,7 +608,9 @@ def test_faiss_local_save_load() -> None:
temp_timestamp = datetime.datetime.utcnow().strftime("%Y%m%d-%H%M%S")
with tempfile.TemporaryDirectory(suffix="_" + temp_timestamp + "/") as temp_folder:
docsearch.save_local(temp_folder)
new_docsearch = FAISS.load_local(temp_folder, FakeEmbeddings())
new_docsearch = FAISS.load_local(
temp_folder, FakeEmbeddings(), allow_dangerous_deserialization=True
)
assert new_docsearch.index is not None


Expand All @@ -620,7 +622,9 @@ async def test_faiss_async_local_save_load() -> None:
temp_timestamp = datetime.datetime.utcnow().strftime("%Y%m%d-%H%M%S")
with tempfile.TemporaryDirectory(suffix="_" + temp_timestamp + "/") as temp_folder:
docsearch.save_local(temp_folder)
new_docsearch = FAISS.load_local(temp_folder, FakeEmbeddings())
new_docsearch = FAISS.load_local(
temp_folder, FakeEmbeddings(), allow_dangerous_deserialization=True
)
assert new_docsearch.index is not None


Expand Down

0 comments on commit 4c25b49

Please sign in to comment.