From 1669077b2f49b6572018cd7b7d4a8033e629b3da Mon Sep 17 00:00:00 2001 From: Amanda Potts Date: Wed, 20 Nov 2024 15:58:50 -0500 Subject: [PATCH] Closes #3904: function to return list of all compiled dimensions available --- arkouda/client.py | 47 ++++++++++- src/registry/register_commands.py | 124 ++++++++++-------------------- tests/client_test.py | 8 ++ 3 files changed, 95 insertions(+), 84 deletions(-) diff --git a/arkouda/client.py b/arkouda/client.py index d9fd062cf1..7818f75954 100644 --- a/arkouda/client.py +++ b/arkouda/client.py @@ -32,6 +32,7 @@ username = security.get_username() connected = False serverConfig = None +registrationConfig = None # verbose flag for arkouda module verboseDefVal = False verbose = verboseDefVal @@ -706,7 +707,7 @@ def connect( On success, prints the connected address, as seen by the server. If called with an existing connection, the socket will be re-initialized. """ - global connected, serverConfig, verbose, regexMaxCaptures, channel + global connected, serverConfig, verbose, regexMaxCaptures, channel, registrationConfig # send the connect message cmd = "connect" @@ -741,6 +742,7 @@ def connect( RuntimeWarning, ) regexMaxCaptures = serverConfig["regexMaxCaptures"] # type: ignore + registrationConfig = _get_registration_config_msg() clientLogger.info(return_message) @@ -1063,6 +1065,29 @@ def get_max_array_rank() -> int: return int(serverConfig["maxArrayDims"]) +def get_array_ranks() -> list[int]: + """ + Get the list of pdarray ranks the server was compiled to support + + This value corresponds to + parameter_classes -> array -> nd in the `registration-config.json` + file when the server was compiled. + + Returns + ------- + list of int + The pdarray ranks supported by the server + """ + + if registrationConfig is None: + raise RuntimeError( + "There was a problem loading registrationConfig." + "Make sure the client is connected to a server." + ) + + return registrationConfig["parameter_classes"]["array"]["nd"] + + def _get_config_msg() -> Mapping[str, Union[str, int, float]]: """ Get runtime information about the server. @@ -1083,6 +1108,26 @@ def _get_config_msg() -> Mapping[str, Union[str, int, float]]: raise RuntimeError(f"{e} in retrieving Arkouda server config") +def _get_registration_config_msg() -> dict: + """ + Get runtime information about the command registration configuration. + + Raises + ------ + RuntimeError + Raised if there is a server-side error in getting memory used + ValueError + Raised if there's an error in parsing the JSON-formatted server config + """ + try: + raw_message = cast(str, generic_msg(cmd="getRegistrationConfig")) + return json.loads(raw_message) + except json.decoder.JSONDecodeError: + raise ValueError(f"Returned config is not valid JSON: {raw_message}") + except Exception as e: + raise RuntimeError(f"{e} in retrieving Arkouda server config") + + def get_mem_used(unit: str = "b", as_percent: bool = False) -> int: """ Compute the amount of memory used by objects in the server's symbol table. diff --git a/src/registry/register_commands.py b/src/registry/register_commands.py index 6bbd819344..98b28df6d6 100644 --- a/src/registry/register_commands.py +++ b/src/registry/register_commands.py @@ -226,9 +226,7 @@ def get_formal_type_spec(formal): else: # TODO: `x: []` and `x: [?d]` are currently treated as invalid formal type expressions raise ValueError("invalid formal type expression") - elif isinstance( - te, chapel.FnCall - ): # Composite types (list and borrowed-class) + elif isinstance(te, chapel.FnCall): # Composite types (list and borrowed-class) if ce := te.called_expression(): if isinstance(ce, chapel.Identifier): call_name = ce.name() @@ -240,9 +238,7 @@ def get_formal_type_spec(formal): f"unsupported formal type for registration {list_type}; list element type must be a scalar", ) else: - return FormalTypeSpec( - formalKind.LIST, name, sk, info=list_type - ) + return FormalTypeSpec(formalKind.LIST, name, sk, info=list_type) elif call_name == "borrowed": actuals = list(te.actuals()) if isinstance(actuals[0], chapel.FnCall): @@ -251,9 +247,7 @@ def get_formal_type_spec(formal): else: # concrete class formal (e.g., 'borrowed MySymEntry') class_name = list(te.actuals())[0].name() - return FormalTypeSpec( - formalKind.BORROWED_CLASS, name, sk, class_name - ) + return FormalTypeSpec(formalKind.BORROWED_CLASS, name, sk, class_name) else: error_message( f"registering '{fn.name()}'", @@ -349,9 +343,7 @@ def clean_enum_name(name): return name -def stamp_generic_command( - generic_proc_name, prefix, module_name, formals, line_num, iar_annotation -): +def stamp_generic_command(generic_proc_name, prefix, module_name, formals, line_num, iar_annotation): """ Create code to stamp out and register a generic command using a generic procedure, and a set values for its generic formals. @@ -367,11 +359,7 @@ def stamp_generic_command( + ",".join( [ # if the generic formal is a 'type' convert it to its numpy dtype name - ( - chapel_scalar_types[v] - if v in chapel_scalar_types - else clean_enum_name(str(v)) - ) + (chapel_scalar_types[v] if v in chapel_scalar_types else clean_enum_name(str(v))) for _, v in formals.items() ] ) @@ -379,10 +367,7 @@ def stamp_generic_command( ) stamp_name = f"ark_{clean_stamp_name(prefix)}_" + "_".join( - [ - clean_enum_name(str(v)).replace("(", "").replace(")", "") - for _, v in formals.items() - ] + [clean_enum_name(str(v)).replace("(", "").replace(")", "") for _, v in formals.items()] ) stamp_formal_args = ", ".join([f"{k}={v}" for k, v in formals.items()]) @@ -434,9 +419,7 @@ def parse_param_class_value(value): if isinstance(value, list): for v in value: if not isinstance(v, (int, float, str)): - raise ValueError( - f"Invalid parameter value type ({type(v)}) in list '{value}'" - ) + raise ValueError(f"Invalid parameter value type ({type(v)}) in list '{value}'") return value elif isinstance(value, int): return [ @@ -452,9 +435,7 @@ def parse_param_class_value(value): if isinstance(vals, list): return vals else: - raise ValueError( - f"Could not create a list of parameter values from '{value}'" - ) + raise ValueError(f"Could not create a list of parameter values from '{value}'") elif isinstance(value, dict) and "__enum__" in value and "__variants__" in value: enum_name = value["__enum__"].split(".")[-1] return [f"{enum_name}.{var}" for var in value["__variants__"]] @@ -566,17 +547,11 @@ def unpack_array_arg(arg_name, array_count, finfo, domain_queries, dtype_queries nd_generic_formal_info = None else: nd_arg_name = "array_nd_" + str(array_count) - nd_generic_formal_info = FormalTypeSpec( - formalKind.SCALAR, nd_arg_name, "param", "int" - ) + nd_generic_formal_info = FormalTypeSpec(formalKind.SCALAR, nd_arg_name, "param", "int") # check if the array formal has a static type or a type-query # if not, generate a unique name and formal info for the dtype argument - if ( - finfo is not None - and finfo[1] is not None - and isinstance(finfo[1], StaticTypeInfo) - ): + if finfo is not None and finfo[1] is not None and isinstance(finfo[1], StaticTypeInfo): dtype_arg_name = finfo[1].value dtype_generic_formal_info = None elif ( @@ -601,9 +576,7 @@ def unpack_array_arg(arg_name, array_count, finfo, domain_queries, dtype_queries ) -def unpack_generic_symbol_arg( - arg_name, symbol_class_name, symbol_count, symbol_param_class -): +def unpack_generic_symbol_arg(arg_name, symbol_class_name, symbol_count, symbol_param_class): """ Generate the code to unpack a non-array symbol-table entry class (a class that inherits from 'AbstractSymEntry'). @@ -632,13 +605,8 @@ def unpack_generic_symbol_arg( type_str = None else: storage_kind = "param" - type_str = ( - "int" # TODO: also support strings and other param-able types here - ) - elif ( - isinstance(symbol_param_class[k], dict) - and "__enum__" in symbol_param_class[k].keys() - ): + type_str = "int" # TODO: also support strings and other param-able types here + elif isinstance(symbol_param_class[k], dict) and "__enum__" in symbol_param_class[k].keys(): storage_kind = "param" type_str = symbol_param_class[k]["__enum__"].split(".")[-1] else: @@ -770,9 +738,7 @@ def gen_arg_unpacking(formals, config): for formal_spec in formals: if formal_spec.is_chapel_scalar_type(): - unpack_lines.append( - unpack_scalar_arg(formal_spec.name, formal_spec.type_str) - ) + unpack_lines.append(unpack_scalar_arg(formal_spec.name, formal_spec.type_str)) elif formal_spec.kind == formalKind.ARRAY: # finfo[0] is the domain query info, finfo[1] is the dtype query info finfo = formal_spec.info @@ -796,11 +762,7 @@ def gen_arg_unpacking(formals, config): # this allows homogeneous-tuple formal types to use the array's rank as a size argument # Do the same for dtype queries if finfo is not None: - if ( - finfo[0] is not None - and isinstance(finfo[0], FormalQuery) - and gen_nd_arg is not None - ): + if finfo[0] is not None and isinstance(finfo[0], FormalQuery) and gen_nd_arg is not None: array_domain_queries[finfo[0].name] = gen_nd_arg.name if ( finfo[1] is not None @@ -854,9 +816,7 @@ def gen_arg_unpacking(formals, config): # a scalar formal with a generic type if formal_spec.type_str is not None: if queried_type := array_dtype_queries[formal_spec.type_str]: - unpack_lines.append( - unpack_scalar_arg(formal_spec.name, queried_type) - ) + unpack_lines.append(unpack_scalar_arg(formal_spec.name, queried_type)) else: # TODO: fully handle generic user-defined types code, scalar_args = unpack_scalar_arg_with_generic( @@ -959,19 +919,13 @@ def gen_command_proc(name, return_type, formals, mod_name, config): arg_unpack, command_formals, query_table = gen_arg_unpacking(formals, config) is_generic_command = len(command_formals) > 0 signature, cmd_name = gen_signature(name, command_formals) - fn_call, result_name = gen_user_function_call( - name, [f.name for f in formals], mod_name, return_type - ) + fn_call, result_name = gen_user_function_call(name, [f.name for f in formals], mod_name, return_type) # get the names of the array-elt-type queries in the formals array_etype_queries = [ f.info[1].name for f in formals - if ( - f.kind == formalKind.ARRAY - and len(f.info) > 0 - and isinstance(f.info[1], FormalQuery) - ) + if (f.kind == formalKind.ARRAY and len(f.info) > 0 and isinstance(f.info[1], FormalQuery)) ] def return_type_fn_name(): @@ -993,27 +947,25 @@ def return_type_fn_name(): or ( # TODO: do resolution to ensure that this is a class type that inherits from 'AbstractSymEntry' return_type_fn_name() is not None - and return_type_fn_name() in ["SymEntry",] + list(config["parameter_classes"].keys()) + and return_type_fn_name() + in [ + "SymEntry", + ] + + list(config["parameter_classes"].keys()) ) ) returns_array = ( - return_type - and isinstance(return_type, chapel.BracketLoop) - and return_type.is_maybe_array_type() + return_type and isinstance(return_type, chapel.BracketLoop) and return_type.is_maybe_array_type() ) if returns_array: - symbol_creation, result_name = gen_symbol_creation( - ARRAY_ENTRY_CLASS_NAME, result_name - ) + symbol_creation, result_name = gen_symbol_creation(ARRAY_ENTRY_CLASS_NAME, result_name) else: symbol_creation = "" response = gen_response(result_name, returns_symbol or returns_array) - command_proc = "\n".join( - [signature, arg_unpack, fn_call, symbol_creation, response, "}"] - ) + command_proc = "\n".join([signature, arg_unpack, fn_call, symbol_creation, response, "}"]) return (command_proc, cmd_name, is_generic_command, command_formals, query_table) @@ -1236,9 +1188,7 @@ def stamp_out_command( if wcn := wc_node: if not wcn.eval(fp, query_table): continue - stamp = stamp_generic_command( - name, cmd_prefix, mod_name, fp, line_num, iar_annotation - ) + stamp = stamp_generic_command(name, cmd_prefix, mod_name, fp, line_num, iar_annotation) yield stamp @@ -1248,9 +1198,7 @@ def extract_enum_imports(config): if isinstance(config[k], dict): if "__enum__" in config[k].keys(): if "__variants__" not in config[k].keys(): - raise ValueError( - f"enum '{k}' is missing '__variants__' field in configuration file" - ) + raise ValueError(f"enum '{k}' is missing '__variants__' field in configuration file") imports.append(f"import {config[k]['__enum__']};") else: imports += extract_enum_imports(config[k]) @@ -1264,13 +1212,23 @@ def register_commands(config, source_files): """ stamps = [ "module Commands {", - "use CommandMap, Message, MultiTypeSymbolTable, MultiTypeSymEntry;", + "use CommandMap, IOUtils, Message, MultiTypeSymbolTable, MultiTypeSymEntry;", "use BigInteger;", watermarkConfig(config), ] stamps += extract_enum_imports(config) + stamps.append("""proc getRegistrationConfig(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws { + return new MsgTuple(getRegConfig(), MsgType.NORMAL); + }""") + + + stamps.append( + "proc getRegConfig(): string throws do return try! regConfig;" + "\nregisterFunction('getRegistrationConfig', getRegistrationConfig, 'Commands', 68);" + ) + count = 0 for filename, ctx in chapel.files_with_contexts(source_files): @@ -1313,8 +1271,8 @@ def register_commands(config, source_files): ) continue - (cmd_proc, cmd_name, is_generic_cmd, cmd_gen_formals, query_table) = ( - gen_command_proc(name, fn.return_type(), con_formals, mod_name, config) + (cmd_proc, cmd_name, is_generic_cmd, cmd_gen_formals, query_table) = gen_command_proc( + name, fn.return_type(), con_formals, mod_name, config ) file_stamps.append(cmd_proc) diff --git a/tests/client_test.py b/tests/client_test.py index d92247a20c..06c17362a5 100644 --- a/tests/client_test.py +++ b/tests/client_test.py @@ -156,3 +156,11 @@ def test_client_get_server_commands(self): cmds = ak.client.get_server_commands() for cmd in ["connect", "info", "str"]: assert cmd in cmds + + def test_get_array_ranks(self): + availableRanks = ak.client.get_array_ranks() + assert isinstance(availableRanks, list) + assert len(availableRanks) >= 1 + assert 1 in availableRanks + assert ak.client.get_max_array_rank() in availableRanks + assert ak.client.get_max_array_rank() + 1 not in availableRanks