Skip to content

Commit

Permalink
Handle report in transition
Browse files Browse the repository at this point in the history
  • Loading branch information
jakirkham committed Dec 10, 2020
1 parent 1580aa5 commit 75e18ec
Showing 1 changed file with 58 additions and 42 deletions.
100 changes: 58 additions & 42 deletions distributed/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -4578,6 +4578,7 @@ def transition_released_waiting(self, key):
ts: TaskState = self.tasks[key]
dts: TaskState
worker_msgs: dict = {}
report_msg: dict = {}

if self.validate:
assert ts._run_spec
Expand All @@ -4587,7 +4588,7 @@ def transition_released_waiting(self, key):
assert not any([dts.state == "forgotten" for dts in ts._dependencies])

if ts._has_lost_dependencies:
return {key: "forgotten"}, worker_msgs, False
return {key: "forgotten"}, worker_msgs, False, report_msg

ts.state = "waiting"

Expand All @@ -4598,7 +4599,7 @@ def transition_released_waiting(self, key):
if dts._exception_blame:
ts._exception_blame = dts._exception_blame
recommendations[key] = "erred"
return recommendations, worker_msgs, False
return recommendations, worker_msgs, False, report_msg

for dts in ts._dependencies:
dep = dts._key
Expand All @@ -4618,7 +4619,7 @@ def transition_released_waiting(self, key):
self.unrunnable.add(ts)
ts.state = "no-worker"

return recommendations, worker_msgs, False
return recommendations, worker_msgs, False, report_msg
except Exception as e:
logger.exception(e)
if LOG_PDB:
Expand All @@ -4632,6 +4633,7 @@ def transition_no_worker_waiting(self, key):
ts: TaskState = self.tasks[key]
dts: TaskState
worker_msgs: dict = {}
report_msg: dict = {}

if self.validate:
assert ts in self.unrunnable
Expand All @@ -4642,7 +4644,7 @@ def transition_no_worker_waiting(self, key):
self.unrunnable.remove(ts)

if ts._has_lost_dependencies:
return {key: "forgotten"}, worker_msgs, False
return {key: "forgotten"}, worker_msgs, False, report_msg

recommendations = {}

Expand All @@ -4664,7 +4666,7 @@ def transition_no_worker_waiting(self, key):
self.unrunnable.add(ts)
ts.state = "no-worker"

return recommendations, worker_msgs, False
return recommendations, worker_msgs, False, report_msg
except Exception as e:
logger.exception(e)
if LOG_PDB:
Expand Down Expand Up @@ -4719,6 +4721,7 @@ def transition_waiting_processing(self, key):
ts: TaskState = self.tasks[key]
dts: TaskState
worker_msgs: dict = {}
report_msg: dict = {}

if self.validate:
assert not ts._waiting_on
Expand All @@ -4731,7 +4734,7 @@ def transition_waiting_processing(self, key):

ws: WorkerState = self.decide_worker(ts)
if ws is None:
return {}, worker_msgs, False
return {}, worker_msgs, False, report_msg
worker = ws._address

duration = self.get_task_duration(ts)
Expand All @@ -4754,7 +4757,7 @@ def transition_waiting_processing(self, key):

worker_msgs[worker] = self.task_to_msg(ts, duration)

return {}, worker_msgs, False
return {}, worker_msgs, False, report_msg
except Exception as e:
logger.exception(e)
if LOG_PDB:
Expand All @@ -4768,6 +4771,7 @@ def transition_waiting_memory(self, key, nbytes=None, worker=None, **kwargs):
ws: WorkerState = self.workers[worker]
ts: TaskState = self.tasks[key]
worker_msgs: dict = {}
report_msg: dict = {}

if self.validate:
assert not ts._processing_on
Expand All @@ -4783,15 +4787,14 @@ def transition_waiting_memory(self, key, nbytes=None, worker=None, **kwargs):

recommendations = {}

msg = self._add_to_memory(ts, ws, recommendations, **kwargs)
self.report(msg)
report_msg = self._add_to_memory(ts, ws, recommendations, **kwargs)

if self.validate:
assert not ts._processing_on
assert not ts._waiting_on
assert ts._who_has

return recommendations, worker_msgs, False
return recommendations, worker_msgs, False, report_msg
except Exception as e:
logger.exception(e)
if LOG_PDB:
Expand All @@ -4813,6 +4816,7 @@ def transition_processing_memory(
ws: WorkerState
wws: WorkerState
worker_msgs: dict = {}
report_msg: dict = {}
try:
ts: TaskState = self.tasks[key]
assert worker
Expand All @@ -4829,7 +4833,7 @@ def transition_processing_memory(

ws = self.workers.get(worker)
if ws is None:
return {key: "released"}, worker_msgs, False
return {key: "released"}, worker_msgs, False, report_msg

if ws != ts._processing_on: # someone else has this task
logger.info(
Expand All @@ -4839,7 +4843,7 @@ def transition_processing_memory(
ws,
key,
)
return {}, worker_msgs, False
return {}, worker_msgs, False, report_msg

if startstops:
L = list()
Expand Down Expand Up @@ -4902,7 +4906,7 @@ def transition_processing_memory(
assert not ts._processing_on
assert not ts._waiting_on

return recommendations, worker_msgs, False
return recommendations, worker_msgs, False, report_msg
except Exception as e:
logger.exception(e)
if LOG_PDB:
Expand All @@ -4917,6 +4921,7 @@ def transition_memory_released(self, key, safe=False):
ts: TaskState = self.tasks[key]
dts: TaskState
worker_msgs: dict = {}
report_msg: dict = {}

if self.validate:
assert not ts._waiting_on
Expand All @@ -4934,6 +4939,7 @@ def transition_memory_released(self, key, safe=False):
{ts._key: "erred"},
worker_msgs,
False,
report_msg,
) # don't try to recreate

recommendations = {}
Expand All @@ -4959,7 +4965,7 @@ def transition_memory_released(self, key, safe=False):

ts.state = "released"

self.report({"op": "lost-data", "key": key})
report_msg = {"op": "lost-data", "key": key}

if not ts._run_spec: # pure data
recommendations[key] = "forgotten"
Expand All @@ -4971,7 +4977,7 @@ def transition_memory_released(self, key, safe=False):
if self.validate:
assert not ts._waiting_on

return recommendations, worker_msgs, False
return recommendations, worker_msgs, False, report_msg
except Exception as e:
logger.exception(e)
if LOG_PDB:
Expand All @@ -4986,6 +4992,7 @@ def transition_released_erred(self, key):
dts: TaskState
failing_ts: TaskState
worker_msgs: dict = {}
report_msg: dict = {}

if self.validate:
with log_errors(pdb=LOG_PDB):
Expand All @@ -5003,19 +5010,17 @@ def transition_released_erred(self, key):
if not dts._who_has:
recommendations[dts._key] = "erred"

self.report(
{
"op": "task-erred",
"key": key,
"exception": failing_ts._exception,
"traceback": failing_ts._traceback,
}
)
report_msg = {
"op": "task-erred",
"key": key,
"exception": failing_ts._exception,
"traceback": failing_ts._traceback,
}

ts.state = "erred"

# TODO: waiting data?
return recommendations, worker_msgs, False
return recommendations, worker_msgs, False, report_msg
except Exception as e:
logger.exception(e)
if LOG_PDB:
Expand All @@ -5029,6 +5034,7 @@ def transition_erred_released(self, key):
ts: TaskState = self.tasks[key]
dts: TaskState
worker_msgs: dict = {}
report_msg: dict = {}

if self.validate:
with log_errors(pdb=LOG_PDB):
Expand All @@ -5048,10 +5054,10 @@ def transition_erred_released(self, key):
if dts.state == "erred":
recommendations[dts._key] = "waiting"

self.report({"op": "task-retried", "key": key})
report_msg = {"op": "task-retried", "key": key}
ts.state = "released"

return recommendations, worker_msgs, False
return recommendations, worker_msgs, False, report_msg
except Exception as e:
logger.exception(e)
if LOG_PDB:
Expand All @@ -5064,6 +5070,7 @@ def transition_waiting_released(self, key):
try:
ts: TaskState = self.tasks[key]
worker_msgs: dict = {}
report_msg: dict = {}

if self.validate:
assert not ts._who_has
Expand All @@ -5089,7 +5096,7 @@ def transition_waiting_released(self, key):
else:
ts._waiters.clear()

return recommendations, worker_msgs, False
return recommendations, worker_msgs, False, report_msg
except Exception as e:
logger.exception(e)
if LOG_PDB:
Expand All @@ -5103,6 +5110,7 @@ def transition_processing_released(self, key):
ts: TaskState = self.tasks[key]
dts: TaskState
worker_msgs: dict = {}
report_msg: dict = {}

if self.validate:
assert ts._processing_on
Expand Down Expand Up @@ -5135,7 +5143,7 @@ def transition_processing_released(self, key):
if self.validate:
assert not ts._processing_on

return recommendations, worker_msgs, False
return recommendations, worker_msgs, False, report_msg
except Exception as e:
logger.exception(e)
if LOG_PDB:
Expand All @@ -5153,6 +5161,7 @@ def transition_processing_erred(
dts: TaskState
failing_ts: TaskState
worker_msgs: dict = {}
report_msg: dict = {}

if self.validate:
assert cause or ts._exception_blame
Expand Down Expand Up @@ -5192,14 +5201,12 @@ def transition_processing_erred(

ts.state = "erred"

self.report(
{
"op": "task-erred",
"key": key,
"exception": failing_ts._exception,
"traceback": failing_ts._traceback,
}
)
report_msg = {
"op": "task-erred",
"key": key,
"exception": failing_ts._exception,
"traceback": failing_ts._traceback,
}

cs: ClientState = self.clients["fire-and-forget"]
if ts in cs._wants_what:
Expand All @@ -5208,7 +5215,7 @@ def transition_processing_erred(
if self.validate:
assert not ts._processing_on

return recommendations, worker_msgs, False
return recommendations, worker_msgs, False, report_msg, report_msg
except Exception as e:
logger.exception(e)
if LOG_PDB:
Expand All @@ -5222,6 +5229,7 @@ def transition_no_worker_released(self, key):
ts: TaskState = self.tasks[key]
dts: TaskState
worker_msgs: dict = {}
report_msg: dict = {}

if self.validate:
assert self.tasks[key].state == "no-worker"
Expand All @@ -5236,7 +5244,7 @@ def transition_no_worker_released(self, key):

ts._waiters.clear()

return {}, worker_msgs, False
return {}, worker_msgs, False, report_msg
except Exception as e:
logger.exception(e)
if LOG_PDB:
Expand Down Expand Up @@ -5302,6 +5310,7 @@ def transition_memory_forgotten(self, key):
try:
ts: TaskState = self.tasks[key]
worker_msgs: dict = {}
report_msg: dict = {}

if self.validate:
assert ts.state == "memory"
Expand Down Expand Up @@ -5329,7 +5338,7 @@ def transition_memory_forgotten(self, key):

self.remove_key(key)

return recommendations, worker_msgs, True
return recommendations, worker_msgs, True, report_msg
except Exception as e:
logger.exception(e)
if LOG_PDB:
Expand All @@ -5342,6 +5351,7 @@ def transition_released_forgotten(self, key):
try:
ts: TaskState = self.tasks[key]
worker_msgs: dict = {}
report_msg: dict = {}

if self.validate:
assert ts.state in ("released", "erred")
Expand All @@ -5365,7 +5375,7 @@ def transition_released_forgotten(self, key):

self.remove_key(key)

return recommendations, worker_msgs, True
return recommendations, worker_msgs, True, report_msg
except Exception as e:
logger.exception(e)
if LOG_PDB:
Expand Down Expand Up @@ -5393,6 +5403,7 @@ def transition(self, key, finish, *args, **kwargs):
ts: TaskState
worker_msgs: dict
report_key: bool
report_msg: dict
try:
try:
ts = self.tasks[key]
Expand All @@ -5408,16 +5419,19 @@ def transition(self, key, finish, *args, **kwargs):

worker_msgs = {}
report_key = False
report_msg = {}
if (start, finish) in self._transitions:
func = self._transitions[start, finish]
recommendations, worker_msgs, report_key = func(key, *args, **kwargs)
recommendations, worker_msgs, report_key, report_msg = func(
key, *args, **kwargs
)
elif "released" not in (start, finish):
func = self._transitions["released", finish]
assert not args and not kwargs
a = self.transition(key, "released")
if key in a:
func = self._transitions["released", a[key]]
b, worker_msgs, report_key = func(key)
b, worker_msgs, report_key, report_msg = func(key)
a = a.copy()
a.update(b)
recommendations = a
Expand All @@ -5431,6 +5445,8 @@ def transition(self, key, finish, *args, **kwargs):
self.worker_send(worker, msg)
if report_key:
self.report_on_key(ts=ts)
if report_msg:
self.report(report_msg)

finish2 = ts.state
self.transition_log.append((key, start, finish2, recommendations, time()))
Expand Down

0 comments on commit 75e18ec

Please sign in to comment.