Skip to content

Commit

Permalink
split classes into files
Browse files Browse the repository at this point in the history
create a context abstract class

Node as Cluster

Adding documentation

Revert node as cluster

Fix multiple out format implementation & test

CR Changes

Group=Cluster
  • Loading branch information
dan-ash committed Jul 23, 2023
1 parent b19b097 commit 6201f5e
Show file tree
Hide file tree
Showing 10 changed files with 902 additions and 820 deletions.
559 changes: 5 additions & 554 deletions diagrams/__init__.py

Large diffs are not rendered by default.

74 changes: 74 additions & 0 deletions diagrams/cluster.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
from graphviz import Digraph
from typing import Optional
from .context import Context
from .utils import setcluster, getcluster, getdiagram

class Cluster(Context):
__bgcolors = ("#E5F5FD", "#EBF3E7", "#ECE8F6", "#FDF7E3")

# fmt: off
_default_graph_attrs = {
"shape": "box",
"style": "rounded",
"labeljust": "l",
"pencolor": "#AEB6BE",
"fontname": "Sans-Serif",
"fontsize": "12",
}

# fmt: on

# FIXME:
# Cluster direction does not work now. Graphviz couldn't render
# correctly for a subgraph that has a different rank direction.
def __init__(
self,
label: str = "cluster",
direction: str = "LR",
graph_attr: Optional[dict] = None,
):
"""Cluster represents a cluster context.
:param label: Cluster label.
:param direction: Data flow direction. Default is 'left to right'.
:param graph_attr: Provide graph_attr dot config attributes.
"""
if graph_attr is None:
graph_attr = {}
self.label = label
super().__init__("cluster_" + self.label)

# Set attributes.
for k, v in self._default_graph_attrs.items():
self.dot.graph_attr[k] = v
self.dot.graph_attr["label"] = self.label

if not self._validate_direction(direction):
raise ValueError(f'"{direction}" is not a valid direction')
self.dot.graph_attr["rankdir"] = direction

# Node must be belong to a diagrams.
self._diagram = getdiagram()
if self._diagram is None:
raise EnvironmentError("Global diagrams context not set up")
self._parent = getcluster()

# Set cluster depth for distinguishing the background color
self.depth = self._parent.depth + 1 if self._parent else 0
coloridx = self.depth % len(self.__bgcolors)
self.dot.graph_attr["bgcolor"] = self.__bgcolors[coloridx]

# Merge passed in attributes
self.dot.graph_attr.update(graph_attr)

def __enter__(self):
setcluster(self)
return self

def __exit__(self, exc_type, exc_value, traceback):
if self._parent:
self._parent.subgraph(self.dot)
else:
self._diagram.subgraph(self.dot)
setcluster(self._parent)

27 changes: 27 additions & 0 deletions diagrams/context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from graphviz import Digraph
from abc import ABC, abstractmethod

class Context(ABC):
__directions = ("TB", "BT", "LR", "RL")

def __init__(self, name, **kwargs):
self.name = name
self.dot = Digraph(self.name, **kwargs)

@abstractmethod
def __enter__(self):
pass

@abstractmethod
def __exit__(self, exc_type, exc_value, traceback):
pass

def _validate_direction(self, direction: str) -> bool:
return direction.upper() in self.__directions

def node(self, nodeid: str, label: str, **attrs) -> None:
"""Create a new node in the cluster."""
self.dot.node(nodeid, label=label, **attrs)

def subgraph(self, dot: Digraph) -> None:
self.dot.subgraph(dot)
149 changes: 149 additions & 0 deletions diagrams/diagram.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
import os
from .node import Node
from .edge import Edge
from typing import Optional
from graphviz import Digraph
from .context import Context
from .utils import setdiagram


class Diagram(Context):
__curvestyles = ("ortho", "curved")
__outformats = ("png", "jpg", "svg", "pdf", "dot")

# fmt: off
_default_graph_attrs = {
"pad": "2.0",
"splines": "ortho",
"nodesep": "0.60",
"ranksep": "0.75",
"fontname": "Sans-Serif",
"fontsize": "15",
"fontcolor": "#2D3436",
}
_default_node_attrs = {
"shape": "box",
"style": "rounded",
"fixedsize": "true",
"width": "1.4",
"height": "1.4",
"labelloc": "b",
# imagepos attribute is not backward compatible
# TODO: check graphviz version to see if "imagepos" is available >= 2.40
# https://github.com/xflr6/graphviz/blob/master/graphviz/backend.py#L248
# "imagepos": "tc",
"imagescale": "true",
"fontname": "Sans-Serif",
"fontsize": "13",
"fontcolor": "#2D3436",
}
_default_edge_attrs = {
"color": "#7B8894",
}

# fmt: on

# TODO: Label position option
# TODO: Save directory option (filename + directory?)
def __init__(
self,
name: str = "",
filename: str = "",
direction: str = "LR",
curvestyle: str = "ortho",
outformats: list = ["png"],
autolabel: bool = False,
show: bool = True,
strict: bool = False,
graph_attr: Optional[dict] = None,
node_attr: Optional[dict] = None,
edge_attr: Optional[dict] = None,
):
"""Diagram represents a global diagrams context.
:param name: Diagram name. It will be used for output filename if the
filename isn't given.
:param filename: The output filename, without the extension (.png).
If not given, it will be generated from the name.
:param direction: Data flow direction. Default is 'left to right'.
:param curvestyle: Curve bending style. One of "ortho" or "curved".
:param outformats: List of output file formats. Default is ['png'].
:param show: Open generated image after save if true, just only save otherwise.
:param graph_attr: Provide graph_attr dot config attributes.
:param node_attr: Provide node_attr dot config attributes.
:param edge_attr: Provide edge_attr dot config attributes.
:param strict: Rendering should merge multi-edges.
"""
if graph_attr is None:
graph_attr = {}
if node_attr is None:
node_attr = {}
if edge_attr is None:
edge_attr = {}

if not name and not filename:
filename = "diagrams_image"
elif not filename:
filename = "_".join(name.split()).lower()
self.filename = filename
super().__init__(name, filename=self.filename, strict=strict)

# Set attributes.
for k, v in self._default_graph_attrs.items():
self.dot.graph_attr[k] = v
self.dot.graph_attr["label"] = self.name
for k, v in self._default_node_attrs.items():
self.dot.node_attr[k] = v
for k, v in self._default_edge_attrs.items():
self.dot.edge_attr[k] = v

if not self._validate_direction(direction):
raise ValueError(f'"{direction}" is not a valid direction')
self.dot.graph_attr["rankdir"] = direction

if not self._validate_curvestyle(curvestyle):
raise ValueError(f'"{curvestyle}" is not a valid curvestyle')
self.dot.graph_attr["splines"] = curvestyle

for outformat in outformats:
if not self._validate_outformat(outformat):
raise ValueError(f'"{outformat}" is not a valid output format')
self.outformats = outformats

# Merge passed in attributes
self.dot.graph_attr.update(graph_attr)
self.dot.node_attr.update(node_attr)
self.dot.edge_attr.update(edge_attr)

self.show = show
self.autolabel = autolabel

def __str__(self) -> str:
return str(self.dot)

def __enter__(self):
setdiagram(self)
return self

def __exit__(self, exc_type, exc_value, traceback):
self.render()
# Remove the graphviz file leaving only the image.
os.remove(self.filename)
setdiagram(None)

def _repr_png_(self):
return self.dot.pipe(format="png")

def _validate_curvestyle(self, curvestyle: str) -> bool:
return curvestyle.lower() in self.__curvestyles

def _validate_outformat(self, outformat: str) -> bool:
return outformat.lower() in self.__outformats

def connect(self, node: "Node", node2: "Node", edge: "Edge") -> None:
"""Connect the two Nodes."""
self.dot.edge(node.nodeid, node2.nodeid, **edge.attrs)

def render(self) -> None:
for outformat in self.outformats:
self.dot.render(format=outformat, view=self.show, quiet=True)
119 changes: 119 additions & 0 deletions diagrams/edge.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
from . import node
from typing import List, Union, Dict

class Edge:
"""Edge represents an edge between two nodes."""

_default_edge_attrs = {
"fontcolor": "#2D3436",
"fontname": "Sans-Serif",
"fontsize": "13",
}

def __init__(
self,
n: "node.Node" = None,
forward: bool = False,
reverse: bool = False,
label: str = "",
color: str = "",
style: str = "",
**attrs: Dict,
):
"""Edge represents an edge between two nodes.
:param n: Parent node.
:param forward: Points forward.
:param reverse: Points backward.
:param label: Edge label.
:param color: Edge color.
:param style: Edge style.
:param attrs: Other edge attributes
"""
if n is not None:
assert isinstance(n, node.Node)

self.node = n
self.forward = forward
self.reverse = reverse

self._attrs = {}

# Set attributes.
for k, v in self._default_edge_attrs.items():
self._attrs[k] = v

if label:
# Graphviz complaining about using label for edges, so replace it with xlabel.
# Update: xlabel option causes the misaligned label position: https://github.com/mingrammer/diagrams/issues/83
self._attrs["label"] = label
if color:
self._attrs["color"] = color
if style:
self._attrs["style"] = style
self._attrs.update(attrs)

def __sub__(self, other: Union["node.Node", "Edge", List["node.Node"]]):
"""Implement Self - Node or Edge and Self - [Nodes]"""
return self.connect(other)

def __rsub__(self, other: Union[List["node.Node"], List["Edge"]]) -> List["Edge"]:
"""Called for [Nodes] or [Edges] - Self because list don't have __sub__ operators."""
return self.append(other)

def __rshift__(self, other: Union["node.Node", "Edge", List["node.Node"]]):
"""Implements Self >> Node or Edge and Self >> [Nodes]."""
self.forward = True
return self.connect(other)

def __lshift__(self, other: Union["node.Node", "Edge", List["node.Node"]]):
"""Implements Self << Node or Edge and Self << [Nodes]."""
self.reverse = True
return self.connect(other)

def __rrshift__(self, other: Union[List["node.Node"], List["Edge"]]) -> List["Edge"]:
"""Called for [Nodes] or [Edges] >> Self because list of Edges don't have __rshift__ operators."""
return self.append(other, forward=True)

def __rlshift__(self, other: Union[List["node.Node"], List["Edge"]]) -> List["Edge"]:
"""Called for [Nodes] or [Edges] << Self because list of Edges don't have __lshift__ operators."""
return self.append(other, reverse=True)

def append(self, other: Union[List["node.Node"], List["Edge"]], forward=None, reverse=None) -> List["Edge"]:
result = []
for o in other:
if isinstance(o, Edge):
o.forward = forward if forward else o.forward
o.reverse = forward if forward else o.reverse
self._attrs = o.attrs.copy()
result.append(o)
else:
result.append(Edge(o, forward=forward, reverse=reverse, **self._attrs))
return result

def connect(self, other: Union["node.Node", "Edge", List["node.Node"]]):
if isinstance(other, list):
for n in other:
self.node.connect(n, self)
return other
elif isinstance(other, Edge):
self._attrs = other._attrs.copy()
return self
else:
if self.node is not None:
return self.node.connect(other, self)
else:
self.node = other
return self

@property
def attrs(self) -> Dict:
if self.forward and self.reverse:
direction = "both"
elif self.forward:
direction = "forward"
elif self.reverse:
direction = "back"
else:
direction = "none"
return {**self._attrs, "dir": direction}
Loading

0 comments on commit 6201f5e

Please sign in to comment.