Skip to content

Commit

Permalink
Re-enable support for extra args in EP config models
Browse files Browse the repository at this point in the history
A recent PR removed the `extra = "allow"` attribute from the
`BaseConfigModel` Pydantic class. Many subclass models rely on this
setting to permit required arguments without rasing a `ValidationError`.
  • Loading branch information
rjmello committed Dec 5, 2024
1 parent d685c17 commit 0176092
Show file tree
Hide file tree
Showing 5 changed files with 47 additions and 23 deletions.
5 changes: 5 additions & 0 deletions changelog.d/20241205_100931_30907815+rjmello_v2_32_1.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
Bug Fixes
^^^^^^^^^

- Fixed an issue where valid endpoint configuration variables were ignored,
causing spurious validation errors.
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
)
from .model import ( # noqa: F401
BaseConfigModel,
BaseEndpointConfigModel,
ManagerEndpointConfigModel,
UserEndpointConfigModel,
)
41 changes: 20 additions & 21 deletions compute_endpoint/globus_compute_endpoint/endpoint/config/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,16 +43,8 @@ def inner(cls, model: t.Optional[BaseModel]):


class BaseConfigModel(BaseModel):
multi_user: t.Optional[bool]
display_name: t.Optional[str]
allowed_functions: t.Optional[t.List[uuid.UUID]]
authentication_policy: t.Optional[uuid.UUID]
subscription_id: t.Optional[uuid.UUID]
amqp_port: t.Optional[int]
heartbeat_period: t.Optional[int]
environment: t.Optional[str]
local_compute_services: t.Optional[bool]
debug: t.Optional[bool]
class Config:
extra = "allow"


class AddressModel(BaseConfigModel):
Expand Down Expand Up @@ -80,9 +72,6 @@ class ChannelModel(BaseConfigModel):


class ProviderModel(BaseConfigModel):
class Config:
extra = "allow"

type: str
channel: t.Optional[ChannelModel]
launcher: t.Optional[LauncherModel]
Expand Down Expand Up @@ -153,7 +142,23 @@ def _validate_provider_container_compatibility(cls, values: dict):
return values


class UserEndpointConfigModel(BaseConfigModel):
class BaseEndpointConfigModel(BaseModel):
multi_user: t.Optional[bool]
display_name: t.Optional[str]
allowed_functions: t.Optional[t.List[uuid.UUID]]
authentication_policy: t.Optional[uuid.UUID]
subscription_id: t.Optional[uuid.UUID]
amqp_port: t.Optional[int]
heartbeat_period: t.Optional[int]
environment: t.Optional[str]
local_compute_services: t.Optional[bool]
debug: t.Optional[bool]

class Config:
extra = "forbid"


class UserEndpointConfigModel(BaseEndpointConfigModel):
engine: EngineModel
heartbeat_threshold: t.Optional[int]
idle_heartbeats_soft: t.Optional[int]
Expand All @@ -167,9 +172,6 @@ class UserEndpointConfigModel(BaseConfigModel):

_validate_engine = _validate_params("engine")

class Config:
extra = "forbid"

def dict(self, *args, **kwargs):
# Slight modification is needed here since we still
# store the engine/executor in a list named executors
Expand All @@ -180,11 +182,8 @@ def dict(self, *args, **kwargs):
return ret


class ManagerEndpointConfigModel(BaseConfigModel):
class ManagerEndpointConfigModel(BaseEndpointConfigModel):
public: t.Optional[bool]
identity_mapping_config_path: t.Optional[FilePath]
force_mu_allow_same_user: t.Optional[bool]
mu_child_ep_grace_period_s: t.Optional[float]

class Config:
extra = "forbid"
Original file line number Diff line number Diff line change
Expand Up @@ -90,10 +90,12 @@ def load_config_yaml(config_str: str) -> UserEndpointConfig | ManagerEndpointCon
try:
ConfigClass: type[UserEndpointConfig | ManagerEndpointConfig]
if is_templatable:
from . import BaseConfigModel, ManagerEndpointConfigModel
from . import BaseEndpointConfigModel, ManagerEndpointConfigModel

ConfigClass = ManagerEndpointConfig
config_schema: BaseConfigModel = ManagerEndpointConfigModel(**config_dict)
config_schema: BaseEndpointConfigModel = ManagerEndpointConfigModel(
**config_dict
)
else:
from . import UserEndpointConfigModel

Expand Down
17 changes: 17 additions & 0 deletions compute_endpoint/tests/unit/test_endpoint_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,3 +227,20 @@ def test_managerconfig_repr_nondefault_kwargs(
repr_c = repr(ManagerEndpointConfig(**{kw: val}))

assert f"{kw}={repr(val)}" in repr_c


def test_engine_model_objects_allow_extra():
config_dict = {
"engine": {
"type": "GlobusComputeEngine",
"address": {
"type": "address_by_interface",
"ifname": "lo", # Not specified in model
},
"provider": {
"type": "LocalProvider",
"max_blocks": 2, # Not specified in model
},
}
}
UserEndpointConfigModel(**config_dict)

0 comments on commit 0176092

Please sign in to comment.