Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add states/nodes to MarkovChain/DiGraph #237

Merged
merged 23 commits into from
Apr 12, 2016
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 87 additions & 12 deletions quantecon/graph_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,16 @@
from fractions import gcd


# Decorator for *_components properties
def annotate_nodes(func):
def new_func(self):
list_of_components = func(self)
if self.node_labels is not None:
return [self.node_labels[c] for c in list_of_components]
return list_of_components
return new_func


class DiGraph(object):
r"""
Class for a directed graph. It stores useful information about the
Expand All @@ -27,6 +37,11 @@ class DiGraph(object):
weighted : bool, optional(default=False)
Whether to treat `adj_matrix` as a weighted adjacency matrix.

node_labels : array_like(default=None)
Array_like of length n containing the labels associated with the
nodes, which must be homogeneous in type. If None, the labels
default to integers 0 through n-1.

Attributes
----------
csgraph : scipy.sparse.csr_matrix
Expand All @@ -38,16 +53,26 @@ class DiGraph(object):
num_strongly_connected_components : int
The number of the strongly connected components.

strongly_connected_components : list(ndarray(int))
strongly_connected_components_indices : list(ndarray(int))
List of numpy arrays containing the indices of the strongly
connected components.

strongly_connected_components : list(ndarray)
List of numpy arrays containing the strongly connected
components.
components, where the nodes are annotated with their labels (if
`node_labels` is not None).

num_sink_strongly_connected_components : int
The number of the sink strongly connected components.

sink_strongly_connected_components : list(ndarray(int))
sink_strongly_connected_components_indices : list(ndarray(int))
List of numpy arrays containing the indices of the sink strongly
connected components.

sink_strongly_connected_components : list(ndarray)
List of numpy arrays containing the sink strongly connected
components.
components, where the nodes are annotated with their labels (if
`node_labels` is not None).

is_aperiodic : bool
Indicate whether the digraph is aperiodic.
Expand All @@ -56,8 +81,14 @@ class DiGraph(object):
The period of the digraph. Defined only for a strongly connected
digraph.

cyclic_components : list(ndarray(int))
List of numpy arrays containing the cyclic components.
cyclic_components_indices : list(ndarray(int))
List of numpy arrays containing the indices of the cyclic
components.

cyclic_components : list(ndarray)
List of numpy arrays containing the cyclic components, where the
nodes are annotated with their labels (if `node_labels` is not
None).

References
----------
Expand All @@ -70,7 +101,7 @@ class DiGraph(object):

"""

def __init__(self, adj_matrix, weighted=False):
def __init__(self, adj_matrix, weighted=False, node_labels=None):
if weighted:
dtype = None
else:
Expand All @@ -83,6 +114,9 @@ def __init__(self, adj_matrix, weighted=False):

self.n = n # Number of nodes

# Call the setter method
self.node_labels = node_labels

self._num_scc = None
self._scc_proj = None
self._sink_scc_labels = None
Expand All @@ -95,6 +129,26 @@ def __repr__(self):
def __str__(self):
return "Directed Graph:\n - n(number of nodes): {n}".format(n=self.n)

@property
def node_labels(self):
return self._node_labels

@node_labels.setter
def node_labels(self, values):
if values is None:
self._node_labels = None
else:
values = np.asarray(values)
if (values.ndim < 1) or (values.shape[0] != self.n):
raise ValueError(
'node_labels must be an array_like of length n'
)
if np.issubdtype(values.dtype, np.object_):
raise ValueError(
'data in node_labels must be homogeneous in type'
)
self._node_labels = values

def _find_scc(self):
"""
Set ``self._num_scc`` and ``self._scc_proj``
Expand Down Expand Up @@ -170,21 +224,31 @@ def num_sink_strongly_connected_components(self):
return len(self.sink_scc_labels)

@property
def strongly_connected_components(self):
def strongly_connected_components_indices(self):
if self.is_strongly_connected:
return [np.arange(self.n)]
else:
return [np.where(self.scc_proj == k)[0]
for k in range(self.num_strongly_connected_components)]

@property
def sink_strongly_connected_components(self):
@annotate_nodes
def strongly_connected_components(self):
return self.strongly_connected_components_indices

@property
def sink_strongly_connected_components_indices(self):
if self.is_strongly_connected:
return [np.arange(self.n)]
else:
return [np.where(self.scc_proj == k)[0]
for k in self.sink_scc_labels.tolist()]

@property
@annotate_nodes
def sink_strongly_connected_components(self):
return self.sink_strongly_connected_components_indices

def _compute_period(self):
"""
Set ``self._period`` and ``self._cyclic_components_proj``.
Expand Down Expand Up @@ -256,13 +320,18 @@ def is_aperiodic(self):
return (self.period == 1)

@property
def cyclic_components(self):
def cyclic_components_indices(self):
if self.is_aperiodic:
return [np.arange(self.n)]
else:
return [np.where(self._cyclic_components_proj == k)[0]
for k in range(self.period)]

@property
@annotate_nodes
def cyclic_components(self,):
return self.cyclic_components_indices

def subgraph(self, nodes):
"""
Return the subgraph consisting of the given nodes and edges
Expand All @@ -271,7 +340,7 @@ def subgraph(self, nodes):
Parameters
----------
nodes : array_like(int, ndim=1)
Array of nodes.
Array of node indices.

Returns
-------
Expand All @@ -282,7 +351,13 @@ def subgraph(self, nodes):
adj_matrix = self.csgraph[nodes, :][:, nodes]

weighted = True # To copy the dtype
return DiGraph(adj_matrix, weighted=weighted)

if self.node_labels is not None:
node_labels = self.node_labels[nodes]
else:
node_labels = None

return DiGraph(adj_matrix, weighted=weighted, node_labels=node_labels)


def _csr_matrix_indices(S):
Expand Down
Loading