Skip to content

Commit

Permalink
Fix Azure storage for assets
Browse files Browse the repository at this point in the history
  • Loading branch information
jtnicholl-cosairus committed Jun 27, 2024
1 parent 9a0724d commit 26e4208
Showing 1 changed file with 22 additions and 3 deletions.
25 changes: 22 additions & 3 deletions weasel/util/remote.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import shutil
import os
import sys
from wasabi import msg
from pathlib import Path
from typing import TYPE_CHECKING, Union
from typing import TYPE_CHECKING, Union, Optional, Any

if TYPE_CHECKING:
from cloudpathlib import CloudPath
Expand All @@ -20,7 +23,8 @@ def upload_file(src: Path, dest: Union[str, "CloudPath"]) -> None:
dest.parent.mkdir(parents=True)

dest = str(dest)
with smart_open.open(dest, mode="wb") as output_file:
transport_params = _transport_params(dest)
with smart_open.open(dest, mode="wb", transport_params=transport_params) as output_file:
with src.open(mode="rb") as input_file:
output_file.write(input_file.read())

Expand All @@ -40,6 +44,21 @@ def download_file(
if dest.exists() and not force:
return None
src = str(src)
with smart_open.open(src, mode="rb", compression="disable") as input_file:
transport_params = _transport_params(src)
with smart_open.open(src, mode="rb", compression="disable", transport_params=transport_params) as input_file:
with dest.open(mode="wb") as output_file:
shutil.copyfileobj(input_file, output_file)


def _transport_params(url: str) -> Optional[dict[str, Any]]:
if url.startswith("azure://"):
connection_string = os.environ.get("AZURE_STORAGE_CONNECTION_STRING")
if not connection_string:
msg.fail(
"Azure storage requires a connection string, which was not provided.",
"Assign it to the environment variable AZURE_STORAGE_CONNECTION_STRING."
)
sys.exit(1)
from azure.storage.blob import BlobServiceClient
return {"client": BlobServiceClient.from_connection_string(connection_string)}
return None

0 comments on commit 26e4208

Please sign in to comment.