-
Notifications
You must be signed in to change notification settings - Fork 2.6k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
create a context abstract class Node as Cluster Adding documentation Revert node as cluster Fix multiple out format implementation & test CR Changes Group=Cluster
- Loading branch information
Showing
10 changed files
with
902 additions
and
820 deletions.
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
from graphviz import Digraph | ||
from typing import Optional | ||
from .context import Context | ||
from .utils import setcluster, getcluster, getdiagram | ||
|
||
class Cluster(Context): | ||
__bgcolors = ("#E5F5FD", "#EBF3E7", "#ECE8F6", "#FDF7E3") | ||
|
||
# fmt: off | ||
_default_graph_attrs = { | ||
"shape": "box", | ||
"style": "rounded", | ||
"labeljust": "l", | ||
"pencolor": "#AEB6BE", | ||
"fontname": "Sans-Serif", | ||
"fontsize": "12", | ||
} | ||
|
||
# fmt: on | ||
|
||
# FIXME: | ||
# Cluster direction does not work now. Graphviz couldn't render | ||
# correctly for a subgraph that has a different rank direction. | ||
def __init__( | ||
self, | ||
label: str = "cluster", | ||
direction: str = "LR", | ||
graph_attr: Optional[dict] = None, | ||
): | ||
"""Cluster represents a cluster context. | ||
:param label: Cluster label. | ||
:param direction: Data flow direction. Default is 'left to right'. | ||
:param graph_attr: Provide graph_attr dot config attributes. | ||
""" | ||
if graph_attr is None: | ||
graph_attr = {} | ||
self.label = label | ||
super().__init__("cluster_" + self.label) | ||
|
||
# Set attributes. | ||
for k, v in self._default_graph_attrs.items(): | ||
self.dot.graph_attr[k] = v | ||
self.dot.graph_attr["label"] = self.label | ||
|
||
if not self._validate_direction(direction): | ||
raise ValueError(f'"{direction}" is not a valid direction') | ||
self.dot.graph_attr["rankdir"] = direction | ||
|
||
# Node must be belong to a diagrams. | ||
self._diagram = getdiagram() | ||
if self._diagram is None: | ||
raise EnvironmentError("Global diagrams context not set up") | ||
self._parent = getcluster() | ||
|
||
# Set cluster depth for distinguishing the background color | ||
self.depth = self._parent.depth + 1 if self._parent else 0 | ||
coloridx = self.depth % len(self.__bgcolors) | ||
self.dot.graph_attr["bgcolor"] = self.__bgcolors[coloridx] | ||
|
||
# Merge passed in attributes | ||
self.dot.graph_attr.update(graph_attr) | ||
|
||
def __enter__(self): | ||
setcluster(self) | ||
return self | ||
|
||
def __exit__(self, exc_type, exc_value, traceback): | ||
if self._parent: | ||
self._parent.subgraph(self.dot) | ||
else: | ||
self._diagram.subgraph(self.dot) | ||
setcluster(self._parent) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
from graphviz import Digraph | ||
from abc import ABC, abstractmethod | ||
|
||
class Context(ABC): | ||
__directions = ("TB", "BT", "LR", "RL") | ||
|
||
def __init__(self, name, **kwargs): | ||
self.name = name | ||
self.dot = Digraph(self.name, **kwargs) | ||
|
||
@abstractmethod | ||
def __enter__(self): | ||
pass | ||
|
||
@abstractmethod | ||
def __exit__(self, exc_type, exc_value, traceback): | ||
pass | ||
|
||
def _validate_direction(self, direction: str) -> bool: | ||
return direction.upper() in self.__directions | ||
|
||
def node(self, nodeid: str, label: str, **attrs) -> None: | ||
"""Create a new node in the cluster.""" | ||
self.dot.node(nodeid, label=label, **attrs) | ||
|
||
def subgraph(self, dot: Digraph) -> None: | ||
self.dot.subgraph(dot) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,149 @@ | ||
import os | ||
from .node import Node | ||
from .edge import Edge | ||
from typing import Optional | ||
from graphviz import Digraph | ||
from .context import Context | ||
from .utils import setdiagram | ||
|
||
|
||
class Diagram(Context): | ||
__curvestyles = ("ortho", "curved") | ||
__outformats = ("png", "jpg", "svg", "pdf", "dot") | ||
|
||
# fmt: off | ||
_default_graph_attrs = { | ||
"pad": "2.0", | ||
"splines": "ortho", | ||
"nodesep": "0.60", | ||
"ranksep": "0.75", | ||
"fontname": "Sans-Serif", | ||
"fontsize": "15", | ||
"fontcolor": "#2D3436", | ||
} | ||
_default_node_attrs = { | ||
"shape": "box", | ||
"style": "rounded", | ||
"fixedsize": "true", | ||
"width": "1.4", | ||
"height": "1.4", | ||
"labelloc": "b", | ||
# imagepos attribute is not backward compatible | ||
# TODO: check graphviz version to see if "imagepos" is available >= 2.40 | ||
# https://github.com/xflr6/graphviz/blob/master/graphviz/backend.py#L248 | ||
# "imagepos": "tc", | ||
"imagescale": "true", | ||
"fontname": "Sans-Serif", | ||
"fontsize": "13", | ||
"fontcolor": "#2D3436", | ||
} | ||
_default_edge_attrs = { | ||
"color": "#7B8894", | ||
} | ||
|
||
# fmt: on | ||
|
||
# TODO: Label position option | ||
# TODO: Save directory option (filename + directory?) | ||
def __init__( | ||
self, | ||
name: str = "", | ||
filename: str = "", | ||
direction: str = "LR", | ||
curvestyle: str = "ortho", | ||
outformats: list = ["png"], | ||
autolabel: bool = False, | ||
show: bool = True, | ||
strict: bool = False, | ||
graph_attr: Optional[dict] = None, | ||
node_attr: Optional[dict] = None, | ||
edge_attr: Optional[dict] = None, | ||
): | ||
"""Diagram represents a global diagrams context. | ||
:param name: Diagram name. It will be used for output filename if the | ||
filename isn't given. | ||
:param filename: The output filename, without the extension (.png). | ||
If not given, it will be generated from the name. | ||
:param direction: Data flow direction. Default is 'left to right'. | ||
:param curvestyle: Curve bending style. One of "ortho" or "curved". | ||
:param outformats: List of output file formats. Default is ['png']. | ||
:param show: Open generated image after save if true, just only save otherwise. | ||
:param graph_attr: Provide graph_attr dot config attributes. | ||
:param node_attr: Provide node_attr dot config attributes. | ||
:param edge_attr: Provide edge_attr dot config attributes. | ||
:param strict: Rendering should merge multi-edges. | ||
""" | ||
if graph_attr is None: | ||
graph_attr = {} | ||
if node_attr is None: | ||
node_attr = {} | ||
if edge_attr is None: | ||
edge_attr = {} | ||
|
||
if not name and not filename: | ||
filename = "diagrams_image" | ||
elif not filename: | ||
filename = "_".join(name.split()).lower() | ||
self.filename = filename | ||
super().__init__(name, filename=self.filename, strict=strict) | ||
|
||
# Set attributes. | ||
for k, v in self._default_graph_attrs.items(): | ||
self.dot.graph_attr[k] = v | ||
self.dot.graph_attr["label"] = self.name | ||
for k, v in self._default_node_attrs.items(): | ||
self.dot.node_attr[k] = v | ||
for k, v in self._default_edge_attrs.items(): | ||
self.dot.edge_attr[k] = v | ||
|
||
if not self._validate_direction(direction): | ||
raise ValueError(f'"{direction}" is not a valid direction') | ||
self.dot.graph_attr["rankdir"] = direction | ||
|
||
if not self._validate_curvestyle(curvestyle): | ||
raise ValueError(f'"{curvestyle}" is not a valid curvestyle') | ||
self.dot.graph_attr["splines"] = curvestyle | ||
|
||
for outformat in outformats: | ||
if not self._validate_outformat(outformat): | ||
raise ValueError(f'"{outformat}" is not a valid output format') | ||
self.outformats = outformats | ||
|
||
# Merge passed in attributes | ||
self.dot.graph_attr.update(graph_attr) | ||
self.dot.node_attr.update(node_attr) | ||
self.dot.edge_attr.update(edge_attr) | ||
|
||
self.show = show | ||
self.autolabel = autolabel | ||
|
||
def __str__(self) -> str: | ||
return str(self.dot) | ||
|
||
def __enter__(self): | ||
setdiagram(self) | ||
return self | ||
|
||
def __exit__(self, exc_type, exc_value, traceback): | ||
self.render() | ||
# Remove the graphviz file leaving only the image. | ||
os.remove(self.filename) | ||
setdiagram(None) | ||
|
||
def _repr_png_(self): | ||
return self.dot.pipe(format="png") | ||
|
||
def _validate_curvestyle(self, curvestyle: str) -> bool: | ||
return curvestyle.lower() in self.__curvestyles | ||
|
||
def _validate_outformat(self, outformat: str) -> bool: | ||
return outformat.lower() in self.__outformats | ||
|
||
def connect(self, node: "Node", node2: "Node", edge: "Edge") -> None: | ||
"""Connect the two Nodes.""" | ||
self.dot.edge(node.nodeid, node2.nodeid, **edge.attrs) | ||
|
||
def render(self) -> None: | ||
for outformat in self.outformats: | ||
self.dot.render(format=outformat, view=self.show, quiet=True) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,119 @@ | ||
from . import node | ||
from typing import List, Union, Dict | ||
|
||
class Edge: | ||
"""Edge represents an edge between two nodes.""" | ||
|
||
_default_edge_attrs = { | ||
"fontcolor": "#2D3436", | ||
"fontname": "Sans-Serif", | ||
"fontsize": "13", | ||
} | ||
|
||
def __init__( | ||
self, | ||
n: "node.Node" = None, | ||
forward: bool = False, | ||
reverse: bool = False, | ||
label: str = "", | ||
color: str = "", | ||
style: str = "", | ||
**attrs: Dict, | ||
): | ||
"""Edge represents an edge between two nodes. | ||
:param n: Parent node. | ||
:param forward: Points forward. | ||
:param reverse: Points backward. | ||
:param label: Edge label. | ||
:param color: Edge color. | ||
:param style: Edge style. | ||
:param attrs: Other edge attributes | ||
""" | ||
if n is not None: | ||
assert isinstance(n, node.Node) | ||
|
||
self.node = n | ||
self.forward = forward | ||
self.reverse = reverse | ||
|
||
self._attrs = {} | ||
|
||
# Set attributes. | ||
for k, v in self._default_edge_attrs.items(): | ||
self._attrs[k] = v | ||
|
||
if label: | ||
# Graphviz complaining about using label for edges, so replace it with xlabel. | ||
# Update: xlabel option causes the misaligned label position: https://github.com/mingrammer/diagrams/issues/83 | ||
self._attrs["label"] = label | ||
if color: | ||
self._attrs["color"] = color | ||
if style: | ||
self._attrs["style"] = style | ||
self._attrs.update(attrs) | ||
|
||
def __sub__(self, other: Union["node.Node", "Edge", List["node.Node"]]): | ||
"""Implement Self - Node or Edge and Self - [Nodes]""" | ||
return self.connect(other) | ||
|
||
def __rsub__(self, other: Union[List["node.Node"], List["Edge"]]) -> List["Edge"]: | ||
"""Called for [Nodes] or [Edges] - Self because list don't have __sub__ operators.""" | ||
return self.append(other) | ||
|
||
def __rshift__(self, other: Union["node.Node", "Edge", List["node.Node"]]): | ||
"""Implements Self >> Node or Edge and Self >> [Nodes].""" | ||
self.forward = True | ||
return self.connect(other) | ||
|
||
def __lshift__(self, other: Union["node.Node", "Edge", List["node.Node"]]): | ||
"""Implements Self << Node or Edge and Self << [Nodes].""" | ||
self.reverse = True | ||
return self.connect(other) | ||
|
||
def __rrshift__(self, other: Union[List["node.Node"], List["Edge"]]) -> List["Edge"]: | ||
"""Called for [Nodes] or [Edges] >> Self because list of Edges don't have __rshift__ operators.""" | ||
return self.append(other, forward=True) | ||
|
||
def __rlshift__(self, other: Union[List["node.Node"], List["Edge"]]) -> List["Edge"]: | ||
"""Called for [Nodes] or [Edges] << Self because list of Edges don't have __lshift__ operators.""" | ||
return self.append(other, reverse=True) | ||
|
||
def append(self, other: Union[List["node.Node"], List["Edge"]], forward=None, reverse=None) -> List["Edge"]: | ||
result = [] | ||
for o in other: | ||
if isinstance(o, Edge): | ||
o.forward = forward if forward else o.forward | ||
o.reverse = forward if forward else o.reverse | ||
self._attrs = o.attrs.copy() | ||
result.append(o) | ||
else: | ||
result.append(Edge(o, forward=forward, reverse=reverse, **self._attrs)) | ||
return result | ||
|
||
def connect(self, other: Union["node.Node", "Edge", List["node.Node"]]): | ||
if isinstance(other, list): | ||
for n in other: | ||
self.node.connect(n, self) | ||
return other | ||
elif isinstance(other, Edge): | ||
self._attrs = other._attrs.copy() | ||
return self | ||
else: | ||
if self.node is not None: | ||
return self.node.connect(other, self) | ||
else: | ||
self.node = other | ||
return self | ||
|
||
@property | ||
def attrs(self) -> Dict: | ||
if self.forward and self.reverse: | ||
direction = "both" | ||
elif self.forward: | ||
direction = "forward" | ||
elif self.reverse: | ||
direction = "back" | ||
else: | ||
direction = "none" | ||
return {**self._attrs, "dir": direction} |
Oops, something went wrong.