Skip to content

Commit

Permalink
Closes #3904: function to return list of all compiled dimensions avai…
Browse files Browse the repository at this point in the history
…lable
  • Loading branch information
ajpotts committed Nov 20, 2024
1 parent b456d3e commit 765d392
Show file tree
Hide file tree
Showing 3 changed files with 95 additions and 84 deletions.
47 changes: 46 additions & 1 deletion arkouda/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
username = security.get_username()
connected = False
serverConfig = None
registrationConfig = None
# verbose flag for arkouda module
verboseDefVal = False
verbose = verboseDefVal
Expand Down Expand Up @@ -706,7 +707,7 @@ def connect(
On success, prints the connected address, as seen by the server. If called
with an existing connection, the socket will be re-initialized.
"""
global connected, serverConfig, verbose, regexMaxCaptures, channel
global connected, serverConfig, verbose, regexMaxCaptures, channel, registrationConfig

# send the connect message
cmd = "connect"
Expand Down Expand Up @@ -741,6 +742,7 @@ def connect(
RuntimeWarning,
)
regexMaxCaptures = serverConfig["regexMaxCaptures"] # type: ignore
registrationConfig = _get_registration_config_msg()
clientLogger.info(return_message)


Expand Down Expand Up @@ -1063,6 +1065,29 @@ def get_max_array_rank() -> int:
return int(serverConfig["maxArrayDims"])


def get_array_ranks() -> list[int]:
"""
Get the list of pdarray ranks the server was compiled to support
This value corresponds to
parameter_classes -> array -> nd in the `registration-config.json`
file when the server was compiled.
Returns
-------
list of int
The pdarray ranks supported by the server
"""

if registrationConfig is None:
raise RuntimeError(
"There was a problem loading registrationConfig."
"Make sure the client is connected to a server."
)

return registrationConfig["parameter_classes"]["array"]["nd"]


def _get_config_msg() -> Mapping[str, Union[str, int, float]]:
"""
Get runtime information about the server.
Expand All @@ -1083,6 +1108,26 @@ def _get_config_msg() -> Mapping[str, Union[str, int, float]]:
raise RuntimeError(f"{e} in retrieving Arkouda server config")


def _get_registration_config_msg() -> dict:
"""
Get runtime information about the command registration configuration.
Raises
------
RuntimeError
Raised if there is a server-side error in getting memory used
ValueError
Raised if there's an error in parsing the JSON-formatted server config
"""
try:
raw_message = cast(str, generic_msg(cmd="getRegistrationConfig"))
return json.loads(raw_message)
except json.decoder.JSONDecodeError:
raise ValueError(f"Returned config is not valid JSON: {raw_message}")
except Exception as e:
raise RuntimeError(f"{e} in retrieving Arkouda server config")


def get_mem_used(unit: str = "b", as_percent: bool = False) -> int:
"""
Compute the amount of memory used by objects in the server's symbol table.
Expand Down
124 changes: 41 additions & 83 deletions src/registry/register_commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,9 +226,7 @@ def get_formal_type_spec(formal):
else:
# TODO: `x: []` and `x: [?d]` are currently treated as invalid formal type expressions
raise ValueError("invalid formal type expression")
elif isinstance(
te, chapel.FnCall
): # Composite types (list and borrowed-class)
elif isinstance(te, chapel.FnCall): # Composite types (list and borrowed-class)
if ce := te.called_expression():
if isinstance(ce, chapel.Identifier):
call_name = ce.name()
Expand All @@ -240,9 +238,7 @@ def get_formal_type_spec(formal):
f"unsupported formal type for registration {list_type}; list element type must be a scalar",
)
else:
return FormalTypeSpec(
formalKind.LIST, name, sk, info=list_type
)
return FormalTypeSpec(formalKind.LIST, name, sk, info=list_type)
elif call_name == "borrowed":
actuals = list(te.actuals())
if isinstance(actuals[0], chapel.FnCall):
Expand All @@ -251,9 +247,7 @@ def get_formal_type_spec(formal):
else:
# concrete class formal (e.g., 'borrowed MySymEntry')
class_name = list(te.actuals())[0].name()
return FormalTypeSpec(
formalKind.BORROWED_CLASS, name, sk, class_name
)
return FormalTypeSpec(formalKind.BORROWED_CLASS, name, sk, class_name)
else:
error_message(
f"registering '{fn.name()}'",
Expand Down Expand Up @@ -349,9 +343,7 @@ def clean_enum_name(name):
return name


def stamp_generic_command(
generic_proc_name, prefix, module_name, formals, line_num, iar_annotation
):
def stamp_generic_command(generic_proc_name, prefix, module_name, formals, line_num, iar_annotation):
"""
Create code to stamp out and register a generic command using a generic
procedure, and a set values for its generic formals.
Expand All @@ -367,22 +359,15 @@ def stamp_generic_command(
+ ",".join(
[
# if the generic formal is a 'type' convert it to its numpy dtype name
(
chapel_scalar_types[v]
if v in chapel_scalar_types
else clean_enum_name(str(v))
)
(chapel_scalar_types[v] if v in chapel_scalar_types else clean_enum_name(str(v)))
for _, v in formals.items()
]
)
+ ">"
)

stamp_name = f"ark_{clean_stamp_name(prefix)}_" + "_".join(
[
clean_enum_name(str(v)).replace("(", "").replace(")", "")
for _, v in formals.items()
]
[clean_enum_name(str(v)).replace("(", "").replace(")", "") for _, v in formals.items()]
)

stamp_formal_args = ", ".join([f"{k}={v}" for k, v in formals.items()])
Expand Down Expand Up @@ -434,9 +419,7 @@ def parse_param_class_value(value):
if isinstance(value, list):
for v in value:
if not isinstance(v, (int, float, str)):
raise ValueError(
f"Invalid parameter value type ({type(v)}) in list '{value}'"
)
raise ValueError(f"Invalid parameter value type ({type(v)}) in list '{value}'")
return value
elif isinstance(value, int):
return [
Expand All @@ -452,9 +435,7 @@ def parse_param_class_value(value):
if isinstance(vals, list):
return vals
else:
raise ValueError(
f"Could not create a list of parameter values from '{value}'"
)
raise ValueError(f"Could not create a list of parameter values from '{value}'")
elif isinstance(value, dict) and "__enum__" in value and "__variants__" in value:
enum_name = value["__enum__"].split(".")[-1]
return [f"{enum_name}.{var}" for var in value["__variants__"]]
Expand Down Expand Up @@ -566,17 +547,11 @@ def unpack_array_arg(arg_name, array_count, finfo, domain_queries, dtype_queries
nd_generic_formal_info = None
else:
nd_arg_name = "array_nd_" + str(array_count)
nd_generic_formal_info = FormalTypeSpec(
formalKind.SCALAR, nd_arg_name, "param", "int"
)
nd_generic_formal_info = FormalTypeSpec(formalKind.SCALAR, nd_arg_name, "param", "int")

# check if the array formal has a static type or a type-query
# if not, generate a unique name and formal info for the dtype argument
if (
finfo is not None
and finfo[1] is not None
and isinstance(finfo[1], StaticTypeInfo)
):
if finfo is not None and finfo[1] is not None and isinstance(finfo[1], StaticTypeInfo):
dtype_arg_name = finfo[1].value
dtype_generic_formal_info = None
elif (
Expand All @@ -601,9 +576,7 @@ def unpack_array_arg(arg_name, array_count, finfo, domain_queries, dtype_queries
)


def unpack_generic_symbol_arg(
arg_name, symbol_class_name, symbol_count, symbol_param_class
):
def unpack_generic_symbol_arg(arg_name, symbol_class_name, symbol_count, symbol_param_class):
"""
Generate the code to unpack a non-array symbol-table entry class (a class that
inherits from 'AbstractSymEntry').
Expand Down Expand Up @@ -632,13 +605,8 @@ def unpack_generic_symbol_arg(
type_str = None
else:
storage_kind = "param"
type_str = (
"int" # TODO: also support strings and other param-able types here
)
elif (
isinstance(symbol_param_class[k], dict)
and "__enum__" in symbol_param_class[k].keys()
):
type_str = "int" # TODO: also support strings and other param-able types here
elif isinstance(symbol_param_class[k], dict) and "__enum__" in symbol_param_class[k].keys():
storage_kind = "param"
type_str = symbol_param_class[k]["__enum__"].split(".")[-1]
else:
Expand Down Expand Up @@ -770,9 +738,7 @@ def gen_arg_unpacking(formals, config):

for formal_spec in formals:
if formal_spec.is_chapel_scalar_type():
unpack_lines.append(
unpack_scalar_arg(formal_spec.name, formal_spec.type_str)
)
unpack_lines.append(unpack_scalar_arg(formal_spec.name, formal_spec.type_str))
elif formal_spec.kind == formalKind.ARRAY:
# finfo[0] is the domain query info, finfo[1] is the dtype query info
finfo = formal_spec.info
Expand All @@ -796,11 +762,7 @@ def gen_arg_unpacking(formals, config):
# this allows homogeneous-tuple formal types to use the array's rank as a size argument
# Do the same for dtype queries
if finfo is not None:
if (
finfo[0] is not None
and isinstance(finfo[0], FormalQuery)
and gen_nd_arg is not None
):
if finfo[0] is not None and isinstance(finfo[0], FormalQuery) and gen_nd_arg is not None:
array_domain_queries[finfo[0].name] = gen_nd_arg.name
if (
finfo[1] is not None
Expand Down Expand Up @@ -854,9 +816,7 @@ def gen_arg_unpacking(formals, config):
# a scalar formal with a generic type
if formal_spec.type_str is not None:
if queried_type := array_dtype_queries[formal_spec.type_str]:
unpack_lines.append(
unpack_scalar_arg(formal_spec.name, queried_type)
)
unpack_lines.append(unpack_scalar_arg(formal_spec.name, queried_type))
else:
# TODO: fully handle generic user-defined types
code, scalar_args = unpack_scalar_arg_with_generic(
Expand Down Expand Up @@ -959,19 +919,13 @@ def gen_command_proc(name, return_type, formals, mod_name, config):
arg_unpack, command_formals, query_table = gen_arg_unpacking(formals, config)
is_generic_command = len(command_formals) > 0
signature, cmd_name = gen_signature(name, command_formals)
fn_call, result_name = gen_user_function_call(
name, [f.name for f in formals], mod_name, return_type
)
fn_call, result_name = gen_user_function_call(name, [f.name for f in formals], mod_name, return_type)

# get the names of the array-elt-type queries in the formals
array_etype_queries = [
f.info[1].name
for f in formals
if (
f.kind == formalKind.ARRAY
and len(f.info) > 0
and isinstance(f.info[1], FormalQuery)
)
if (f.kind == formalKind.ARRAY and len(f.info) > 0 and isinstance(f.info[1], FormalQuery))
]

def return_type_fn_name():
Expand All @@ -993,27 +947,25 @@ def return_type_fn_name():
or (
# TODO: do resolution to ensure that this is a class type that inherits from 'AbstractSymEntry'
return_type_fn_name() is not None
and return_type_fn_name() in ["SymEntry",] + list(config["parameter_classes"].keys())
and return_type_fn_name()
in [
"SymEntry",
]
+ list(config["parameter_classes"].keys())
)
)
returns_array = (
return_type
and isinstance(return_type, chapel.BracketLoop)
and return_type.is_maybe_array_type()
return_type and isinstance(return_type, chapel.BracketLoop) and return_type.is_maybe_array_type()
)

if returns_array:
symbol_creation, result_name = gen_symbol_creation(
ARRAY_ENTRY_CLASS_NAME, result_name
)
symbol_creation, result_name = gen_symbol_creation(ARRAY_ENTRY_CLASS_NAME, result_name)
else:
symbol_creation = ""

response = gen_response(result_name, returns_symbol or returns_array)

command_proc = "\n".join(
[signature, arg_unpack, fn_call, symbol_creation, response, "}"]
)
command_proc = "\n".join([signature, arg_unpack, fn_call, symbol_creation, response, "}"])

return (command_proc, cmd_name, is_generic_command, command_formals, query_table)

Expand Down Expand Up @@ -1236,9 +1188,7 @@ def stamp_out_command(
if wcn := wc_node:
if not wcn.eval(fp, query_table):
continue
stamp = stamp_generic_command(
name, cmd_prefix, mod_name, fp, line_num, iar_annotation
)
stamp = stamp_generic_command(name, cmd_prefix, mod_name, fp, line_num, iar_annotation)
yield stamp


Expand All @@ -1248,9 +1198,7 @@ def extract_enum_imports(config):
if isinstance(config[k], dict):
if "__enum__" in config[k].keys():
if "__variants__" not in config[k].keys():
raise ValueError(
f"enum '{k}' is missing '__variants__' field in configuration file"
)
raise ValueError(f"enum '{k}' is missing '__variants__' field in configuration file")
imports.append(f"import {config[k]['__enum__']};")
else:
imports += extract_enum_imports(config[k])
Expand All @@ -1264,13 +1212,23 @@ def register_commands(config, source_files):
"""
stamps = [
"module Commands {",
"use CommandMap, Message, MultiTypeSymbolTable, MultiTypeSymEntry;",
"use CommandMap, IOUtils, Message, MultiTypeSymbolTable, MultiTypeSymEntry;",
"use BigInteger;",
watermarkConfig(config),
]

stamps += extract_enum_imports(config)

stamps.append("""proc getRegistrationConfig(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws {
return new MsgTuple(getRegConfig(), MsgType.NORMAL);
}""")


stamps.append(
"proc getRegConfig(): string throws do return try! regConfig;"
"\nregisterFunction('getRegistrationConfig', getRegistrationConfig, 'Commands', 68);"
)

count = 0

for filename, ctx in chapel.files_with_contexts(source_files):
Expand Down Expand Up @@ -1313,8 +1271,8 @@ def register_commands(config, source_files):
)
continue

(cmd_proc, cmd_name, is_generic_cmd, cmd_gen_formals, query_table) = (
gen_command_proc(name, fn.return_type(), con_formals, mod_name, config)
(cmd_proc, cmd_name, is_generic_cmd, cmd_gen_formals, query_table) = gen_command_proc(
name, fn.return_type(), con_formals, mod_name, config
)

file_stamps.append(cmd_proc)
Expand Down
8 changes: 8 additions & 0 deletions tests/client_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,3 +156,11 @@ def test_client_get_server_commands(self):
cmds = ak.client.get_server_commands()
for cmd in ["connect", "info", "str"]:
assert cmd in cmds

def test_get_array_ranks(self):
availableRanks = ak.client.get_array_ranks()
assert isinstance(availableRanks, list)
assert len(availableRanks) >= 1
assert 1 in availableRanks
assert ak.client.get_max_array_rank() in availableRanks
assert ak.client.get_max_array_rank() + 1 not in availableRanks

0 comments on commit 765d392

Please sign in to comment.