Skip to content

Commit

Permalink
Merge branch 'master' into key_populate
Browse files Browse the repository at this point in the history
  • Loading branch information
dimitri-yatsenko committed Sep 15, 2024
2 parents fbeaec9 + 24c090d commit f73bb59
Show file tree
Hide file tree
Showing 7 changed files with 197 additions and 50 deletions.
81 changes: 73 additions & 8 deletions datajoint/dependencies.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,74 @@
import networkx as nx
import itertools
import re
from collections import defaultdict
from .errors import DataJointError


def extract_master(part_table):
"""
given a part table name, return master part. None if not a part table
"""
match = re.match(r"(?P<master>`\w+`.`#?\w+)__\w+`", part_table)
return match["master"] + "`" if match else None


def topo_sort(graph):
"""
topological sort of a dependency graph that keeps part tables together with their masters
:return: list of table names in topological order
"""

graph = nx.DiGraph(graph) # make a copy

# collapse alias nodes
alias_nodes = [node for node in graph if node.isdigit()]
for node in alias_nodes:
try:
direct_edge = (
next(x for x in graph.in_edges(node))[0],
next(x for x in graph.out_edges(node))[1],
)
except StopIteration:
pass # a disconnected alias node
else:
graph.add_edge(*direct_edge)
graph.remove_nodes_from(alias_nodes)

# Add parts' dependencies to their masters' dependencies
# to ensure correct topological ordering of the masters.
for part in graph:
# find the part's master
if (master := extract_master(part)) in graph:
for edge in graph.in_edges(part):
parent = edge[0]
if parent != master and extract_master(parent) != master:
graph.add_edge(parent, master)
sorted_nodes = list(nx.topological_sort(graph))

# bring parts up to their masters
pos = len(sorted_nodes) - 1
placed = set()
while pos > 1:
part = sorted_nodes[pos]
if not (master := extract_master) or part in placed:
pos -= 1
else:
placed.add(part)
try:
j = sorted_nodes.index(master)
except ValueError:
# master not found
pass
else:
if pos > j + 1:
# move the part to its master
del sorted_nodes[pos]
sorted_nodes.insert(j + 1, part)

return sorted_nodes


class Dependencies(nx.DiGraph):
"""
The graph of dependencies (foreign keys) between loaded tables.
Expand Down Expand Up @@ -106,6 +171,10 @@ def load(self, force=True):
raise DataJointError("DataJoint can only work with acyclic dependencies")
self._loaded = True

def topo_sort(self):
""":return: list of tables names in topological order"""
return topo_sort(self)

def parents(self, table_name, primary=None):
"""
:param table_name: `schema`.`table`
Expand Down Expand Up @@ -142,18 +211,14 @@ def descendants(self, full_table_name):
:return: all dependent tables sorted in topological order. Self is included.
"""
self.load(force=False)
nodes = self.subgraph(nx.algorithms.dag.descendants(self, full_table_name))
return [full_table_name] + list(nx.algorithms.dag.topological_sort(nodes))
nodes = self.subgraph(nx.descendants(self, full_table_name))
return [full_table_name] + nodes.topo_sort()

def ancestors(self, full_table_name):
"""
:param full_table_name: In form `schema`.`table_name`
:return: all dependent tables sorted in topological order. Self is included.
"""
self.load(force=False)
nodes = self.subgraph(nx.algorithms.dag.ancestors(self, full_table_name))
return list(
reversed(
list(nx.algorithms.dag.topological_sort(nodes)) + [full_table_name]
)
)
nodes = self.subgraph(nx.ancestors(self, full_table_name))
return reversed(nodes.topo_sort() + [full_table_name])
50 changes: 14 additions & 36 deletions datajoint/diagram.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
import logging
import inspect
from .table import Table
from .user_tables import Manual, Imported, Computed, Lookup, Part
from .dependencies import topo_sort
from .user_tables import Manual, Imported, Computed, Lookup, Part, _get_tier, _AliasNode
from .errors import DataJointError
from .table import lookup_class_name

Expand All @@ -26,29 +27,6 @@


logger = logging.getLogger(__name__.split(".")[0])
user_table_classes = (Manual, Lookup, Computed, Imported, Part)


class _AliasNode:
"""
special class to indicate aliased foreign keys
"""

pass


def _get_tier(table_name):
if not table_name.startswith("`"):
return _AliasNode
else:
try:
return next(
tier
for tier in user_table_classes
if re.fullmatch(tier.tier_regexp, table_name.split("`")[-2])
)
except StopIteration:
return None


if not diagram_active:
Expand All @@ -70,19 +48,22 @@ def __init__(self, *args, **kwargs):

class Diagram(nx.DiGraph):
"""
Entity relationship diagram.
Schema diagram showing tables and foreign keys between in the form of a directed
acyclic graph (DAG). The diagram is derived from the connection.dependencies object.
Usage:
>>> diag = Diagram(source)
source can be a base table object, a base table class, a schema, or a module that has a schema.
source can be a table object, a table class, a schema, or a module that has a schema.
>>> diag.draw()
draws the diagram using pyplot
diag1 + diag2 - combines the two diagrams.
diag1 - diag2 - differente between diagrams
diag1 * diag2 - intersction of diagrams
diag + n - expands n levels of successors
diag - n - expands n levels of predecessors
Thus dj.Diagram(schema.Table)+1-1 defines the diagram of immediate ancestors and descendants of schema.Table
Expand All @@ -91,7 +72,8 @@ class Diagram(nx.DiGraph):
Only those tables that are loaded in the connection object are displayed
"""

def __init__(self, source, context=None):
def __init__(self, source=None, context=None):

if isinstance(source, Diagram):
# copy constructor
self.nodes_to_show = set(source.nodes_to_show)
Expand Down Expand Up @@ -152,7 +134,7 @@ def from_sequence(cls, sequence):

def add_parts(self):
"""
Adds to the diagram the part tables of tables already included in the diagram
Adds to the diagram the part tables of all master tables already in the diagram
:return:
"""

Expand All @@ -177,14 +159,6 @@ def is_part(part, master):
)
return self

def topological_sort(self):
""":return: list of nodes in topological order"""
return list(
nx.algorithms.dag.topological_sort(
nx.DiGraph(self).subgraph(self.nodes_to_show)
)
)

def __add__(self, arg):
"""
:param arg: either another Diagram or a positive integer.
Expand Down Expand Up @@ -252,6 +226,10 @@ def __mul__(self, arg):
self.nodes_to_show.intersection_update(arg.nodes_to_show)
return self

def topo_sort(self):
"""return nodes in lexicographical topological order"""
return topo_sort(self)

def _make_graph(self):
"""
Make the self.graph - a graph object ready for drawing
Expand Down
79 changes: 75 additions & 4 deletions datajoint/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,17 @@
import logging
import inspect
import re
import collections
import itertools
from .connection import conn
from .diagram import Diagram
from .settings import config
from .errors import DataJointError, AccessError
from .jobs import JobTable
from .external import ExternalMapping
from .heading import Heading
from .utils import user_choice, to_camel_case
from .user_tables import Part, Computed, Imported, Manual, Lookup
from .table import lookup_class_name, Log
from .user_tables import Part, Computed, Imported, Manual, Lookup, _get_tier
from .table import lookup_class_name, Log, FreeTable
import types

logger = logging.getLogger(__name__.split(".")[0])
Expand Down Expand Up @@ -399,6 +400,76 @@ def jobs(self):
self._jobs = JobTable(self.connection, self.database)
return self._jobs

@property
def code(self):
self._assert_exists()
return self.save()

def save(self, python_filename=None):
"""
Generate the code for a module that recreates the schema.
This method is in preparation for a future release and is not officially supported.
:return: a string containing the body of a complete Python module defining this schema.
"""
self._assert_exists()
module_count = itertools.count()
# add virtual modules for referenced modules with names vmod0, vmod1, ...
module_lookup = collections.defaultdict(
lambda: "vmod" + str(next(module_count))
)
db = self.database

def make_class_definition(table):
tier = _get_tier(table).__name__
class_name = table.split(".")[1].strip("`")
indent = ""
if tier == "Part":
class_name = class_name.split("__")[-1]
indent += " "
class_name = to_camel_case(class_name)

def replace(s):
d, tabs = s.group(1), s.group(2)
return ("" if d == db else (module_lookup[d] + ".")) + ".".join(
to_camel_case(tab) for tab in tabs.lstrip("__").split("__")
)

return ("" if tier == "Part" else "\n@schema\n") + (
"{indent}class {class_name}(dj.{tier}):\n"
'{indent} definition = """\n'
'{indent} {defi}"""'
).format(
class_name=class_name,
indent=indent,
tier=tier,
defi=re.sub(
r"`([^`]+)`.`([^`]+)`",
replace,
FreeTable(self.connection, table).describe(),
).replace("\n", "\n " + indent),
)

tables = self.connection.dependencies.topo_sort()
body = "\n\n".join(make_class_definition(table) for table in tables)
python_code = "\n\n".join(
(
'"""This module was auto-generated by datajoint from an existing schema"""',
"import datajoint as dj\n\nschema = dj.Schema('{db}')".format(db=db),
"\n".join(
"{module} = dj.VirtualModule('{module}', '{schema_name}')".format(
module=v, schema_name=k
)
for k, v in module_lookup.items()
),
body,
)
)
if python_filename is None:
return python_code
with open(python_filename, "wt") as f:
f.write(python_code)

def list_tables(self):
"""
Return a list of all tables in the schema except tables with ~ in first character such
Expand All @@ -410,7 +481,7 @@ def list_tables(self):
t
for d, t in (
full_t.replace("`", "").split(".")
for full_t in Diagram(self).topological_sort()
for full_t in Diagram(self).topo_sort()
)
if d == self.database
]
Expand Down
1 change: 0 additions & 1 deletion datajoint/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,6 @@ def children(self, primary=None, as_objects=False, foreign_key_info=False):

def descendants(self, as_objects=False):
"""
:param as_objects: False - a list of table names; True - a list of table objects.
:return: list of tables descendants in topological order.
"""
Expand Down
27 changes: 27 additions & 0 deletions datajoint/user_tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
Hosts the table tiers, user tables should be derived from.
"""

import re
from .table import Table
from .autopopulate import AutoPopulate
from .utils import from_camel_case, ClassProperty
Expand Down Expand Up @@ -242,3 +243,29 @@ def drop(self, force=False):
def alter(self, prompt=True, context=None):
# without context, use declaration context which maps master keyword to master table
super().alter(prompt=prompt, context=context or self.declaration_context)


user_table_classes = (Manual, Lookup, Computed, Imported, Part)


class _AliasNode:
"""
special class to indicate aliased foreign keys
"""

pass


def _get_tier(table_name):
"""given the table name, return"""
if not table_name.startswith("`"):
return _AliasNode
else:
try:
return next(
tier
for tier in user_table_classes
if re.fullmatch(tier.tier_regexp, table_name.split("`")[-2])
)
except StopIteration:
return None
1 change: 0 additions & 1 deletion tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
"""

import json
import ast
import subprocess
import pytest
import datajoint as dj
Expand Down
8 changes: 8 additions & 0 deletions tests/test_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,14 @@ def test_list_tables(schema_simp):
assert actual == expected, f"Missing from list_tables(): {expected - actual}"


def test_schema_save_any(schema_any):
assert "class Experiment(dj.Imported)" in schema_any.code


def test_schema_save_empty(schema_empty):
assert "class Experiment(dj.Imported)" in schema_empty.code


def test_uppercase_schema(db_creds_root):
"""
https://github.com/datajoint/datajoint-python/issues/564
Expand Down

0 comments on commit f73bb59

Please sign in to comment.