Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add use of zstd compression on compute services #336

Merged
merged 23 commits into from
Jan 24, 2025
Merged
Show file tree
Hide file tree
Changes from 20 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
67f7964
Add zstd compression to set_task_result in compute service
ianmkenney Nov 25, 2024
41ad174
Use request and manually process payload
ianmkenney Dec 2, 2024
54cd184
Add compression module and send send bytes from API path as latin-1
ianmkenney Dec 2, 2024
f2da1fc
Use compression module in objectstore
ianmkenney Dec 11, 2024
18131f1
Attempt to decompress object store objects in a try block
ianmkenney Dec 13, 2024
8ee5349
Use fixed filename for compressed object
ianmkenney Dec 18, 2024
b32f62d
Implement test_set_task_result_legacy
ianmkenney Dec 26, 2024
6f3ce3c
Separate executing tasks and pushing results in TestClient
ianmkenney Dec 30, 2024
63f9982
Clear leftover state before testing legacy PDR pull
ianmkenney Dec 31, 2024
6985ca7
Parameterize test_get_transformation_and_network_results
ianmkenney Dec 31, 2024
32bef06
Merge branch 'main' into feature/220-zstd-compression-compute-services
ianmkenney Dec 31, 2024
cda3541
Small docstring adjustment
dotsdl Jan 23, 2025
664cce4
Merge branch 'main' into feature/220-zstd-compression-compute-services
dotsdl Jan 23, 2025
0a15138
Merge fixes
dotsdl Jan 23, 2025
8b65c72
Removed need to decompress protocoldagresult in S3ObjectStore.push_pr…
dotsdl Jan 24, 2025
7d32286
Merge branch 'main' into feature/220-zstd-compression-compute-services
dotsdl Jan 24, 2025
1187666
Some clarity edits
dotsdl Jan 24, 2025
7d3c86c
Black!
dotsdl Jan 24, 2025
e149d09
CI fixes, other edits from review
dotsdl Jan 24, 2025
8473729
Black!
dotsdl Jan 24, 2025
d61d133
Assign value to protocoldagresult
ianmkenney Jan 24, 2025
443e193
Simplification
dotsdl Jan 24, 2025
321f5d7
Revert "simplification"
dotsdl Jan 24, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 62 additions & 0 deletions alchemiscale/compression.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
from gufe.tokenization import GufeTokenizable, JSON_HANDLER
import json
import zstandard as zstd


def compress_gufe_zstd(gufe_object: GufeTokenizable) -> bytes:
"""Compress a GufeTokenizable using zstandard compression.

After the GufeTokenizable is converted to a KeyedChain, it's
serialized into JSON using the gufe provided
JSON_HANDLER.encoder. The resulting string is utf-8 encoded and
compressed with the zstandard compressor. These bytes are returned
by the function.

Parameters
----------
gufe_object: GufeTokenizable
The GufeTokenizable to compress.

Returns
-------
bytes
Compressed byte form of the GufeTokenizable.
"""
keyed_chain_rep = gufe_object.to_keyed_chain()
json_rep = json.dumps(keyed_chain_rep, cls=JSON_HANDLER.encoder)
json_bytes = json_rep.encode("utf-8")

compressor = zstd.ZstdCompressor()
compressed_gufe = compressor.compress(json_bytes)

return compressed_gufe


def decompress_gufe_zstd(compressed_bytes: bytes) -> GufeTokenizable:
"""Decompress a zstandard compressed GufeTokenizable.

The bytes encoding a zstandard compressed GufeTokenizable are
decompressed and decoded using the gufe provided
JSON_HANDLER.decoder. It is assumed that the decompressed bytes
are utf-8 encoded.

This is the inverse operation of `compress_gufe_zstd`.

Parameters
----------
compressed_bytes: bytes
The compressed byte form of a GufeTokenizable.

Returns
-------
GufeTokenizable
The decompressed GufeTokenizable.
"""
decompressor = zstd.ZstdDecompressor()
decompressed_gufe: bytes = decompressor.decompress(compressed_bytes)

keyed_chain_rep = json.loads(
decompressed_gufe.decode("utf-8"), cls=JSON_HANDLER.decoder
)
gufe_object = GufeTokenizable.from_keyed_chain(keyed_chain_rep)
return gufe_object
33 changes: 21 additions & 12 deletions alchemiscale/compute/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,10 @@
from datetime import datetime, timedelta
import random

from fastapi import FastAPI, APIRouter, Body, Depends
from fastapi import FastAPI, APIRouter, Body, Depends, Request
from fastapi.middleware.gzip import GZipMiddleware
from gufe.tokenization import GufeTokenizable, JSON_HANDLER
import zstandard as zstd
from gufe.protocols import ProtocolDAGResult

from ..base.api import (
Expand All @@ -30,6 +31,7 @@
gufe_to_json,
GzipRoute,
)
from ..compression import decompress_gufe_zstd
from ..settings import (
get_base_api_settings,
get_compute_api_settings,
Expand Down Expand Up @@ -297,18 +299,17 @@ def retrieve_task_transformation(

# we keep this as a string to avoid useless deserialization/reserialization here
try:
pdr: str = s3os.pull_protocoldagresult(
pdr_sk, transformation_sk, return_as="json", ok=True
pdr_bytes: bytes = s3os.pull_protocoldagresult(
pdr_sk, transformation_sk, ok=True
)
except:
# if we fail to get the object with the above, fall back to
# location-based retrieval
pdr: str = s3os.pull_protocoldagresult(
pdr_bytes: bytes = s3os.pull_protocoldagresult(
location=protocoldagresultref.location,
return_as="json",
ok=True,
)

pdr = pdr_bytes.decode("latin-1")
else:
pdr = None

Expand All @@ -317,20 +318,24 @@ def retrieve_task_transformation(

# TODO: support compression performed client-side
@router.post("/tasks/{task_scoped_key}/results", response_model=ScopedKey)
def set_task_result(
async def set_task_result(
task_scoped_key,
*,
protocoldagresult: str = Body(embed=True),
compute_service_id: Optional[str] = Body(embed=True),
request: Request,
n4js: Neo4jStore = Depends(get_n4js_depends),
s3os: S3ObjectStore = Depends(get_s3os_depends),
token: TokenData = Depends(get_token_data_depends),
):
body = await request.body()
body_ = json.loads(body.decode("utf-8"), cls=JSON_HANDLER.decoder)

protocoldagresult_ = body_["protocoldagresult"]
compute_service_id = body_["compute_service_id"]

task_sk = ScopedKey.from_str(task_scoped_key)
validate_scopes(task_sk.scope, token)

pdr = json.loads(protocoldagresult, cls=JSON_HANDLER.decoder)
pdr: ProtocolDAGResult = GufeTokenizable.from_dict(pdr)
pdr: ProtocolDAGResult = decompress_gufe_zstd(protocoldagresult_)

tf_sk, _ = n4js.get_task_transformation(
task=task_scoped_key,
Expand All @@ -339,7 +344,11 @@ def set_task_result(

# push the ProtocolDAGResult to the object store
protocoldagresultref: ProtocolDAGResultRef = s3os.push_protocoldagresult(
pdr, transformation=tf_sk, creator=compute_service_id
protocoldagresult=protocoldagresult_,
protocoldagresult_ok=pdr.ok(),
protocoldagresult_gufekey=pdr.key,
transformation=tf_sk,
creator=compute_service_id,
)

# push the reference to the state store
Expand Down
32 changes: 22 additions & 10 deletions alchemiscale/compute/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
import requests
from requests.auth import HTTPBasicAuth

import zstandard as zstd

from gufe.tokenization import GufeTokenizable, JSON_HANDLER
from gufe import Transformation
from gufe.protocols import ProtocolDAGResult
Expand All @@ -22,6 +24,7 @@
AlchemiscaleBaseClientError,
json_to_gufe,
)
from ..compression import compress_gufe_zstd, decompress_gufe_zstd
from ..models import Scope, ScopedKey
from ..storage.models import TaskHub, Task, ComputeServiceID, TaskStatusEnum

Expand Down Expand Up @@ -112,26 +115,35 @@ def get_task_transformation(self, task: ScopedKey) -> ScopedKey:

def retrieve_task_transformation(
self, task: ScopedKey
) -> Tuple[Transformation, Optional[ProtocolDAGResult]]:
transformation, protocoldagresult = self._get_resource(
) -> tuple[Transformation, ProtocolDAGResult | None]:
transformation_json, protocoldagresult_latin1 = self._get_resource(
f"/tasks/{task}/transformation/gufe"
)

return (
json_to_gufe(transformation),
json_to_gufe(protocoldagresult) if protocoldagresult is not None else None,
)
if protocoldagresult is not None:

protocoldagresult_bytes = protocoldagresult_latin1.encode("latin-1")

try:
# Attempt to decompress the ProtocolDAGResult object
protocoldagresult = decompress_gufe_zstd(protocoldagresult_bytes)
except zstd.ZstdError:
# If decompression fails, assume it's a UTF-8 encoded JSON string
protocoldagresult = json_to_gufe(
protocoldagresult_bytes.decode("utf-8")
)

return json_to_gufe(transformation_json), protocoldagresult

def set_task_result(
self,
task: ScopedKey,
protocoldagresult: ProtocolDAGResult,
compute_service_id=Optional[ComputeServiceID],
compute_service_id: Optional[ComputeServiceID] = None,
) -> ScopedKey:

data = dict(
protocoldagresult=json.dumps(
protocoldagresult.to_dict(), cls=JSON_HANDLER.encoder
),
protocoldagresult=compress_gufe_zstd(protocoldagresult),
compute_service_id=str(compute_service_id),
)

Expand Down
8 changes: 3 additions & 5 deletions alchemiscale/interface/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -1140,17 +1140,15 @@ def get_protocoldagresult(
# we leave each ProtocolDAGResult in string form to avoid
# deserializing/reserializing here; just passing through to client
try:
pdr: str = s3os.pull_protocoldagresult(
pdr_sk, transformation_sk, return_as="json", ok=ok
)
pdr_bytes: str = s3os.pull_protocoldagresult(pdr_sk, transformation_sk, ok=ok)
except Exception:
# if we fail to get the object with the above, fall back to
# location-based retrieval
pdr: str = s3os.pull_protocoldagresult(
pdr_bytes: str = s3os.pull_protocoldagresult(
location=protocoldagresultref.location,
return_as="json",
ok=ok,
)
pdr = pdr_bytes.decode("latin-1")

return [pdr]

Expand Down
16 changes: 12 additions & 4 deletions alchemiscale/interface/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,16 @@
from gufe import AlchemicalNetwork, Transformation, ChemicalSystem
from gufe.tokenization import GufeTokenizable, JSON_HANDLER, KeyedChain
from gufe.protocols import ProtocolResult, ProtocolDAGResult
import zstandard as zstd


from ..base.client import (
AlchemiscaleBaseClient,
AlchemiscaleBaseClientError,
json_to_gufe,
use_session,
)
from ..compression import decompress_gufe_zstd
from ..models import Scope, ScopedKey
from ..storage.models import (
TaskStatusEnum,
Expand Down Expand Up @@ -1352,14 +1355,19 @@ def get_tasks_priority(
async def _async_get_protocoldagresult(
self, protocoldagresultref, transformation, route, compress
):
pdr_json = await self._get_resource_async(
pdr_latin1_decoded = await self._get_resource_async(
f"/transformations/{transformation}/{route}/{protocoldagresultref}",
compress=compress,
)

pdr = GufeTokenizable.from_dict(
json.loads(pdr_json[0], cls=JSON_HANDLER.decoder)
)
pdr_bytes = pdr_latin1_decoded[0].encode("latin-1")

try:
# Attempt to decompress the ProtocolDAGResult object
pdr = decompress_gufe_zstd(pdr_bytes)
except zstd.ZstdError:
# If decompress fails, assume it's a UTF-8 encoded JSON string
pdr = json_to_gufe(pdr_bytes.decode("utf-8"))

return pdr

Expand Down
1 change: 0 additions & 1 deletion alchemiscale/security/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
"""

import secrets
import base64
import hashlib
from datetime import datetime, timedelta
from typing import Optional, Union
Expand Down
Loading
Loading