Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Some reorg in order to help and allow others to contribute more #439

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
559 changes: 5 additions & 554 deletions diagrams/__init__.py

Large diffs are not rendered by default.

74 changes: 74 additions & 0 deletions diagrams/cluster.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
from graphviz import Digraph
from typing import Optional
from .context import Context
from .utils import setcluster, getcluster, getdiagram

class Cluster(Context):
__bgcolors = ("#E5F5FD", "#EBF3E7", "#ECE8F6", "#FDF7E3")

# fmt: off
_default_graph_attrs = {
"shape": "box",
"style": "rounded",
"labeljust": "l",
"pencolor": "#AEB6BE",
"fontname": "Sans-Serif",
"fontsize": "12",
}

# fmt: on

# FIXME:
# Cluster direction does not work now. Graphviz couldn't render
# correctly for a subgraph that has a different rank direction.
def __init__(
self,
label: str = "cluster",
direction: str = "LR",
graph_attr: Optional[dict] = None,
):
"""Cluster represents a cluster context.

:param label: Cluster label.
:param direction: Data flow direction. Default is 'left to right'.
:param graph_attr: Provide graph_attr dot config attributes.
"""
if graph_attr is None:
graph_attr = {}
self.label = label
super().__init__("cluster_" + self.label)

# Set attributes.
for k, v in self._default_graph_attrs.items():
self.dot.graph_attr[k] = v
self.dot.graph_attr["label"] = self.label

if not self._validate_direction(direction):
raise ValueError(f'"{direction}" is not a valid direction')
self.dot.graph_attr["rankdir"] = direction

# Node must be belong to a diagrams.
self._diagram = getdiagram()
if self._diagram is None:
raise EnvironmentError("Global diagrams context not set up")
self._parent = getcluster()

# Set cluster depth for distinguishing the background color
self.depth = self._parent.depth + 1 if self._parent else 0
coloridx = self.depth % len(self.__bgcolors)
self.dot.graph_attr["bgcolor"] = self.__bgcolors[coloridx]

# Merge passed in attributes
self.dot.graph_attr.update(graph_attr)

def __enter__(self):
setcluster(self)
return self

def __exit__(self, exc_type, exc_value, traceback):
if self._parent:
self._parent.subgraph(self.dot)
else:
self._diagram.subgraph(self.dot)
setcluster(self._parent)

27 changes: 27 additions & 0 deletions diagrams/context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from graphviz import Digraph
from abc import ABC, abstractmethod

class Context(ABC):
__directions = ("TB", "BT", "LR", "RL")

def __init__(self, name, **kwargs):
self.name = name
self.dot = Digraph(self.name, **kwargs)

@abstractmethod
def __enter__(self):
pass

@abstractmethod
def __exit__(self, exc_type, exc_value, traceback):
pass

def _validate_direction(self, direction: str) -> bool:
return direction.upper() in self.__directions

def node(self, nodeid: str, label: str, **attrs) -> None:
"""Create a new node in the cluster."""
self.dot.node(nodeid, label=label, **attrs)

def subgraph(self, dot: Digraph) -> None:
self.dot.subgraph(dot)
149 changes: 149 additions & 0 deletions diagrams/diagram.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
import os
from .node import Node
from .edge import Edge
from typing import Optional
from graphviz import Digraph
from .context import Context
from .utils import setdiagram


class Diagram(Context):
__curvestyles = ("ortho", "curved")
__outformats = ("png", "jpg", "svg", "pdf", "dot")

# fmt: off
_default_graph_attrs = {
"pad": "2.0",
"splines": "ortho",
"nodesep": "0.60",
"ranksep": "0.75",
"fontname": "Sans-Serif",
"fontsize": "15",
"fontcolor": "#2D3436",
}
_default_node_attrs = {
"shape": "box",
"style": "rounded",
"fixedsize": "true",
"width": "1.4",
"height": "1.4",
"labelloc": "b",
# imagepos attribute is not backward compatible
# TODO: check graphviz version to see if "imagepos" is available >= 2.40
# https://github.com/xflr6/graphviz/blob/master/graphviz/backend.py#L248
# "imagepos": "tc",
"imagescale": "true",
"fontname": "Sans-Serif",
"fontsize": "13",
"fontcolor": "#2D3436",
}
_default_edge_attrs = {
"color": "#7B8894",
}

# fmt: on

# TODO: Label position option
# TODO: Save directory option (filename + directory?)
def __init__(
self,
name: str = "",
filename: str = "",
direction: str = "LR",
curvestyle: str = "ortho",
outformats: list = ["png"],
autolabel: bool = False,
show: bool = True,
strict: bool = False,
graph_attr: Optional[dict] = None,
node_attr: Optional[dict] = None,
edge_attr: Optional[dict] = None,
):
"""Diagram represents a global diagrams context.

:param name: Diagram name. It will be used for output filename if the
filename isn't given.
:param filename: The output filename, without the extension (.png).
If not given, it will be generated from the name.
:param direction: Data flow direction. Default is 'left to right'.
:param curvestyle: Curve bending style. One of "ortho" or "curved".
:param outformats: List of output file formats. Default is ['png'].
:param show: Open generated image after save if true, just only save otherwise.
:param graph_attr: Provide graph_attr dot config attributes.
:param node_attr: Provide node_attr dot config attributes.
:param edge_attr: Provide edge_attr dot config attributes.
:param strict: Rendering should merge multi-edges.
"""
if graph_attr is None:
graph_attr = {}
if node_attr is None:
node_attr = {}
if edge_attr is None:
edge_attr = {}

if not name and not filename:
filename = "diagrams_image"
elif not filename:
filename = "_".join(name.split()).lower()
self.filename = filename
super().__init__(name, filename=self.filename, strict=strict)

# Set attributes.
for k, v in self._default_graph_attrs.items():
self.dot.graph_attr[k] = v
self.dot.graph_attr["label"] = self.name
for k, v in self._default_node_attrs.items():
self.dot.node_attr[k] = v
for k, v in self._default_edge_attrs.items():
self.dot.edge_attr[k] = v

if not self._validate_direction(direction):
raise ValueError(f'"{direction}" is not a valid direction')
self.dot.graph_attr["rankdir"] = direction

if not self._validate_curvestyle(curvestyle):
raise ValueError(f'"{curvestyle}" is not a valid curvestyle')
self.dot.graph_attr["splines"] = curvestyle

for outformat in outformats:
if not self._validate_outformat(outformat):
raise ValueError(f'"{outformat}" is not a valid output format')
self.outformats = outformats

# Merge passed in attributes
self.dot.graph_attr.update(graph_attr)
self.dot.node_attr.update(node_attr)
self.dot.edge_attr.update(edge_attr)

self.show = show
self.autolabel = autolabel

def __str__(self) -> str:
return str(self.dot)

def __enter__(self):
setdiagram(self)
return self

def __exit__(self, exc_type, exc_value, traceback):
self.render()
# Remove the graphviz file leaving only the image.
os.remove(self.filename)
setdiagram(None)

def _repr_png_(self):
return self.dot.pipe(format="png")

def _validate_curvestyle(self, curvestyle: str) -> bool:
return curvestyle.lower() in self.__curvestyles

def _validate_outformat(self, outformat: str) -> bool:
return outformat.lower() in self.__outformats

def connect(self, node: "Node", node2: "Node", edge: "Edge") -> None:
"""Connect the two Nodes."""
self.dot.edge(node.nodeid, node2.nodeid, **edge.attrs)

def render(self) -> None:
for outformat in self.outformats:
self.dot.render(format=outformat, view=self.show, quiet=True)
119 changes: 119 additions & 0 deletions diagrams/edge.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
from . import node
from typing import List, Union, Dict

class Edge:
"""Edge represents an edge between two nodes."""

_default_edge_attrs = {
"fontcolor": "#2D3436",
"fontname": "Sans-Serif",
"fontsize": "13",
}

def __init__(
self,
n: "node.Node" = None,
forward: bool = False,
reverse: bool = False,
label: str = "",
color: str = "",
style: str = "",
**attrs: Dict,
):
"""Edge represents an edge between two nodes.

:param n: Parent node.
:param forward: Points forward.
:param reverse: Points backward.
:param label: Edge label.
:param color: Edge color.
:param style: Edge style.
:param attrs: Other edge attributes
"""
if n is not None:
assert isinstance(n, node.Node)

self.node = n
self.forward = forward
self.reverse = reverse

self._attrs = {}

# Set attributes.
for k, v in self._default_edge_attrs.items():
self._attrs[k] = v

if label:
# Graphviz complaining about using label for edges, so replace it with xlabel.
# Update: xlabel option causes the misaligned label position: https://github.com/mingrammer/diagrams/issues/83
self._attrs["label"] = label
if color:
self._attrs["color"] = color
if style:
self._attrs["style"] = style
self._attrs.update(attrs)

def __sub__(self, other: Union["node.Node", "Edge", List["node.Node"]]):
"""Implement Self - Node or Edge and Self - [Nodes]"""
return self.connect(other)

def __rsub__(self, other: Union[List["node.Node"], List["Edge"]]) -> List["Edge"]:
"""Called for [Nodes] or [Edges] - Self because list don't have __sub__ operators."""
return self.append(other)

def __rshift__(self, other: Union["node.Node", "Edge", List["node.Node"]]):
"""Implements Self >> Node or Edge and Self >> [Nodes]."""
self.forward = True
return self.connect(other)

def __lshift__(self, other: Union["node.Node", "Edge", List["node.Node"]]):
"""Implements Self << Node or Edge and Self << [Nodes]."""
self.reverse = True
return self.connect(other)

def __rrshift__(self, other: Union[List["node.Node"], List["Edge"]]) -> List["Edge"]:
"""Called for [Nodes] or [Edges] >> Self because list of Edges don't have __rshift__ operators."""
return self.append(other, forward=True)

def __rlshift__(self, other: Union[List["node.Node"], List["Edge"]]) -> List["Edge"]:
"""Called for [Nodes] or [Edges] << Self because list of Edges don't have __lshift__ operators."""
return self.append(other, reverse=True)

def append(self, other: Union[List["node.Node"], List["Edge"]], forward=None, reverse=None) -> List["Edge"]:
result = []
for o in other:
if isinstance(o, Edge):
o.forward = forward if forward else o.forward
o.reverse = forward if forward else o.reverse
self._attrs = o.attrs.copy()
result.append(o)
else:
result.append(Edge(o, forward=forward, reverse=reverse, **self._attrs))
return result

def connect(self, other: Union["node.Node", "Edge", List["node.Node"]]):
if isinstance(other, list):
for n in other:
self.node.connect(n, self)
return other
elif isinstance(other, Edge):
self._attrs = other._attrs.copy()
return self
else:
if self.node is not None:
return self.node.connect(other, self)
else:
self.node = other
return self

@property
def attrs(self) -> Dict:
if self.forward and self.reverse:
direction = "both"
elif self.forward:
direction = "forward"
elif self.reverse:
direction = "back"
else:
direction = "none"
return {**self._attrs, "dir": direction}
Loading