Skip to content

Commit

Permalink
feat!: Support Multijoin in the Python client (#6020)
Browse files Browse the repository at this point in the history
Fixes #5884

BREAKING CHANGE: the `table()` method on `MultiJoinTable` has been
changed to `@property`, which means to access the underlying table of a
multi join result, the user can no longer call the `table()` method but
instead just use the `table` attribute on a `MultiJoinTable` instance.

---------

Co-authored-by: Chip Kent <[email protected]>
  • Loading branch information
jmao-denver and chipkent authored Sep 30, 2024
1 parent 09a7849 commit 057c0a6
Show file tree
Hide file tree
Showing 7 changed files with 257 additions and 28 deletions.
23 changes: 23 additions & 0 deletions py/client/pydeephaven/_table_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -687,3 +687,26 @@ def make_grpc_request(self, result_id, source_id) -> Any:
def make_grpc_request_for_batch(self, result_id, source_id) -> Any:
return table_pb2.BatchTableRequest.Operation(
meta_table=self.make_grpc_request(result_id=result_id, source_id=source_id))


class MultijoinTablesOp(TableOp):
def __init__(self, multi_join_inputs: List["MultiJoinInput"]):
self.multi_join_inputs = multi_join_inputs

@classmethod
def get_stub_func(cls, table_service_stub: table_pb2_grpc.TableServiceStub) -> Any:
return table_service_stub.MultiJoinTables

def make_grpc_request(self, result_id, source_id) -> Any:
pb_inputs = []
for mji in self.multi_join_inputs:
source_id = table_pb2.TableReference(ticket=mji.table.ticket.pb_ticket)
columns_to_match = mji.on
columns_to_add = mji.joins
pb_inputs.append(table_pb2.MultiJoinInput(source_id=source_id, columns_to_match=columns_to_match,
columns_to_add=columns_to_add))
return table_pb2.MultiJoinTablesRequest(result_id=result_id, multi_join_inputs=pb_inputs)

def make_grpc_request_for_batch(self, result_id, source_id) -> Any:
return table_pb2.BatchTableRequest.Operation(
multi_join=self.make_grpc_request(result_id=result_id, source_id=source_id))
4 changes: 2 additions & 2 deletions py/client/pydeephaven/_table_service.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#
# Copyright (c) 2016-2024 Deephaven Data Labs and Patent Pending
#
from typing import Union, List
from typing import Union, List, Optional

from pydeephaven._batch_assembler import BatchOpAssembler
from pydeephaven._table_ops import TableOp
Expand Down Expand Up @@ -38,7 +38,7 @@ def batch(self, ops: List[TableOp]) -> Table:
except Exception as e:
raise DHError("failed to finish the table batch operation.") from e

def grpc_table_op(self, table: Table, op: TableOp, table_class: type = Table) -> Union[Table, InputTable]:
def grpc_table_op(self, table: Optional[Table], op: TableOp, table_class: type = Table) -> Union[Table, InputTable]:
"""Makes a single gRPC Table operation call and returns a new Table."""
try:
export_ticket = self.session.make_export_ticket()
Expand Down
2 changes: 1 addition & 1 deletion py/client/pydeephaven/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,7 +397,7 @@ def _connect(self):
# started together don't align retries.
skew = random()
# Backoff schedule for retries after consecutive failures to refresh auth token
self._refresh_backoff = [ skew + 0.1, skew + 1, skew + 10 ]
self._refresh_backoff = [skew + 0.1, skew + 1, skew + 10]

if self._refresh_backoff[0] > self._timeout_seconds:
raise DHError(f'server configuration http.session.durationMs={session_duration} is too small.')
Expand Down
82 changes: 80 additions & 2 deletions py/client/pydeephaven/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,13 @@

from __future__ import annotations

from typing import List, Union
from typing import List, Union, Sequence

import pyarrow as pa

from pydeephaven._utils import to_list

from pydeephaven._table_ops import MetaTableOp, SortDirection
from pydeephaven._table_ops import MetaTableOp, SortDirection, MultijoinTablesOp
from pydeephaven.agg import Aggregation
from pydeephaven.dherror import DHError
from pydeephaven._table_interface import TableInterface
Expand Down Expand Up @@ -804,3 +804,81 @@ def delete(self, table: Table) -> None:
self.session.input_table_service.delete(self, table)
except Exception as e:
raise DHError("delete data in the InputTable failed.") from e


class MultiJoinTable:
"""A MultiJoinTable is an object that contains the result of a multi-table natural join. To retrieve the underlying
result Table, use the :attr:`.table` property. """

def __init__(self, table: Table):
self._table = table

@property
def table(self) -> Table:
"""Returns the Table containing the multi-table natural join output. """
return self._table


class MultiJoinInput:
"""A MultiJoinInput represents the input tables, key columns and additional columns to be used in the multi-table
natural join.
"""
table: Table
on: Union[str, Sequence[str]]
joins: Union[str, Sequence[str]] = None

def __init__(self, table: Table, on: Union[str, Sequence[str]], joins: Union[str, Sequence[str]] = None):
"""Initializes a MultiJoinInput object.
Args:
table (Table): the right table to include in the join
on (Union[str, Sequence[str]]): the column(s) to match, can be a common name or an equality expression that
matches every input table, i.e. "col_a = col_b" to rename output column names.
joins (Union[str, Sequence[str]], optional): the column(s) to be added from the table to the result
table, can be renaming expressions, i.e. "new_col = col"; default is None
"""
self.table = table
self.on = to_list(on)
self.joins = to_list(joins)


def multi_join(input: Union[Table, Sequence[Table], MultiJoinInput, Sequence[MultiJoinInput]],
on: Union[str, Sequence[str]] = None) -> MultiJoinTable:
""" The multi_join method creates a new table by performing a multi-table natural join on the input tables. The
result consists of the set of distinct keys from the input tables natural joined to each input table. Input
tables need not have a matching row for each key, but they may not have multiple matching rows for a given key.
Args:
input (Union[Table, Sequence[Table], MultiJoinInput, Sequence[MultiJoinInput]]): the input objects specifying the
tables and columns to include in the join.
on (Union[str, Sequence[str]], optional): the column(s) to match, can be a common name or an equality expression
that matches every input table, i.e. "col_a = col_b" to rename output column names. Note: When
MultiJoinInput objects are supplied, this parameter must be omitted.
Returns:
MultiJoinTable: the result of the multi-table natural join operation. To access the underlying Table, use the
:attr:`~MultiJoinTable.table` property.
Raises:
DHError
"""
if isinstance(input, Table) or (isinstance(input, Sequence) and all(isinstance(t, Table) for t in input)):
tables = to_list(input)
session = tables[0].session
if not all([t.session == session for t in tables]):
raise DHError(message="all tables must be from the same session.")
multi_join_inputs = [MultiJoinInput(table=t, on=on) for t in tables]
elif isinstance(input, MultiJoinInput) or (
isinstance(input, Sequence) and all(isinstance(ji, MultiJoinInput) for ji in input)):
if on is not None:
raise DHError(message="on parameter is not permitted when MultiJoinInput objects are provided.")
multi_join_inputs = to_list(input)
session = multi_join_inputs[0].table.session
if not all([mji.table.session == session for mji in multi_join_inputs]):
raise DHError(message="all tables must be from the same session.")
else:
raise DHError(
message="input must be a Table, a sequence of Tables, a MultiJoinInput, or a sequence of MultiJoinInputs.")

table_op = MultijoinTablesOp(multi_join_inputs=multi_join_inputs)
return MultiJoinTable(table=session.table_service.grpc_table_op(None, table_op, table_class=Table))
127 changes: 127 additions & 0 deletions py/client/tests/test_multijoin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
#
# Copyright (c) 2016-2024 Deephaven Data Labs and Patent Pending
#
import unittest

from pyarrow import csv

from pydeephaven import DHError, Session
from pydeephaven.table import MultiJoinInput, multi_join
from tests.testbase import BaseTestCase


class MultiJoinTestCase(BaseTestCase):
def setUp(self):
super().setUp()
pa_table = csv.read_csv(self.csv_file)
self.static_tableA = self.session.import_table(pa_table).select(["a", "b", "c1=c", "d1=d", "e1=e"])
self.static_tableB = self.static_tableA.update(["c2=c1+1", "d2=d1+2", "e2=e1+3"]).drop_columns(
["c1", "d1", "e1"])
self.ticking_tableA = self.session.time_table("PT00:00:00.001").update(
["a = i", "b = i*i % 13", "c1 = i * 13 % 23", "d1 = a + b", "e1 = a - b"]).drop_columns(["Timestamp"])
self.ticking_tableB = self.ticking_tableA.update(["c2=c1+1", "d2=d1+2", "e2=e1+3"]).drop_columns(
["c1", "d1", "e1"])

def tearDown(self) -> None:
self.static_tableA = None
self.static_tableB = None
self.ticking_tableA = None
self.ticking_tableB = None
super().tearDown()

def test_static_simple(self):
# Test with multiple input tables
mj_table = multi_join(input=[self.static_tableA, self.static_tableB], on=["a", "b"])

# Output table is static
self.assertFalse(mj_table.table.is_refreshing)
# Output table has same # rows as sources
self.assertEqual(mj_table.table.size, self.static_tableA.size)
self.assertEqual(mj_table.table.size, self.static_tableB.size)

# Test with a single input table
mj_table = multi_join(self.static_tableA, ["a", "b"])

# Output table is static
self.assertFalse(mj_table.table.is_refreshing)
# Output table has same # rows as sources
self.assertEqual(mj_table.table.size, self.static_tableA.size)

def test_ticking_simple(self):
# Test with multiple input tables
mj_table = multi_join(input=[self.ticking_tableA, self.ticking_tableB], on=["a", "b"])

# Output table is refreshing
self.assertTrue(mj_table.table.is_refreshing)

# Test with a single input table
mj_table = multi_join(input=self.ticking_tableA, on=["a", "b"])

# Output table is refreshing
self.assertTrue(mj_table.table.is_refreshing)

def test_static(self):
# Test with multiple input
mj_input = [
MultiJoinInput(table=self.static_tableA, on=["key1=a", "key2=b"], joins=["c1", "e1"]),
MultiJoinInput(table=self.static_tableB, on=["key1=a", "key2=b"], joins=["d2"])
]
mj_table = multi_join(mj_input)

# Output table is static
self.assertFalse(mj_table.table.is_refreshing)
# Output table has same # rows as sources
self.assertEqual(mj_table.table.size, self.static_tableA.size)
self.assertEqual(mj_table.table.size, self.static_tableB.size)

# Test with a single input
mj_table = multi_join(MultiJoinInput(table=self.static_tableA, on=["key1=a", "key2=b"], joins="c1"))

# Output table is static
self.assertFalse(mj_table.table.is_refreshing)
# Output table has same # rows as sources
self.assertEqual(mj_table.table.size, self.static_tableA.size)

def test_ticking(self):
# Test with multiple input
mj_input = [
MultiJoinInput(table=self.ticking_tableA, on=["key1=a", "key2=b"], joins=["c1", "e1"]),
MultiJoinInput(table=self.ticking_tableB, on=["key1=a", "key2=b"], joins=["d2"])
]
mj_table = multi_join(mj_input)

# Output table is refreshing
self.assertTrue(mj_table.table.is_refreshing)

# Test with a single input
mj_table = multi_join(input=MultiJoinInput(table=self.ticking_tableA, on=["key1=a", "key2=b"], joins="c1"))

# Output table is refreshing
self.assertTrue(mj_table.table.is_refreshing)

def test_errors(self):
# Assert the exception is raised when providing MultiJoinInput and the on parameter is not None (omitted).
mj_input = [
MultiJoinInput(table=self.ticking_tableA, on=["key1=a", "key2=b"], joins=["c1", "e1"]),
MultiJoinInput(table=self.ticking_tableB, on=["key1=a", "key2=b"], joins=["d2"])
]
with self.assertRaises(DHError) as cm:
mj_table = multi_join(mj_input, on=["key1=a", "key2=b"])
self.assertIn("on parameter is not permitted", str(cm.exception))

session = Session()
t = session.time_table("PT00:00:00.001").update(
["a = i", "b = i*i % 13", "c1 = i * 13 % 23", "d1 = a + b", "e1 = a - b"]).drop_columns(["Timestamp"])

# Assert the exception is raised when to-be-joined tables are not from the same session.
mj_input = [
MultiJoinInput(table=self.ticking_tableA, on=["key1=a", "key2=b"], joins=["c1", "e1"]),
MultiJoinInput(table=t, on=["key1=a", "key2=b"], joins=["d2"])
]
with self.assertRaises(DHError) as cm:
mj_table = multi_join(mj_input)
self.assertIn("all tables must be from the same session", str(cm.exception))


if __name__ == '__main__':
unittest.main()
7 changes: 4 additions & 3 deletions py/server/deephaven/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -3786,7 +3786,7 @@ def __init__(self, table: Table, on: Union[str, Sequence[str]], joins: Union[str
table (Table): the right table to include in the join
on (Union[str, Sequence[str]]): the column(s) to match, can be a common name or an equal expression,
i.e. "col_a = col_b" for different column names
joins (Union[str, Sequence[str]], optional): the column(s) to be added from the this table to the result
joins (Union[str, Sequence[str]], optional): the column(s) to be added from the table to the result
table, can be renaming expressions, i.e. "new_col = col"; default is None
Raises:
Expand All @@ -3803,13 +3803,14 @@ def __init__(self, table: Table, on: Union[str, Sequence[str]], joins: Union[str

class MultiJoinTable(JObjectWrapper):
"""A MultiJoinTable is an object that contains the result of a multi-table natural join. To retrieve the underlying
result Table, use the table() method. """
result Table, use the :attr:`.table` property. """
j_object_type = _JMultiJoinTable

@property
def j_object(self) -> jpy.JType:
return self.j_multijointable

@property
def table(self) -> Table:
"""Returns the Table containing the multi-table natural join output. """
return Table(j_table=self.j_multijointable.table())
Expand Down Expand Up @@ -3866,7 +3867,7 @@ def multi_join(input: Union[Table, Sequence[Table], MultiJoinInput, Sequence[Mul
Returns:
MultiJoinTable: the result of the multi-table natural join operation. To access the underlying Table, use the
table() method.
:attr:`~MultiJoinTable.table` property.
"""
return MultiJoinTable(input, on)

Expand Down
Loading

0 comments on commit 057c0a6

Please sign in to comment.