diff --git a/datajoint/dependencies.py b/datajoint/dependencies.py index 06c380b5..4f78ad4f 100644 --- a/datajoint/dependencies.py +++ b/datajoint/dependencies.py @@ -1,9 +1,74 @@ import networkx as nx import itertools +import re from collections import defaultdict from .errors import DataJointError +def extract_master(part_table): + """ + given a part table name, return master part. None if not a part table + """ + match = re.match(r"(?P`\w+`.`#?\w+)__\w+`", part_table) + return match["master"] + "`" if match else None + + +def topo_sort(graph): + """ + topological sort of a dependency graph that keeps part tables together with their masters + :return: list of table names in topological order + """ + + graph = nx.DiGraph(graph) # make a copy + + # collapse alias nodes + alias_nodes = [node for node in graph if node.isdigit()] + for node in alias_nodes: + try: + direct_edge = ( + next(x for x in graph.in_edges(node))[0], + next(x for x in graph.out_edges(node))[1], + ) + except StopIteration: + pass # a disconnected alias node + else: + graph.add_edge(*direct_edge) + graph.remove_nodes_from(alias_nodes) + + # Add parts' dependencies to their masters' dependencies + # to ensure correct topological ordering of the masters. + for part in graph: + # find the part's master + if (master := extract_master(part)) in graph: + for edge in graph.in_edges(part): + parent = edge[0] + if parent != master and extract_master(parent) != master: + graph.add_edge(parent, master) + sorted_nodes = list(nx.topological_sort(graph)) + + # bring parts up to their masters + pos = len(sorted_nodes) - 1 + placed = set() + while pos > 1: + part = sorted_nodes[pos] + if not (master := extract_master) or part in placed: + pos -= 1 + else: + placed.add(part) + try: + j = sorted_nodes.index(master) + except ValueError: + # master not found + pass + else: + if pos > j + 1: + # move the part to its master + del sorted_nodes[pos] + sorted_nodes.insert(j + 1, part) + + return sorted_nodes + + class Dependencies(nx.DiGraph): """ The graph of dependencies (foreign keys) between loaded tables. @@ -106,6 +171,10 @@ def load(self, force=True): raise DataJointError("DataJoint can only work with acyclic dependencies") self._loaded = True + def topo_sort(self): + """:return: list of tables names in topological order""" + return topo_sort(self) + def parents(self, table_name, primary=None): """ :param table_name: `schema`.`table` @@ -142,8 +211,8 @@ def descendants(self, full_table_name): :return: all dependent tables sorted in topological order. Self is included. """ self.load(force=False) - nodes = self.subgraph(nx.algorithms.dag.descendants(self, full_table_name)) - return [full_table_name] + list(nx.algorithms.dag.topological_sort(nodes)) + nodes = self.subgraph(nx.descendants(self, full_table_name)) + return [full_table_name] + nodes.topo_sort() def ancestors(self, full_table_name): """ @@ -151,9 +220,5 @@ def ancestors(self, full_table_name): :return: all dependent tables sorted in topological order. Self is included. """ self.load(force=False) - nodes = self.subgraph(nx.algorithms.dag.ancestors(self, full_table_name)) - return list( - reversed( - list(nx.algorithms.dag.topological_sort(nodes)) + [full_table_name] - ) - ) + nodes = self.subgraph(nx.ancestors(self, full_table_name)) + return reversed(nodes.topo_sort() + [full_table_name]) diff --git a/datajoint/diagram.py b/datajoint/diagram.py index 0136ccaf..ca1df82b 100644 --- a/datajoint/diagram.py +++ b/datajoint/diagram.py @@ -5,7 +5,8 @@ import logging import inspect from .table import Table -from .user_tables import Manual, Imported, Computed, Lookup, Part +from .dependencies import topo_sort +from .user_tables import Manual, Imported, Computed, Lookup, Part, _get_tier, _AliasNode from .errors import DataJointError from .table import lookup_class_name @@ -26,29 +27,6 @@ logger = logging.getLogger(__name__.split(".")[0]) -user_table_classes = (Manual, Lookup, Computed, Imported, Part) - - -class _AliasNode: - """ - special class to indicate aliased foreign keys - """ - - pass - - -def _get_tier(table_name): - if not table_name.startswith("`"): - return _AliasNode - else: - try: - return next( - tier - for tier in user_table_classes - if re.fullmatch(tier.tier_regexp, table_name.split("`")[-2]) - ) - except StopIteration: - return None if not diagram_active: @@ -70,19 +48,22 @@ def __init__(self, *args, **kwargs): class Diagram(nx.DiGraph): """ - Entity relationship diagram. + Schema diagram showing tables and foreign keys between in the form of a directed + acyclic graph (DAG). The diagram is derived from the connection.dependencies object. Usage: >>> diag = Diagram(source) - source can be a base table object, a base table class, a schema, or a module that has a schema. + source can be a table object, a table class, a schema, or a module that has a schema. >>> diag.draw() draws the diagram using pyplot diag1 + diag2 - combines the two diagrams. + diag1 - diag2 - differente between diagrams + diag1 * diag2 - intersction of diagrams diag + n - expands n levels of successors diag - n - expands n levels of predecessors Thus dj.Diagram(schema.Table)+1-1 defines the diagram of immediate ancestors and descendants of schema.Table @@ -91,7 +72,8 @@ class Diagram(nx.DiGraph): Only those tables that are loaded in the connection object are displayed """ - def __init__(self, source, context=None): + def __init__(self, source=None, context=None): + if isinstance(source, Diagram): # copy constructor self.nodes_to_show = set(source.nodes_to_show) @@ -152,7 +134,7 @@ def from_sequence(cls, sequence): def add_parts(self): """ - Adds to the diagram the part tables of tables already included in the diagram + Adds to the diagram the part tables of all master tables already in the diagram :return: """ @@ -177,14 +159,6 @@ def is_part(part, master): ) return self - def topological_sort(self): - """:return: list of nodes in topological order""" - return list( - nx.algorithms.dag.topological_sort( - nx.DiGraph(self).subgraph(self.nodes_to_show) - ) - ) - def __add__(self, arg): """ :param arg: either another Diagram or a positive integer. @@ -252,6 +226,10 @@ def __mul__(self, arg): self.nodes_to_show.intersection_update(arg.nodes_to_show) return self + def topo_sort(self): + """return nodes in lexicographical topological order""" + return topo_sort(self) + def _make_graph(self): """ Make the self.graph - a graph object ready for drawing diff --git a/datajoint/schemas.py b/datajoint/schemas.py index 25c3f4b4..7545f828 100644 --- a/datajoint/schemas.py +++ b/datajoint/schemas.py @@ -2,16 +2,17 @@ import logging import inspect import re +import collections +import itertools from .connection import conn -from .diagram import Diagram from .settings import config from .errors import DataJointError, AccessError from .jobs import JobTable from .external import ExternalMapping from .heading import Heading from .utils import user_choice, to_camel_case -from .user_tables import Part, Computed, Imported, Manual, Lookup -from .table import lookup_class_name, Log +from .user_tables import Part, Computed, Imported, Manual, Lookup, _get_tier +from .table import lookup_class_name, Log, FreeTable import types logger = logging.getLogger(__name__.split(".")[0]) @@ -399,6 +400,76 @@ def jobs(self): self._jobs = JobTable(self.connection, self.database) return self._jobs + @property + def code(self): + self._assert_exists() + return self.save() + + def save(self, python_filename=None): + """ + Generate the code for a module that recreates the schema. + This method is in preparation for a future release and is not officially supported. + + :return: a string containing the body of a complete Python module defining this schema. + """ + self._assert_exists() + module_count = itertools.count() + # add virtual modules for referenced modules with names vmod0, vmod1, ... + module_lookup = collections.defaultdict( + lambda: "vmod" + str(next(module_count)) + ) + db = self.database + + def make_class_definition(table): + tier = _get_tier(table).__name__ + class_name = table.split(".")[1].strip("`") + indent = "" + if tier == "Part": + class_name = class_name.split("__")[-1] + indent += " " + class_name = to_camel_case(class_name) + + def replace(s): + d, tabs = s.group(1), s.group(2) + return ("" if d == db else (module_lookup[d] + ".")) + ".".join( + to_camel_case(tab) for tab in tabs.lstrip("__").split("__") + ) + + return ("" if tier == "Part" else "\n@schema\n") + ( + "{indent}class {class_name}(dj.{tier}):\n" + '{indent} definition = """\n' + '{indent} {defi}"""' + ).format( + class_name=class_name, + indent=indent, + tier=tier, + defi=re.sub( + r"`([^`]+)`.`([^`]+)`", + replace, + FreeTable(self.connection, table).describe(), + ).replace("\n", "\n " + indent), + ) + + tables = self.connection.dependencies.topo_sort() + body = "\n\n".join(make_class_definition(table) for table in tables) + python_code = "\n\n".join( + ( + '"""This module was auto-generated by datajoint from an existing schema"""', + "import datajoint as dj\n\nschema = dj.Schema('{db}')".format(db=db), + "\n".join( + "{module} = dj.VirtualModule('{module}', '{schema_name}')".format( + module=v, schema_name=k + ) + for k, v in module_lookup.items() + ), + body, + ) + ) + if python_filename is None: + return python_code + with open(python_filename, "wt") as f: + f.write(python_code) + def list_tables(self): """ Return a list of all tables in the schema except tables with ~ in first character such @@ -410,7 +481,7 @@ def list_tables(self): t for d, t in ( full_t.replace("`", "").split(".") - for full_t in Diagram(self).topological_sort() + for full_t in Diagram(self).topo_sort() ) if d == self.database ] diff --git a/datajoint/table.py b/datajoint/table.py index a597956e..db9eaffa 100644 --- a/datajoint/table.py +++ b/datajoint/table.py @@ -217,7 +217,6 @@ def children(self, primary=None, as_objects=False, foreign_key_info=False): def descendants(self, as_objects=False): """ - :param as_objects: False - a list of table names; True - a list of table objects. :return: list of tables descendants in topological order. """ diff --git a/datajoint/user_tables.py b/datajoint/user_tables.py index bcb6a027..0a784560 100644 --- a/datajoint/user_tables.py +++ b/datajoint/user_tables.py @@ -2,6 +2,7 @@ Hosts the table tiers, user tables should be derived from. """ +import re from .table import Table from .autopopulate import AutoPopulate from .utils import from_camel_case, ClassProperty @@ -242,3 +243,29 @@ def drop(self, force=False): def alter(self, prompt=True, context=None): # without context, use declaration context which maps master keyword to master table super().alter(prompt=prompt, context=context or self.declaration_context) + + +user_table_classes = (Manual, Lookup, Computed, Imported, Part) + + +class _AliasNode: + """ + special class to indicate aliased foreign keys + """ + + pass + + +def _get_tier(table_name): + """given the table name, return""" + if not table_name.startswith("`"): + return _AliasNode + else: + try: + return next( + tier + for tier in user_table_classes + if re.fullmatch(tier.tier_regexp, table_name.split("`")[-2]) + ) + except StopIteration: + return None diff --git a/tests/test_cli.py b/tests/test_cli.py index 3f0fd00c..29fedf22 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -3,7 +3,6 @@ """ import json -import ast import subprocess import pytest import datajoint as dj diff --git a/tests/test_schema.py b/tests/test_schema.py index 857c1474..257de221 100644 --- a/tests/test_schema.py +++ b/tests/test_schema.py @@ -218,6 +218,14 @@ def test_list_tables(schema_simp): assert actual == expected, f"Missing from list_tables(): {expected - actual}" +def test_schema_save_any(schema_any): + assert "class Experiment(dj.Imported)" in schema_any.code + + +def test_schema_save_empty(schema_empty): + assert "class Experiment(dj.Imported)" in schema_empty.code + + def test_uppercase_schema(db_creds_root): """ https://github.com/datajoint/datajoint-python/issues/564