Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

User/tom/fix/s2 pctasks perf #291

Merged
merged 16 commits into from
May 28, 2024
8 changes: 8 additions & 0 deletions datasets/sentinel-2/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,11 @@
```shell
az acr build -r {the registry} --subscription {the subscription} -t pctasks-sentinel-2:latest -t pctasks-sentinel-2:{date}.{count} -f datasets/sentinel-2/Dockerfile .
```

## Update Workflow

Created with

```
pctasks dataset process-items --is-update-workflow sentinel-2-l2a-update -d datasets/sentinel-2/dataset.yaml
```
49 changes: 32 additions & 17 deletions pctasks/core/pctasks/core/storage/blob.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import contextlib
import logging
import multiprocessing
import os
Expand Down Expand Up @@ -252,6 +253,7 @@ def __init__(
self.storage_account_name = storage_account_name
self.container_name = container_name
self.prefix = prefix.strip("/") if prefix is not None else prefix
self._container_client_wrapper: Optional[ContainerClientWrapper] = None

def __repr__(self) -> str:
prefix_part = "" if self.prefix is None else f"/{self.prefix}"
Expand All @@ -261,14 +263,17 @@ def __repr__(self) -> str:
)

def _get_client(self) -> ContainerClientWrapper:
account_client = BlobServiceClient(
account_url=self.account_url,
credential=self._blob_creds,
)

container_client = account_client.get_container_client(self.container_name)
if self._container_client_wrapper is None:
account_client = BlobServiceClient(
account_url=self.account_url,
credential=self._blob_creds,
)

return ContainerClientWrapper(account_client, container_client)
container_client = account_client.get_container_client(self.container_name)
self._container_client_wrapper = ContainerClientWrapper(
account_client, container_client
)
return self._container_client_wrapper

def _get_name_starts_with(
self, additional_prefix: Optional[str] = None
Expand Down Expand Up @@ -388,7 +393,8 @@ def get_path(self, uri: str) -> str:
return blob_uri.blob_name or ""

def get_file_info(self, file_path: str) -> StorageFileInfo:
with self._get_client() as client:
client = self._get_client()
with contextlib.nullcontext():
TomAugspurger marked this conversation as resolved.
Show resolved Hide resolved
with client.container.get_blob_client(self._add_prefix(file_path)) as blob:
try:
props = with_backoff(lambda: blob.get_blob_properties())
Expand All @@ -397,7 +403,8 @@ def get_file_info(self, file_path: str) -> StorageFileInfo:
return StorageFileInfo(size=cast(int, props.size))

def file_exists(self, file_path: str) -> bool:
with self._get_client() as client:
client = self._get_client()
with contextlib.nullcontext():
with client.container.get_blob_client(self._add_prefix(file_path)) as blob:
return with_backoff(lambda: blob.exists())

Expand Down Expand Up @@ -450,7 +457,8 @@ def fetch_blobs() -> Iterable[str]:
for blob_name in page:
yield blob_name

with self._get_client() as client:
client = self._get_client()
with contextlib.nullcontext():
return with_backoff(fetch_blobs)

def walk(
Expand Down Expand Up @@ -514,7 +522,8 @@ def _get_prefix_content(
limit_break = False

full_prefixes: List[str] = [self._get_name_starts_with(name_starts_with) or ""]
with self._get_client() as client:
client = self._get_client()
with contextlib.nullcontext():
while full_prefixes:
if walk_limit and walk_count >= walk_limit:
break
Expand Down Expand Up @@ -570,7 +579,8 @@ def download_file(
if timeout_seconds is not None:
kwargs["timeout"] = timeout_seconds

with self._get_client() as client:
client = self._get_client()
with contextlib.nullcontext():
with client.container.get_blob_client(self._add_prefix(file_path)) as blob:
with open(output_path, "wb" if is_binary else "w") as f:
try:
Expand All @@ -585,7 +595,8 @@ def upload_bytes(
target_path: str,
overwrite: bool = True,
) -> None:
with self._get_client() as client:
client = self._get_client()
with contextlib.nullcontext():
with client.container.get_blob_client(
self._add_prefix(target_path)
) as blob:
Expand Down Expand Up @@ -615,7 +626,8 @@ def upload_file(
kwargs = {}
if content_type:
kwargs["content_settings"] = ContentSettings(content_type=content_type)
with self._get_client() as client:
client = self._get_client()
with contextlib.nullcontext():
with client.container.get_blob_client(
self._add_prefix(target_path)
) as blob:
Expand All @@ -629,7 +641,8 @@ def _upload() -> None:
def read_bytes(self, file_path: str) -> bytes:
try:
blob_path = self._add_prefix(file_path)
with self._get_client() as client:
client = self._get_client()
with contextlib.nullcontext():
with client.container.get_blob_client(blob_path) as blob:
blob_data = with_backoff(
lambda: blob.download_blob(
Expand All @@ -649,7 +662,8 @@ def read_bytes(self, file_path: str) -> bytes:

def write_bytes(self, file_path: str, data: bytes, overwrite: bool = True) -> None:
full_path = self._add_prefix(file_path)
with self._get_client() as client:
client = self._get_client()
with contextlib.nullcontext():
with client.container.get_blob_client(full_path) as blob:
with_backoff(
lambda: blob.upload_blob(data, overwrite=overwrite) # type: ignore
Expand All @@ -660,7 +674,8 @@ def delete_folder(self, folder_path: Optional[str] = None) -> None:
self.delete_file(file_path)

def delete_file(self, file_path: str) -> None:
with self._get_client() as client:
client = self._get_client()
with contextlib.nullcontext():
with client.container.get_blob_client(self._add_prefix(file_path)) as blob:
try:
with_backoff(lambda: blob.delete_blob())
Expand Down
28 changes: 13 additions & 15 deletions pctasks/core/tests/storage/test_blob.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,21 +88,19 @@ def test_blob_download_timeout():
with temp_azurite_blob_storage(
HERE / ".." / "data-files" / "simple-assets"
) as storage:
with storage._get_client() as client:
with client.container.get_blob_client(
storage._add_prefix("a/asset-a-1.json")
) as blob:
storage_stream_downloader = blob.download_blob(timeout=TIMEOUT_SECONDS)
assert (
storage_stream_downloader._request_options["timeout"]
== TIMEOUT_SECONDS
)

storage_stream_downloader = blob.download_blob()
assert (
storage_stream_downloader._request_options.pop("timeout", None)
is None
)
client = storage._get_client()
with client.container.get_blob_client(
storage._add_prefix("a/asset-a-1.json")
) as blob:
storage_stream_downloader = blob.download_blob(timeout=TIMEOUT_SECONDS)
assert (
storage_stream_downloader._request_options["timeout"] == TIMEOUT_SECONDS
)

storage_stream_downloader = blob.download_blob()
assert (
storage_stream_downloader._request_options.pop("timeout", None) is None
)


@pytest.mark.parametrize(
Expand Down
6 changes: 6 additions & 0 deletions pctasks/run/pctasks/run/argo/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
from argo_workflows.api import workflow_service_api
from argo_workflows.exceptions import NotFoundException
from argo_workflows.model.container import Container
from argo_workflows.model.security_context import SecurityContext
from argo_workflows.model.capabilities import Capabilities
from argo_workflows.model.env_var import EnvVar
from argo_workflows.model.io_argoproj_workflow_v1alpha1_template import (
IoArgoprojWorkflowV1alpha1Template,
Expand Down Expand Up @@ -174,6 +176,10 @@ def submit_workflow(
image_pull_policy=get_pull_policy(runner_image),
command=["pctasks"],
env=env,
security_context=SecurityContext(
# Enables tools like py-spy for debugging
capabilities=Capabilities(add=["SYS_PTRACE"])
),
args=[
"-v",
"run",
Expand Down
1 change: 1 addition & 0 deletions pctasks/run/pctasks/run/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ class RunSettings(PCTasksSettings):
def section_name(cls) -> str:
return "run"

max_concurrent_workflow_tasks: int = 120
remote_runner_threads: int = 50
default_task_wait_seconds: int = 60
max_wait_retries: int = 10
Expand Down
4 changes: 4 additions & 0 deletions pctasks/run/pctasks/run/workflow/executor/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,6 +372,10 @@ def partition_id(self) -> str:
def run_id(self) -> str:
return self.job_part_submit_msg.run_id

@property
def has_next_task(self) -> bool:
return bool(self.task_queue)

def prepare_next_task(self, settings: RunSettings) -> None:
next_task_config = next(iter(self.task_queue), None)
if next_task_config:
Expand Down
Loading
Loading