Skip to content

Commit

Permalink
Handle report in transition
Browse files Browse the repository at this point in the history
  • Loading branch information
jakirkham committed Dec 10, 2020
1 parent e42baf8 commit 9e5f2f3
Showing 1 changed file with 58 additions and 42 deletions.
100 changes: 58 additions & 42 deletions distributed/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -4592,6 +4592,7 @@ def transition_released_waiting(self, key):
ts: TaskState = tasks[key]
dts: TaskState
worker_msgs: dict = {}
report_msg: dict = {}

if self.validate:
assert ts._run_spec
Expand All @@ -4601,7 +4602,7 @@ def transition_released_waiting(self, key):
assert not any([dts._state == "forgotten" for dts in ts._dependencies])

if ts._has_lost_dependencies:
return {key: "forgotten"}, worker_msgs, False
return {key: "forgotten"}, worker_msgs, False, report_msg

ts.state = "waiting"

Expand All @@ -4612,7 +4613,7 @@ def transition_released_waiting(self, key):
if dts._exception_blame:
ts._exception_blame = dts._exception_blame
recommendations[key] = "erred"
return recommendations, worker_msgs, False
return recommendations, worker_msgs, False, report_msg

for dts in ts._dependencies:
dep = dts._key
Expand All @@ -4632,7 +4633,7 @@ def transition_released_waiting(self, key):
self.unrunnable.add(ts)
ts.state = "no-worker"

return recommendations, worker_msgs, False
return recommendations, worker_msgs, False, report_msg
except Exception as e:
logger.exception(e)
if LOG_PDB:
Expand All @@ -4647,6 +4648,7 @@ def transition_no_worker_waiting(self, key):
ts: TaskState = tasks[key]
dts: TaskState
worker_msgs: dict = {}
report_msg: dict = {}

if self.validate:
assert ts in self.unrunnable
Expand All @@ -4657,7 +4659,7 @@ def transition_no_worker_waiting(self, key):
self.unrunnable.remove(ts)

if ts._has_lost_dependencies:
return {key: "forgotten"}, worker_msgs, False
return {key: "forgotten"}, worker_msgs, False, report_msg

recommendations: dict = {}

Expand All @@ -4679,7 +4681,7 @@ def transition_no_worker_waiting(self, key):
self.unrunnable.add(ts)
ts.state = "no-worker"

return recommendations, worker_msgs, False
return recommendations, worker_msgs, False, report_msg
except Exception as e:
logger.exception(e)
if LOG_PDB:
Expand Down Expand Up @@ -4735,6 +4737,7 @@ def transition_waiting_processing(self, key):
ts: TaskState = tasks[key]
dts: TaskState
worker_msgs: dict = {}
report_msg: dict = {}

if self.validate:
assert not ts._waiting_on
Expand All @@ -4747,7 +4750,7 @@ def transition_waiting_processing(self, key):

ws: WorkerState = self.decide_worker(ts)
if ws is None:
return {}, worker_msgs, False
return {}, worker_msgs, False, report_msg
worker = ws._address

duration = self.get_task_duration(ts)
Expand All @@ -4770,7 +4773,7 @@ def transition_waiting_processing(self, key):

worker_msgs[worker] = self.task_to_msg(ts, duration)

return {}, worker_msgs, False
return {}, worker_msgs, False, report_msg
except Exception as e:
logger.exception(e)
if LOG_PDB:
Expand All @@ -4785,6 +4788,7 @@ def transition_waiting_memory(self, key, nbytes=None, worker=None, **kwargs):
tasks: dict = self.tasks
ts: TaskState = tasks[key]
worker_msgs: dict = {}
report_msg: dict = {}

if self.validate:
assert not ts._processing_on
Expand All @@ -4800,15 +4804,14 @@ def transition_waiting_memory(self, key, nbytes=None, worker=None, **kwargs):

recommendations: dict = {}

msg = self._add_to_memory(ts, ws, recommendations, **kwargs)
self.report(msg)
report_msg = self._add_to_memory(ts, ws, recommendations, **kwargs)

if self.validate:
assert not ts._processing_on
assert not ts._waiting_on
assert ts._who_has

return recommendations, worker_msgs, False
return recommendations, worker_msgs, False, report_msg
except Exception as e:
logger.exception(e)
if LOG_PDB:
Expand All @@ -4830,6 +4833,7 @@ def transition_processing_memory(
ws: WorkerState
wws: WorkerState
worker_msgs: dict = {}
report_msg: dict = {}
try:
tasks: dict = self.tasks
ts: TaskState = tasks[key]
Expand All @@ -4847,7 +4851,7 @@ def transition_processing_memory(

ws = self.workers.get(worker)
if ws is None:
return {key: "released"}, worker_msgs, False
return {key: "released"}, worker_msgs, False, report_msg

if ws != ts._processing_on: # someone else has this task
logger.info(
Expand All @@ -4857,7 +4861,7 @@ def transition_processing_memory(
ws,
key,
)
return {}, worker_msgs, False
return {}, worker_msgs, False, report_msg

if startstops:
L = list()
Expand Down Expand Up @@ -4920,7 +4924,7 @@ def transition_processing_memory(
assert not ts._processing_on
assert not ts._waiting_on

return recommendations, worker_msgs, False
return recommendations, worker_msgs, False, report_msg
except Exception as e:
logger.exception(e)
if LOG_PDB:
Expand All @@ -4936,6 +4940,7 @@ def transition_memory_released(self, key, safe=False):
ts: TaskState = tasks[key]
dts: TaskState
worker_msgs: dict = {}
report_msg: dict = {}

if self.validate:
assert not ts._waiting_on
Expand All @@ -4953,6 +4958,7 @@ def transition_memory_released(self, key, safe=False):
{ts._key: "erred"},
worker_msgs,
False,
report_msg,
) # don't try to recreate

recommendations: dict = {}
Expand All @@ -4978,7 +4984,7 @@ def transition_memory_released(self, key, safe=False):

ts.state = "released"

self.report({"op": "lost-data", "key": key})
report_msg = {"op": "lost-data", "key": key}

if not ts._run_spec: # pure data
recommendations[key] = "forgotten"
Expand All @@ -4990,7 +4996,7 @@ def transition_memory_released(self, key, safe=False):
if self.validate:
assert not ts._waiting_on

return recommendations, worker_msgs, False
return recommendations, worker_msgs, False, report_msg
except Exception as e:
logger.exception(e)
if LOG_PDB:
Expand All @@ -5006,6 +5012,7 @@ def transition_released_erred(self, key):
dts: TaskState
failing_ts: TaskState
worker_msgs: dict = {}
report_msg: dict = {}

if self.validate:
with log_errors(pdb=LOG_PDB):
Expand All @@ -5023,19 +5030,17 @@ def transition_released_erred(self, key):
if not dts._who_has:
recommendations[dts._key] = "erred"

self.report(
{
"op": "task-erred",
"key": key,
"exception": failing_ts._exception,
"traceback": failing_ts._traceback,
}
)
report_msg = {
"op": "task-erred",
"key": key,
"exception": failing_ts._exception,
"traceback": failing_ts._traceback,
}

ts.state = "erred"

# TODO: waiting data?
return recommendations, worker_msgs, False
return recommendations, worker_msgs, False, report_msg
except Exception as e:
logger.exception(e)
if LOG_PDB:
Expand All @@ -5050,6 +5055,7 @@ def transition_erred_released(self, key):
ts: TaskState = tasks[key]
dts: TaskState
worker_msgs: dict = {}
report_msg: dict = {}

if self.validate:
with log_errors(pdb=LOG_PDB):
Expand All @@ -5069,10 +5075,10 @@ def transition_erred_released(self, key):
if dts._state == "erred":
recommendations[dts._key] = "waiting"

self.report({"op": "task-retried", "key": key})
report_msg = {"op": "task-retried", "key": key}
ts.state = "released"

return recommendations, worker_msgs, False
return recommendations, worker_msgs, False, report_msg
except Exception as e:
logger.exception(e)
if LOG_PDB:
Expand All @@ -5086,6 +5092,7 @@ def transition_waiting_released(self, key):
tasks: dict = self.tasks
ts: TaskState = tasks[key]
worker_msgs: dict = {}
report_msg: dict = {}

if self.validate:
assert not ts._who_has
Expand All @@ -5111,7 +5118,7 @@ def transition_waiting_released(self, key):
else:
ts._waiters.clear()

return recommendations, worker_msgs, False
return recommendations, worker_msgs, False, report_msg
except Exception as e:
logger.exception(e)
if LOG_PDB:
Expand All @@ -5126,6 +5133,7 @@ def transition_processing_released(self, key):
ts: TaskState = tasks[key]
dts: TaskState
worker_msgs: dict = {}
report_msg: dict = {}

if self.validate:
assert ts._processing_on
Expand Down Expand Up @@ -5158,7 +5166,7 @@ def transition_processing_released(self, key):
if self.validate:
assert not ts._processing_on

return recommendations, worker_msgs, False
return recommendations, worker_msgs, False, report_msg
except Exception as e:
logger.exception(e)
if LOG_PDB:
Expand All @@ -5177,6 +5185,7 @@ def transition_processing_erred(
dts: TaskState
failing_ts: TaskState
worker_msgs: dict = {}
report_msg: dict = {}

if self.validate:
assert cause or ts._exception_blame
Expand Down Expand Up @@ -5216,14 +5225,12 @@ def transition_processing_erred(

ts.state = "erred"

self.report(
{
"op": "task-erred",
"key": key,
"exception": failing_ts._exception,
"traceback": failing_ts._traceback,
}
)
report_msg = {
"op": "task-erred",
"key": key,
"exception": failing_ts._exception,
"traceback": failing_ts._traceback,
}

cs: ClientState = self.clients["fire-and-forget"]
if ts in cs._wants_what:
Expand All @@ -5232,7 +5239,7 @@ def transition_processing_erred(
if self.validate:
assert not ts._processing_on

return recommendations, worker_msgs, False
return recommendations, worker_msgs, False, report_msg
except Exception as e:
logger.exception(e)
if LOG_PDB:
Expand All @@ -5247,6 +5254,7 @@ def transition_no_worker_released(self, key):
ts: TaskState = tasks[key]
dts: TaskState
worker_msgs: dict = {}
report_msg: dict = {}

if self.validate:
assert self.tasks[key].state == "no-worker"
Expand All @@ -5261,7 +5269,7 @@ def transition_no_worker_released(self, key):

ts._waiters.clear()

return {}, worker_msgs, False
return {}, worker_msgs, False, report_msg
except Exception as e:
logger.exception(e)
if LOG_PDB:
Expand Down Expand Up @@ -5330,6 +5338,7 @@ def transition_memory_forgotten(self, key):
tasks = self.tasks
ts: TaskState = tasks[key]
worker_msgs: dict = {}
report_msg: dict = {}

if self.validate:
assert ts._state == "memory"
Expand Down Expand Up @@ -5357,7 +5366,7 @@ def transition_memory_forgotten(self, key):

self.remove_key(key)

return recommendations, worker_msgs, True
return recommendations, worker_msgs, True, report_msg
except Exception as e:
logger.exception(e)
if LOG_PDB:
Expand All @@ -5371,6 +5380,7 @@ def transition_released_forgotten(self, key):
tasks: dict = self.tasks
ts: TaskState = tasks[key]
worker_msgs: dict = {}
report_msg: dict = {}

if self.validate:
assert ts._state in ("released", "erred")
Expand All @@ -5394,7 +5404,7 @@ def transition_released_forgotten(self, key):

self.remove_key(key)

return recommendations, worker_msgs, True
return recommendations, worker_msgs, True, report_msg
except Exception as e:
logger.exception(e)
if LOG_PDB:
Expand Down Expand Up @@ -5422,6 +5432,7 @@ def transition(self, key, finish, *args, **kwargs):
ts: TaskState
worker_msgs: dict
report_key: bool
report_msg: dict
try:
try:
ts = self.tasks[key]
Expand All @@ -5438,16 +5449,19 @@ def transition(self, key, finish, *args, **kwargs):
recommendations: dict = {}
worker_msgs = {}
report_key = False
report_msg = {}
if (start, finish) in self._transitions:
func = self._transitions[start, finish]
recommendations, worker_msgs, report_key = func(key, *args, **kwargs)
recommendations, worker_msgs, report_key, report_msg = func(
key, *args, **kwargs
)
elif "released" not in (start, finish):
func = self._transitions["released", finish]
assert not args and not kwargs
a = self.transition(key, "released")
if key in a:
func = self._transitions["released", a[key]]
b, worker_msgs, report_key = func(key)
b, worker_msgs, report_key, report_msg = func(key)
a = a.copy()
a.update(b)
recommendations = a
Expand All @@ -5461,6 +5475,8 @@ def transition(self, key, finish, *args, **kwargs):
self.worker_send(worker, msg)
if report_key:
self.report_on_key(ts=ts)
if report_msg:
self.report(report_msg)

finish2 = ts._state
self.transition_log.append((key, start, finish2, recommendations, time()))
Expand Down

0 comments on commit 9e5f2f3

Please sign in to comment.