Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support the use of Python callable's attributes in query strings #3509

Merged
merged 5 commits into from
Mar 10, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -376,7 +376,7 @@ private Class<?>[] printArguments(Expression[] arguments, VisitArgs printer) {

printer.append('(');
for (int i = 0; i < arguments.length; i++) {
types.add(arguments[i].accept(this, printer));
types.add(arguments[i].accept(this, printer.cloneWithCastingContext(null)));

if (i != arguments.length - 1) {
printer.append(", ");
Expand Down Expand Up @@ -483,7 +483,7 @@ private Method getMethod(final Class<?> scope, final String methodName, final Cl
}
}
} else {
if (scope == org.jpy.PyObject.class) {
if (scope == org.jpy.PyObject.class || scope == PyCallableWrapper.class) {
// This is a Python method call, assume it exists and wrap in PythonScopeJpyImpl.CallableWrapper
for (Method method : PyCallableWrapper.class.getDeclaredMethods()) {
possiblyAddExecutable(acceptableMethods, method, "call", paramTypes, parameterizedTypes);
Expand Down Expand Up @@ -1650,7 +1650,7 @@ public Class<?> visit(FieldAccessExpr n, VisitArgs printer) {
try {
// For Python object, the type of the field is PyObject by default, the actual data type if
// primitive will only be known at runtime
if (scopeType == PyObject.class) {
if (scopeType == PyObject.class || scopeType == PyCallableWrapper.class) {
ret = PyObject.class;
} else {
ret = scopeType.getField(fieldName).getType();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import io.deephaven.engine.table.impl.select.python.ArgumentsChunked;
import io.deephaven.internal.log.LoggerFactory;
import io.deephaven.io.logger.Logger;
import org.jpy.PyLib;
import org.jpy.PyModule;
import org.jpy.PyObject;

Expand All @@ -11,6 +12,9 @@
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;

import static org.jpy.PyLib.assertPythonRuns;

/**
* When given a pyObject that is a callable, we stick it inside the callable wrapper, which implements a call() varargs
Expand Down Expand Up @@ -52,6 +56,14 @@ public PyCallableWrapper(PyObject pyCallable) {
this.pyCallable = pyCallable;
}

public PyObject getAttribute(String name) {
return this.pyCallable.getAttribute(name);
}

public <T> T getAttribute(String name, Class<? extends T> valueType) {
return this.pyCallable.getAttribute(name, valueType);
}

public ArgumentsChunked buildArgumentsChunked(List<String> columnNames) {
for (ChunkArgument arg : chunkArguments) {
if (arg instanceof ColumnChunkArgument) {
Expand Down
39 changes: 38 additions & 1 deletion py/server/tests/test_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@
from types import SimpleNamespace
from typing import List, Any

from deephaven import DHError, read_csv, empty_table, SortDirection, AsOfMatchRule, time_table, ugp
from deephaven import DHError, read_csv, empty_table, SortDirection, AsOfMatchRule, time_table, ugp, new_table, dtypes
from deephaven.agg import sum_, weighted_avg, avg, pct, group, count_, first, last, max_, median, min_, std, abs_sum, \
var, formula, partition
from deephaven.column import datetime_col
from deephaven.execution_context import make_user_exec_ctx
from deephaven.html import to_html
from deephaven.jcompat import j_hashmap
Expand Down Expand Up @@ -902,6 +903,42 @@ def make_pairs_3(tid, a, b):
self.assertEqual(x2.size, 10)
self.assertEqual(x3.size, 10)

def test_callable_attrs_in_query(self):
input_cols = [
datetime_col(name="DTCol", data=[dtypes.DateTime(1), dtypes.DateTime(10000000)]),
]
test_table = new_table(cols=input_cols)
from deephaven.time import year, TimeZone
rt = test_table.update("Year = (int)year(DTCol, TimeZone.NY)")
jmao-denver marked this conversation as resolved.
Show resolved Hide resolved
self.assertEqual(rt.size, test_table.size)

class Foo:
jmao-denver marked this conversation as resolved.
Show resolved Hide resolved
ATTR = 256

def __call__(self):
...

def do_something(self, p=None):
return p if p else 1

def do_something(p=None):
return p if p else 1

rt = empty_table(1).update("Col = Foo.ATTR")
self.assertTrue(rt.columns[0].data_type == dtypes.PyObject)

rt = empty_table(1).update("Col = (int)Foo.ATTR")
self.assertTrue(rt.columns[0].data_type == dtypes.int32)

foo = Foo()
rt = empty_table(1).update("Col = (int)foo.do_something()")
self.assertTrue(rt.columns[0].data_type == dtypes.int32)

rt = empty_table(1).update("Col = (int)do_something((byte)Foo.ATTR)")
df = to_pandas(rt)
self.assertEqual(df.loc[0]['Col'], 1)
self.assertTrue(rt.columns[0].data_type == dtypes.int32)


if __name__ == "__main__":
unittest.main()