Skip to content

Commit

Permalink
Merge pull request #1050 from ttngu207/populate_success_count
Browse files Browse the repository at this point in the history
Returning success count from the `.populate()` call
  • Loading branch information
dimitri-yatsenko authored Oct 9, 2023
2 parents 2a11279 + 18fd619 commit 10511e7
Show file tree
Hide file tree
Showing 3 changed files with 118 additions and 79 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
- Changed - Migrate docs from `https://docs.datajoint.org/python` to `https://datajoint.com/docs/core/datajoint-python`
- Fixed - Updated set_password to work on MySQL 8 - PR [#1106](https://github.com/datajoint/datajoint-python/pull/1106)
- Added - Missing tests for set_password - PR [#1106](https://github.com/datajoint/datajoint-python/pull/1106)
- Changed - Returning success count after the .populate() call - PR [#1050](https://github.com/datajoint/datajoint-python/pull/1050)

### 0.14.1 -- Jun 02, 2023
- Fixed - Fix altering a part table that uses the "master" keyword - PR [#991](https://github.com/datajoint/datajoint-python/pull/991)
Expand Down
179 changes: 100 additions & 79 deletions datajoint/autopopulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,9 @@ def populate(
to be passed down to each ``make()`` call. Computation arguments should be
specified within the pipeline e.g. using a `dj.Lookup` table.
:type make_kwargs: dict, optional
:return: a dict with two keys
"success_count": the count of successful ``make()`` calls in this ``populate()`` call
"error_list": the error list that is filled if `suppress_errors` is True
"""
if self.connection.in_transaction:
raise DataJointError("Populate cannot be called during a transaction.")
Expand Down Expand Up @@ -222,49 +225,62 @@ def handler(signum, frame):

keys = keys[:max_calls]
nkeys = len(keys)
if not nkeys:
return

processes = min(_ for _ in (processes, nkeys, mp.cpu_count()) if _)

error_list = []
populate_kwargs = dict(
suppress_errors=suppress_errors,
return_exception_objects=return_exception_objects,
make_kwargs=make_kwargs,
)
success_list = []

if processes == 1:
for key in (
tqdm(keys, desc=self.__class__.__name__) if display_progress else keys
):
error = self._populate1(key, jobs, **populate_kwargs)
if error is not None:
error_list.append(error)
else:
# spawn multiple processes
self.connection.close() # disconnect parent process from MySQL server
del self.connection._conn.ctx # SSLContext is not pickleable
with mp.Pool(
processes, _initialize_populate, (self, jobs, populate_kwargs)
) as pool, (
tqdm(desc="Processes: ", total=nkeys)
if display_progress
else contextlib.nullcontext()
) as progress_bar:
for error in pool.imap(_call_populate1, keys, chunksize=1):
if error is not None:
error_list.append(error)
if display_progress:
progress_bar.update()
self.connection.connect() # reconnect parent process to MySQL server
if nkeys:
processes = min(_ for _ in (processes, nkeys, mp.cpu_count()) if _)

populate_kwargs = dict(
suppress_errors=suppress_errors,
return_exception_objects=return_exception_objects,
make_kwargs=make_kwargs,
)

if processes == 1:
for key in (
tqdm(keys, desc=self.__class__.__name__)
if display_progress
else keys
):
status = self._populate1(key, jobs, **populate_kwargs)
if status is True:
success_list.append(1)
elif isinstance(status, tuple):
error_list.append(status)
else:
assert status is False
else:
# spawn multiple processes
self.connection.close() # disconnect parent process from MySQL server
del self.connection._conn.ctx # SSLContext is not pickleable
with mp.Pool(
processes, _initialize_populate, (self, jobs, populate_kwargs)
) as pool, (
tqdm(desc="Processes: ", total=nkeys)
if display_progress
else contextlib.nullcontext()
) as progress_bar:
for status in pool.imap(_call_populate1, keys, chunksize=1):
if status is True:
success_list.append(1)
elif isinstance(status, tuple):
error_list.append(status)
else:
assert status is False
if display_progress:
progress_bar.update()
self.connection.connect() # reconnect parent process to MySQL server

# restore original signal handler:
if reserve_jobs:
signal.signal(signal.SIGTERM, old_handler)

if suppress_errors:
return error_list
return {
"success_count": sum(success_list),
"error_list": error_list,
}

def _populate1(
self, key, jobs, suppress_errors, return_exception_objects, make_kwargs=None
Expand All @@ -275,55 +291,60 @@ def _populate1(
:param key: dict specifying job to populate
:param suppress_errors: bool if errors should be suppressed and returned
:param return_exception_objects: if True, errors must be returned as objects
:return: (key, error) when suppress_errors=True, otherwise None
:return: (key, error) when suppress_errors=True,
True if successfully invoke one `make()` call, otherwise False
"""
make = self._make_tuples if hasattr(self, "_make_tuples") else self.make

if jobs is None or jobs.reserve(self.target.table_name, self._job_key(key)):
self.connection.start_transaction()
if key in self.target: # already populated
if jobs is not None and not jobs.reserve(
self.target.table_name, self._job_key(key)
):
return False

self.connection.start_transaction()
if key in self.target: # already populated
self.connection.cancel_transaction()
if jobs is not None:
jobs.complete(self.target.table_name, self._job_key(key))
return False

logger.debug(f"Making {key} -> {self.target.full_table_name}")
self.__class__._allow_insert = True
try:
make(dict(key), **(make_kwargs or {}))
except (KeyboardInterrupt, SystemExit, Exception) as error:
try:
self.connection.cancel_transaction()
if jobs is not None:
jobs.complete(self.target.table_name, self._job_key(key))
except LostConnectionError:
pass
error_message = "{exception}{msg}".format(
exception=error.__class__.__name__,
msg=": " + str(error) if str(error) else "",
)
logger.debug(
f"Error making {key} -> {self.target.full_table_name} - {error_message}"
)
if jobs is not None:
# show error name and error message (if any)
jobs.error(
self.target.table_name,
self._job_key(key),
error_message=error_message,
error_stack=traceback.format_exc(),
)
if not suppress_errors or isinstance(error, SystemExit):
raise
else:
logger.debug(f"Making {key} -> {self.target.full_table_name}")
self.__class__._allow_insert = True
try:
make(dict(key), **(make_kwargs or {}))
except (KeyboardInterrupt, SystemExit, Exception) as error:
try:
self.connection.cancel_transaction()
except LostConnectionError:
pass
error_message = "{exception}{msg}".format(
exception=error.__class__.__name__,
msg=": " + str(error) if str(error) else "",
)
logger.debug(
f"Error making {key} -> {self.target.full_table_name} - {error_message}"
)
if jobs is not None:
# show error name and error message (if any)
jobs.error(
self.target.table_name,
self._job_key(key),
error_message=error_message,
error_stack=traceback.format_exc(),
)
if not suppress_errors or isinstance(error, SystemExit):
raise
else:
logger.error(error)
return key, error if return_exception_objects else error_message
else:
self.connection.commit_transaction()
logger.debug(
f"Success making {key} -> {self.target.full_table_name}"
)
if jobs is not None:
jobs.complete(self.target.table_name, self._job_key(key))
finally:
self.__class__._allow_insert = False
logger.error(error)
return key, error if return_exception_objects else error_message
else:
self.connection.commit_transaction()
logger.debug(f"Success making {key} -> {self.target.full_table_name}")
if jobs is not None:
jobs.complete(self.target.table_name, self._job_key(key))
return True
finally:
self.__class__._allow_insert = False

def progress(self, *restrictions, display=False):
"""
Expand Down
17 changes: 17 additions & 0 deletions tests_old/test_autopopulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,23 @@ def test_populate(self):
assert_true(self.ephys)
assert_true(self.channel)

def test_populate_with_success_count(self):
# test simple populate
assert_true(self.subject, "root tables are empty")
assert_false(self.experiment, "table already filled?")
ret = self.experiment.populate()
success_count = ret["success_count"]
assert_equal(len(self.experiment.key_source & self.experiment), success_count)

# test restricted populate
assert_false(self.trial, "table already filled?")
restriction = self.subject.proj(animal="subject_id").fetch("KEY")[0]
d = self.trial.connection.dependencies
d.load()
ret = self.trial.populate(restriction, suppress_errors=True)
success_count = ret["success_count"]
assert_equal(len(self.trial.key_source & self.trial), success_count)

def test_populate_exclude_error_and_ignore_jobs(self):
# test simple populate
assert_true(self.subject, "root tables are empty")
Expand Down

0 comments on commit 10511e7

Please sign in to comment.