From 89d46c61dd2d7179a56878d8ad544bd5b3c73ce4 Mon Sep 17 00:00:00 2001 From: jianfengmao Date: Wed, 8 Mar 2023 14:33:40 -0700 Subject: [PATCH] Support the use of PY callable's attr in query --- .../table/impl/lang/QueryLanguageParser.java | 4 ++-- .../engine/util/PyCallableWrapper.java | 12 +++++++++++ py/server/tests/test_table.py | 21 ++++++++++++++++++- 3 files changed, 34 insertions(+), 3 deletions(-) diff --git a/engine/table/src/main/java/io/deephaven/engine/table/impl/lang/QueryLanguageParser.java b/engine/table/src/main/java/io/deephaven/engine/table/impl/lang/QueryLanguageParser.java index e519a5b54c3..d32fdc3131f 100644 --- a/engine/table/src/main/java/io/deephaven/engine/table/impl/lang/QueryLanguageParser.java +++ b/engine/table/src/main/java/io/deephaven/engine/table/impl/lang/QueryLanguageParser.java @@ -376,7 +376,7 @@ private Class[] printArguments(Expression[] arguments, VisitArgs printer) { printer.append('('); for (int i = 0; i < arguments.length; i++) { - types.add(arguments[i].accept(this, printer)); + types.add(arguments[i].accept(this, printer.cloneWithCastingContext(null))); if (i != arguments.length - 1) { printer.append(", "); @@ -1650,7 +1650,7 @@ public Class visit(FieldAccessExpr n, VisitArgs printer) { try { // For Python object, the type of the field is PyObject by default, the actual data type if // primitive will only be known at runtime - if (scopeType == PyObject.class) { + if (scopeType == PyObject.class || scopeType == PyCallableWrapper.class) { ret = PyObject.class; } else { ret = scopeType.getField(fieldName).getType(); diff --git a/engine/table/src/main/java/io/deephaven/engine/util/PyCallableWrapper.java b/engine/table/src/main/java/io/deephaven/engine/util/PyCallableWrapper.java index 4b5cb079070..648822da562 100644 --- a/engine/table/src/main/java/io/deephaven/engine/util/PyCallableWrapper.java +++ b/engine/table/src/main/java/io/deephaven/engine/util/PyCallableWrapper.java @@ -3,6 +3,7 @@ import io.deephaven.engine.table.impl.select.python.ArgumentsChunked; import io.deephaven.internal.log.LoggerFactory; import io.deephaven.io.logger.Logger; +import org.jpy.PyLib; import org.jpy.PyModule; import org.jpy.PyObject; @@ -11,6 +12,9 @@ import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.Objects; + +import static org.jpy.PyLib.assertPythonRuns; /** * When given a pyObject that is a callable, we stick it inside the callable wrapper, which implements a call() varargs @@ -52,6 +56,14 @@ public PyCallableWrapper(PyObject pyCallable) { this.pyCallable = pyCallable; } + public PyObject getAttribute(String name) { + return this.pyCallable.getAttribute(name); + } + + public T getAttribute(String name, Class valueType) { + return this.pyCallable.getAttribute(name, valueType); + } + public ArgumentsChunked buildArgumentsChunked(List columnNames) { for (ChunkArgument arg : chunkArguments) { if (arg instanceof ColumnChunkArgument) { diff --git a/py/server/tests/test_table.py b/py/server/tests/test_table.py index 5300f6f908f..b04f2f086cf 100644 --- a/py/server/tests/test_table.py +++ b/py/server/tests/test_table.py @@ -5,9 +5,10 @@ from types import SimpleNamespace from typing import List, Any -from deephaven import DHError, read_csv, empty_table, SortDirection, AsOfMatchRule, time_table, ugp +from deephaven import DHError, read_csv, empty_table, SortDirection, AsOfMatchRule, time_table, ugp, new_table, dtypes from deephaven.agg import sum_, weighted_avg, avg, pct, group, count_, first, last, max_, median, min_, std, abs_sum, \ var, formula, partition +from deephaven.column import datetime_col from deephaven.execution_context import make_user_exec_ctx from deephaven.html import to_html from deephaven.jcompat import j_hashmap @@ -902,6 +903,24 @@ def make_pairs_3(tid, a, b): self.assertEqual(x2.size, 10) self.assertEqual(x3.size, 10) + def test_class_attrs_in_query(self): + input_cols = [ + datetime_col(name="DTCol", data=[dtypes.DateTime(1), dtypes.DateTime(10000000)]), + ] + test_table = new_table(cols=input_cols) + from deephaven.time import year, TimeZone + rt = test_table.update("Year = (int)year(DTCol, TimeZone.NY)") + self.assertEqual(rt.size, test_table.size) + + class Foo: + ATTR = 1 + + rt = empty_table(1).update("Col = Foo.ATTR") + self.assertTrue(rt.columns[0].data_type == dtypes.PyObject) + + rt = empty_table(1).update("Col = (int)Foo.ATTR") + self.assertTrue(rt.columns[0].data_type == dtypes.int32) + if __name__ == "__main__": unittest.main()