Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closes #3904: function to return list of all compiled dimensions avai… #3909

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 46 additions & 1 deletion arkouda/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
username = security.get_username()
connected = False
serverConfig = None
registrationConfig = None
# verbose flag for arkouda module
verboseDefVal = False
verbose = verboseDefVal
Expand Down Expand Up @@ -706,7 +707,7 @@ def connect(
On success, prints the connected address, as seen by the server. If called
with an existing connection, the socket will be re-initialized.
"""
global connected, serverConfig, verbose, regexMaxCaptures, channel
global connected, serverConfig, verbose, regexMaxCaptures, channel, registrationConfig

# send the connect message
cmd = "connect"
Expand Down Expand Up @@ -741,6 +742,7 @@ def connect(
RuntimeWarning,
)
regexMaxCaptures = serverConfig["regexMaxCaptures"] # type: ignore
registrationConfig = _get_registration_config_msg()
clientLogger.info(return_message)


Expand Down Expand Up @@ -1063,6 +1065,29 @@ def get_max_array_rank() -> int:
return int(serverConfig["maxArrayDims"])


def get_array_ranks() -> list[int]:
"""
Get the list of pdarray ranks the server was compiled to support

This value corresponds to
parameter_classes -> array -> nd in the `registration-config.json`
file when the server was compiled.

Returns
-------
list of int
The pdarray ranks supported by the server
"""

if registrationConfig is None:
raise RuntimeError(
"There was a problem loading registrationConfig."
"Make sure the client is connected to a server."
)

return registrationConfig["parameter_classes"]["array"]["nd"]


def _get_config_msg() -> Mapping[str, Union[str, int, float]]:
"""
Get runtime information about the server.
Expand All @@ -1083,6 +1108,26 @@ def _get_config_msg() -> Mapping[str, Union[str, int, float]]:
raise RuntimeError(f"{e} in retrieving Arkouda server config")


def _get_registration_config_msg() -> dict:
"""
Get runtime information about the command registration configuration.

Raises
------
RuntimeError
Raised if there is a server-side error in getting memory used
ValueError
Raised if there's an error in parsing the JSON-formatted server config
"""
try:
raw_message = cast(str, generic_msg(cmd="getRegistrationConfig"))
return json.loads(raw_message)
except json.decoder.JSONDecodeError:
raise ValueError(f"Returned config is not valid JSON: {raw_message}")
except Exception as e:
raise RuntimeError(f"{e} in retrieving Arkouda server config")


def get_mem_used(unit: str = "b", as_percent: bool = False) -> int:
"""
Compute the amount of memory used by objects in the server's symbol table.
Expand Down
124 changes: 41 additions & 83 deletions src/registry/register_commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,9 +226,7 @@ def get_formal_type_spec(formal):
else:
# TODO: `x: []` and `x: [?d]` are currently treated as invalid formal type expressions
raise ValueError("invalid formal type expression")
elif isinstance(
te, chapel.FnCall
): # Composite types (list and borrowed-class)
elif isinstance(te, chapel.FnCall): # Composite types (list and borrowed-class)
if ce := te.called_expression():
if isinstance(ce, chapel.Identifier):
call_name = ce.name()
Expand All @@ -240,9 +238,7 @@ def get_formal_type_spec(formal):
f"unsupported formal type for registration {list_type}; list element type must be a scalar",
)
else:
return FormalTypeSpec(
formalKind.LIST, name, sk, info=list_type
)
return FormalTypeSpec(formalKind.LIST, name, sk, info=list_type)
elif call_name == "borrowed":
actuals = list(te.actuals())
if isinstance(actuals[0], chapel.FnCall):
Expand All @@ -251,9 +247,7 @@ def get_formal_type_spec(formal):
else:
# concrete class formal (e.g., 'borrowed MySymEntry')
class_name = list(te.actuals())[0].name()
return FormalTypeSpec(
formalKind.BORROWED_CLASS, name, sk, class_name
)
return FormalTypeSpec(formalKind.BORROWED_CLASS, name, sk, class_name)
else:
error_message(
f"registering '{fn.name()}'",
Expand Down Expand Up @@ -349,9 +343,7 @@ def clean_enum_name(name):
return name


def stamp_generic_command(
generic_proc_name, prefix, module_name, formals, line_num, iar_annotation
):
def stamp_generic_command(generic_proc_name, prefix, module_name, formals, line_num, iar_annotation):
"""
Create code to stamp out and register a generic command using a generic
procedure, and a set values for its generic formals.
Expand All @@ -367,22 +359,15 @@ def stamp_generic_command(
+ ",".join(
[
# if the generic formal is a 'type' convert it to its numpy dtype name
(
chapel_scalar_types[v]
if v in chapel_scalar_types
else clean_enum_name(str(v))
)
(chapel_scalar_types[v] if v in chapel_scalar_types else clean_enum_name(str(v)))
for _, v in formals.items()
]
)
+ ">"
)

stamp_name = f"ark_{clean_stamp_name(prefix)}_" + "_".join(
[
clean_enum_name(str(v)).replace("(", "").replace(")", "")
for _, v in formals.items()
]
[clean_enum_name(str(v)).replace("(", "").replace(")", "") for _, v in formals.items()]
)

stamp_formal_args = ", ".join([f"{k}={v}" for k, v in formals.items()])
Expand Down Expand Up @@ -434,9 +419,7 @@ def parse_param_class_value(value):
if isinstance(value, list):
for v in value:
if not isinstance(v, (int, float, str)):
raise ValueError(
f"Invalid parameter value type ({type(v)}) in list '{value}'"
)
raise ValueError(f"Invalid parameter value type ({type(v)}) in list '{value}'")
return value
elif isinstance(value, int):
return [
Expand All @@ -452,9 +435,7 @@ def parse_param_class_value(value):
if isinstance(vals, list):
return vals
else:
raise ValueError(
f"Could not create a list of parameter values from '{value}'"
)
raise ValueError(f"Could not create a list of parameter values from '{value}'")
elif isinstance(value, dict) and "__enum__" in value and "__variants__" in value:
enum_name = value["__enum__"].split(".")[-1]
return [f"{enum_name}.{var}" for var in value["__variants__"]]
Expand Down Expand Up @@ -566,17 +547,11 @@ def unpack_array_arg(arg_name, array_count, finfo, domain_queries, dtype_queries
nd_generic_formal_info = None
else:
nd_arg_name = "array_nd_" + str(array_count)
nd_generic_formal_info = FormalTypeSpec(
formalKind.SCALAR, nd_arg_name, "param", "int"
)
nd_generic_formal_info = FormalTypeSpec(formalKind.SCALAR, nd_arg_name, "param", "int")

# check if the array formal has a static type or a type-query
# if not, generate a unique name and formal info for the dtype argument
if (
finfo is not None
and finfo[1] is not None
and isinstance(finfo[1], StaticTypeInfo)
):
if finfo is not None and finfo[1] is not None and isinstance(finfo[1], StaticTypeInfo):
dtype_arg_name = finfo[1].value
dtype_generic_formal_info = None
elif (
Expand All @@ -601,9 +576,7 @@ def unpack_array_arg(arg_name, array_count, finfo, domain_queries, dtype_queries
)


def unpack_generic_symbol_arg(
arg_name, symbol_class_name, symbol_count, symbol_param_class
):
def unpack_generic_symbol_arg(arg_name, symbol_class_name, symbol_count, symbol_param_class):
"""
Generate the code to unpack a non-array symbol-table entry class (a class that
inherits from 'AbstractSymEntry').
Expand Down Expand Up @@ -632,13 +605,8 @@ def unpack_generic_symbol_arg(
type_str = None
else:
storage_kind = "param"
type_str = (
"int" # TODO: also support strings and other param-able types here
)
elif (
isinstance(symbol_param_class[k], dict)
and "__enum__" in symbol_param_class[k].keys()
):
type_str = "int" # TODO: also support strings and other param-able types here
elif isinstance(symbol_param_class[k], dict) and "__enum__" in symbol_param_class[k].keys():
storage_kind = "param"
type_str = symbol_param_class[k]["__enum__"].split(".")[-1]
else:
Expand Down Expand Up @@ -770,9 +738,7 @@ def gen_arg_unpacking(formals, config):

for formal_spec in formals:
if formal_spec.is_chapel_scalar_type():
unpack_lines.append(
unpack_scalar_arg(formal_spec.name, formal_spec.type_str)
)
unpack_lines.append(unpack_scalar_arg(formal_spec.name, formal_spec.type_str))
elif formal_spec.kind == formalKind.ARRAY:
# finfo[0] is the domain query info, finfo[1] is the dtype query info
finfo = formal_spec.info
Expand All @@ -796,11 +762,7 @@ def gen_arg_unpacking(formals, config):
# this allows homogeneous-tuple formal types to use the array's rank as a size argument
# Do the same for dtype queries
if finfo is not None:
if (
finfo[0] is not None
and isinstance(finfo[0], FormalQuery)
and gen_nd_arg is not None
):
if finfo[0] is not None and isinstance(finfo[0], FormalQuery) and gen_nd_arg is not None:
array_domain_queries[finfo[0].name] = gen_nd_arg.name
if (
finfo[1] is not None
Expand Down Expand Up @@ -854,9 +816,7 @@ def gen_arg_unpacking(formals, config):
# a scalar formal with a generic type
if formal_spec.type_str is not None:
if queried_type := array_dtype_queries[formal_spec.type_str]:
unpack_lines.append(
unpack_scalar_arg(formal_spec.name, queried_type)
)
unpack_lines.append(unpack_scalar_arg(formal_spec.name, queried_type))
else:
# TODO: fully handle generic user-defined types
code, scalar_args = unpack_scalar_arg_with_generic(
Expand Down Expand Up @@ -959,19 +919,13 @@ def gen_command_proc(name, return_type, formals, mod_name, config):
arg_unpack, command_formals, query_table = gen_arg_unpacking(formals, config)
is_generic_command = len(command_formals) > 0
signature, cmd_name = gen_signature(name, command_formals)
fn_call, result_name = gen_user_function_call(
name, [f.name for f in formals], mod_name, return_type
)
fn_call, result_name = gen_user_function_call(name, [f.name for f in formals], mod_name, return_type)

# get the names of the array-elt-type queries in the formals
array_etype_queries = [
f.info[1].name
for f in formals
if (
f.kind == formalKind.ARRAY
and len(f.info) > 0
and isinstance(f.info[1], FormalQuery)
)
if (f.kind == formalKind.ARRAY and len(f.info) > 0 and isinstance(f.info[1], FormalQuery))
]

def return_type_fn_name():
Expand All @@ -993,27 +947,25 @@ def return_type_fn_name():
or (
# TODO: do resolution to ensure that this is a class type that inherits from 'AbstractSymEntry'
return_type_fn_name() is not None
and return_type_fn_name() in ["SymEntry",] + list(config["parameter_classes"].keys())
and return_type_fn_name()
in [
"SymEntry",
]
+ list(config["parameter_classes"].keys())
)
)
returns_array = (
return_type
and isinstance(return_type, chapel.BracketLoop)
and return_type.is_maybe_array_type()
return_type and isinstance(return_type, chapel.BracketLoop) and return_type.is_maybe_array_type()
)

if returns_array:
symbol_creation, result_name = gen_symbol_creation(
ARRAY_ENTRY_CLASS_NAME, result_name
)
symbol_creation, result_name = gen_symbol_creation(ARRAY_ENTRY_CLASS_NAME, result_name)
else:
symbol_creation = ""

response = gen_response(result_name, returns_symbol or returns_array)

command_proc = "\n".join(
[signature, arg_unpack, fn_call, symbol_creation, response, "}"]
)
command_proc = "\n".join([signature, arg_unpack, fn_call, symbol_creation, response, "}"])

return (command_proc, cmd_name, is_generic_command, command_formals, query_table)

Expand Down Expand Up @@ -1236,9 +1188,7 @@ def stamp_out_command(
if wcn := wc_node:
if not wcn.eval(fp, query_table):
continue
stamp = stamp_generic_command(
name, cmd_prefix, mod_name, fp, line_num, iar_annotation
)
stamp = stamp_generic_command(name, cmd_prefix, mod_name, fp, line_num, iar_annotation)
yield stamp


Expand All @@ -1248,9 +1198,7 @@ def extract_enum_imports(config):
if isinstance(config[k], dict):
if "__enum__" in config[k].keys():
if "__variants__" not in config[k].keys():
raise ValueError(
f"enum '{k}' is missing '__variants__' field in configuration file"
)
raise ValueError(f"enum '{k}' is missing '__variants__' field in configuration file")
imports.append(f"import {config[k]['__enum__']};")
else:
imports += extract_enum_imports(config[k])
Expand All @@ -1264,13 +1212,23 @@ def register_commands(config, source_files):
"""
stamps = [
"module Commands {",
"use CommandMap, Message, MultiTypeSymbolTable, MultiTypeSymEntry;",
"use CommandMap, IOUtils, Message, MultiTypeSymbolTable, MultiTypeSymEntry;",
"use BigInteger;",
watermarkConfig(config),
]

stamps += extract_enum_imports(config)

stamps.append("""proc getRegistrationConfig(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws {
return new MsgTuple(getRegConfig(), MsgType.NORMAL);
}""")


stamps.append(
"proc getRegConfig(): string throws do return try! regConfig;"
"\nregisterFunction('getRegistrationConfig', getRegistrationConfig, 'Commands', 68);"
)

count = 0

for filename, ctx in chapel.files_with_contexts(source_files):
Expand Down Expand Up @@ -1313,8 +1271,8 @@ def register_commands(config, source_files):
)
continue

(cmd_proc, cmd_name, is_generic_cmd, cmd_gen_formals, query_table) = (
gen_command_proc(name, fn.return_type(), con_formals, mod_name, config)
(cmd_proc, cmd_name, is_generic_cmd, cmd_gen_formals, query_table) = gen_command_proc(
name, fn.return_type(), con_formals, mod_name, config
)

file_stamps.append(cmd_proc)
Expand Down
8 changes: 8 additions & 0 deletions tests/client_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,3 +156,11 @@ def test_client_get_server_commands(self):
cmds = ak.client.get_server_commands()
for cmd in ["connect", "info", "str"]:
assert cmd in cmds

def test_get_array_ranks(self):
availableRanks = ak.client.get_array_ranks()
assert isinstance(availableRanks, list)
assert len(availableRanks) >= 1
assert 1 in availableRanks
assert ak.client.get_max_array_rank() in availableRanks
assert ak.client.get_max_array_rank() + 1 not in availableRanks
Loading