Skip to content

Commit

Permalink
Update mp for fast pull (#201)
Browse files Browse the repository at this point in the history
* tests with threads work
* use 'asyncio_aws_s3' internally, fixed tests, fixed bugs
* isort/black
  • Loading branch information
kyocum authored Apr 23, 2024
1 parent caaee36 commit 32d8b9e
Show file tree
Hide file tree
Showing 10 changed files with 267 additions and 167 deletions.
2 changes: 1 addition & 1 deletion disdat/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from disdat import logger as _logger
from disdat.data_context import DataContext
from disdat.hyperframe import HyperFrameRecord, LineageRecord, parse_return_val
from disdat.utility.aws_s3 import s3_path_exists
from disdat.utility.asyncio_aws_s3 import s3_path_exists

PROC_ID_TRUNCATE_HASH = 10 # 10 ls hex digits

Expand Down
4 changes: 2 additions & 2 deletions disdat/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,15 @@
# limitations under the License.
#

import configparser
import importlib
import logging
import os
import shutil
import sys
import urllib
import uuid

from six.moves import configparser, urllib

import disdat.config
from disdat import logger as _logger
from disdat import resource
Expand Down
2 changes: 1 addition & 1 deletion disdat/data_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
import disdat.constants as constants
import disdat.hyperframe as hyperframe
import disdat.hyperframe_pb2 as hyperframe_pb2
import disdat.utility.aws_s3 as aws_s3
import disdat.utility.asyncio_aws_s3 as aws_s3
from disdat import logger as _logger
from disdat.common import DisdatConfig

Expand Down
50 changes: 24 additions & 26 deletions disdat/fs.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@

import disdat.common as common
import disdat.hyperframe as hyperframe
import disdat.utility.aws_s3 as aws_s3
import disdat.utility.asyncio_aws_s3 as aws_s3
from disdat import logger as _logger
from disdat.common import CatNoBundleError, DisdatConfig
from disdat.data_context import DataContext
Expand Down Expand Up @@ -84,15 +84,14 @@ def _run_git_cmd(git_dir, git_cmd, get_output=False):

return output


def determine_pipe_version(pipe_root):
"""
def determine_pipe_version(pipe_root):
Given a pipe file path, return the repo status. If they are set, use the environment variables,
otherwise run the git commands.
Args:
pipe_root: path to the root of the pipeline
Returns:
CodeVersion: populated object with the git hash, branch, fetch url, last updated date
and "dirty" status. A pipeline is considered to be dirty if there are modified files
Expand Down Expand Up @@ -551,7 +550,7 @@ def ensure_data_context(self, data_context):
if data_context is None:
data_context = self.curr_context
if data_context is None:
print("No current context. `dsdt switch <othercontext>`")
_logger.info("No current context. `dsdt switch <othercontext>`")
return None
return data_context

Expand Down Expand Up @@ -728,7 +727,7 @@ def cat(self, human_name, uuid=None, tags=None, file=None, data_context=None):
other = data_context.present_hfr(hfr)
if file is not None:
df = data_context.convert_hfr2df(hfr)
print("Saving to file {}".format(file))
_logger.info("Saving to file {}".format(file))
df.to_csv(file, sep=",", index=False)
return other
else:
Expand Down Expand Up @@ -899,7 +898,7 @@ def delete_context(self, fq_context_name, remote, force):
if self.curr_context is not None and (
fq_context_name == self.curr_context_name
):
print(
_logger.info(
"Disdat deleting the current context {}, remember to 'dsdt switch <otherbranch>' afterwords!".format(
fq_context_name
)
Expand Down Expand Up @@ -983,9 +982,9 @@ def switch(self, local_context_name):

if new_context is not None:
self.curr_context = new_context
print("Switched to context {}".format(self.curr_context_name))
_logger.info("Switched to context {}".format(self.curr_context_name))
else:
print("In context {}".format(self.curr_context_name))
_logger.info("In context {}".format(self.curr_context_name))

def commit(self, bundle_name, input_tags, uuid=None, data_context=None):
"""Commit indicates that this is a primary version of this bundle.
Expand Down Expand Up @@ -1025,13 +1024,13 @@ def commit(self, bundle_name, input_tags, uuid=None, data_context=None):
data_context=data_context,
)
else:
print(
_logger.warn(
"Push requires either a human name or a uuid to identify the hyperframe."
)
return None

if hfr is None:
print(
_logger.info(
"No bundle with human name [{}] or uuid [{}] found.".format(
bundle_name, uuid
)
Expand All @@ -1040,7 +1039,7 @@ def commit(self, bundle_name, input_tags, uuid=None, data_context=None):

commit_tag = hfr.get_tag("committed")
if commit_tag is not None and commit_tag == "True":
print(
_logger.info(
"Bundle human name [{}] uuid [{}] already committed.".format(
hfr.pb.human_name, hfr.pb.uuid
)
Expand Down Expand Up @@ -1222,7 +1221,7 @@ def push(
return None

if data_context.remote_ctxt_url is None:
print(
_logger.info(
"Push cannot execute. Local context {} on remote {} not bound.".format(
data_context.local_ctxt, data_context.remote_ctxt
)
Expand All @@ -1245,13 +1244,13 @@ def push(
human_name, tags=tags, data_context=data_context
)
else:
print(
_logger.info(
"Push requires either a human name or a uuid to identify the hyperframe."
)
return None

if hfr is None:
print(
_logger.info(
"Push unable to find committed bundle name [{}] uuid [{}]".format(
human_name, uuid
)
Expand All @@ -1267,21 +1266,21 @@ def push(
if not hyperframe.is_hyperframe_pb_file(src):
to_delete.append(urllib.parse.urlparse(src).path)
except Exception as e:
print("Push unable to copy bundle to branch: {}".format(e))
_logger.warn("Push unable to copy bundle to branch: {}".format(e))
return None

if delocalize:
for f in to_delete:
try:
os.remove(f)
except IOError as e:
print(
_logger.warn(
"fast_push: during delocalization, unable to remove {} due to {}".format(
f, e
)
)

print(
_logger.info(
"Pushed committed bundle {} uuid {} to remote {}".format(
human_name, hfr.pb.uuid, data_context.remote_ctxt_url
)
Expand Down Expand Up @@ -1356,7 +1355,7 @@ def fast_push(self, data_context, delocalize):
try:
os.remove(f)
except IOError as e:
print(
_logger.warn(
"fast_push: during delocalization, unable to remove {} due to {}".format(
f, e
)
Expand Down Expand Up @@ -1453,13 +1452,12 @@ def pull(self, human_name=None, uuid=None, localize=False, data_context=None):
Raise:
UserWarning: If we are not in a valid context.
"""

data_context = self.ensure_data_context(data_context)
if data_context is None:
return

if data_context.remote_ctxt_url is None:
print(
_logger.error(
"Pull cannot execute. Local context {} on remote {} not bound.".format(
data_context.local_ctxt, data_context.remote_ctxt
)
Expand Down Expand Up @@ -1506,7 +1504,7 @@ def pull(self, human_name=None, uuid=None, localize=False, data_context=None):
if s3_uuid != uuid:
continue
else:
print(
_logger.info(
"Found remote bundle with UUID {}, checking local context for duplicates ...".format(
uuid
)
Expand All @@ -1525,7 +1523,7 @@ def pull(self, human_name=None, uuid=None, localize=False, data_context=None):
if human_name != local_hfr.pb.human_name:
continue
else:
print(
_logger.info(
"Found remote bundle with human name {}, uuid {} localizing ...".format(
local_hfr.pb.human_name, local_hfr.pb.uuid
)
Expand All @@ -1543,7 +1541,7 @@ def pull(self, human_name=None, uuid=None, localize=False, data_context=None):
if human_name != hfr_test.pb.human_name:
continue
else:
print(
_logger.info(
"Found remote bundle with human name {}, uuid {} ...".format(
hfr_test.pb.human_name, hfr_test.pb.uuid
)
Expand All @@ -1556,7 +1554,7 @@ def pull(self, human_name=None, uuid=None, localize=False, data_context=None):
local_uuid_dir = os.path.join(data_context.get_object_dir(), s3_uuid)
local_hfr_path = os.path.join(local_uuid_dir, hfr_basename)
if os.path.exists(local_uuid_dir):
print(
_logger.info(
"Pull found existing data in local disdat db at UUID {}, overwriting . . .".format(
s3_uuid
)
Expand Down Expand Up @@ -1687,7 +1685,7 @@ def _parse_date(date_string, throw=False):
else:
date = datetime.strptime(date_string, "%m-%d-%Y")
except ValueError as ve:
print(
_logger.info(
"Unable to parse date, must be like '12-1-2008' or '\"12-1-2008 13:12:05\"'"
)
if not throw:
Expand Down
2 changes: 1 addition & 1 deletion disdat/hyperframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@
import disdat.common as common
from disdat import hyperframe_pb2
from disdat import logger as _logger
from disdat.utility.aws_s3 import s3_path_exists
from disdat.utility.asyncio_aws_s3 import s3_path_exists

HyperFrameTuple = namedtuple("HyperFrameTuple", "columns, links, uuid, tags")

Expand Down
Loading

0 comments on commit 32d8b9e

Please sign in to comment.