Skip to content

Commit

Permalink
Add support for colors.
Browse files Browse the repository at this point in the history
  • Loading branch information
syamajala committed Sep 24, 2018
1 parent 87edead commit 3d8b6dd
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 10 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -52,3 +52,5 @@ docs/_build/

# PyBuilder
target/

.pytest_cache
5 changes: 3 additions & 2 deletions graphkit/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def __init__(self, **kwargs):
self.needs = kwargs.get('needs')
self.provides = kwargs.get('provides')
self.params = kwargs.get('params', {})
self.color = kwargs.get('color', None)

# call _after_init as final step of initialization
self._after_init()
Expand Down Expand Up @@ -151,8 +152,8 @@ def __init__(self, **kwargs):
self.net = kwargs.pop('net')
Operation.__init__(self, **kwargs)

def _compute(self, named_inputs, outputs=None):
return self.net.compute(outputs, named_inputs)
def _compute(self, named_inputs, outputs=None, color=None):
return self.net.compute(outputs, named_inputs, color)

def __call__(self, *args, **kwargs):
return self._compute(*args, **kwargs)
Expand Down
6 changes: 6 additions & 0 deletions graphkit/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,9 @@ class operation(Operation):
A dict of key/value pairs representing constant parameters
associated with your operation. These can correspond to either
``args`` or ``kwargs`` of ``fn`.
:param str color:
A color for the node in the computation graph.
"""

def __init__(self, fn=None, **kwargs):
Expand All @@ -93,6 +96,9 @@ def _normalize_kwargs(self, kwargs):
if type(kwargs['params']) is not dict:
kwargs['params'] = {}

if 'color' in kwargs and type(kwargs['color']) == str:
assert kwargs['color'], "empty string provided for `color` parameters"

return kwargs

def __call__(self, fn=None, **kwargs):
Expand Down
33 changes: 25 additions & 8 deletions graphkit/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,9 @@ def add_op(self, operation):
for n in operation.needs:
self.graph.add_edge(DataPlaceholderNode(n), operation)

if operation.color:
self.graph.nodes[operation]['color'] = operation.color

# add nodes and edges to graph describing what this layer provides
for p in operation.provides:
self.graph.add_edge(operation, DataPlaceholderNode(p))
Expand All @@ -93,6 +96,7 @@ def show_layers(self):
print("layer_name: ", name)
print("\t", "needs: ", step.needs)
print("\t", "provides: ", step.provides)
print("\t", "color: ", step.color)
print("")

def compile(self):
Expand Down Expand Up @@ -136,7 +140,7 @@ def compile(self):
else:
raise TypeError("Unrecognized network graph node")

def _find_necessary_steps(self, outputs, inputs):
def _find_necessary_steps(self, outputs, inputs, color=None):
"""
Determines what graph steps need to be run to get to the requested
outputs from the provided inputs. Eliminates steps that come before
Expand All @@ -152,6 +156,9 @@ def _find_necessary_steps(self, outputs, inputs):
:param dict inputs:
A dictionary mapping names to values for all provided inputs.
:param str color:
A color to filter nodes by.
:returns:
Returns a list of all the steps that need to be run for the
provided inputs and requested outputs.
Expand All @@ -160,7 +167,7 @@ def _find_necessary_steps(self, outputs, inputs):
# return steps if it has already been computed before for this set of inputs and outputs
outputs = tuple(sorted(outputs)) if isinstance(outputs, (list, set)) else outputs
inputs_keys = tuple(sorted(inputs.keys()))
cache_key = (inputs_keys, outputs)
cache_key = (inputs_keys, outputs, color)
if cache_key in self._necessary_steps_cache:
return self._necessary_steps_cache[cache_key]

Expand Down Expand Up @@ -199,15 +206,23 @@ def _find_necessary_steps(self, outputs, inputs):
# Get rid of the unnecessary nodes from the set of necessary ones.
necessary_nodes -= unnecessary_nodes

necessary_steps = [step for step in self.steps if step in necessary_nodes]
necessary_steps = []

for step in self.steps:
if isinstance(step, Operation):
if step.color == color and step in necessary_nodes:
necessary_steps.append(step)
else:
if step in necessary_nodes:
necessary_steps.append(step)

# save this result in a precomputed cache for future lookup
self._necessary_steps_cache[cache_key] = necessary_steps

# Return an ordered list of the needed steps.
return necessary_steps

def compute(self, outputs, named_inputs):
def compute(self, outputs, named_inputs, color=None):
"""
This method runs the graph one operation at a time in a single thread
Any inputs to the network must be passed in by name.
Expand All @@ -222,6 +237,8 @@ def compute(self, outputs, named_inputs):
and the values are the concrete values you
want to set for the data node.
:param str color: Only the subgraph of nodes with color will be evaluted.
:returns: a dictionary of output data objects, keyed by name.
"""

Expand All @@ -238,7 +255,7 @@ def compute(self, outputs, named_inputs):

# Find the subset of steps we need to run to get to the requested
# outputs from the provided inputs.
all_steps = self._find_necessary_steps(outputs, named_inputs)
all_steps = self._find_necessary_steps(outputs, named_inputs, color)

self.times = {}
for step in all_steps:
Expand Down Expand Up @@ -281,9 +298,9 @@ def compute(self, outputs, named_inputs):
raise TypeError("Unrecognized instruction.")

if not outputs:
# Return the whole cache as output, including input and
# intermediate data nodes.
return cache
# Return cache as output including intermediate data nodes,
# but excluding input.
return {k: cache[k] for k in set(cache) - set(named_inputs)}

else:
# Filter outputs to just return what's needed.
Expand Down

0 comments on commit 3d8b6dd

Please sign in to comment.