From 0f1a2d2003aa05ae189c37879af7200f48694568 Mon Sep 17 00:00:00 2001 From: "pixeebot[bot]" <104101892+pixeebot[bot]@users.noreply.github.com> Date: Tue, 20 Feb 2024 22:16:46 +0000 Subject: [PATCH 1/2] Use Assignment Expression (Walrus) In Conditional --- src/python/zigopt/api/auth.py | 9 +++---- src/python/zigopt/api/paging.py | 3 +-- src/python/zigopt/api/ratelimit.py | 3 +-- src/python/zigopt/api/request.py | 15 ++++------- src/python/zigopt/assignments/model.py | 3 +-- src/python/zigopt/authentication/login.py | 6 ++--- src/python/zigopt/authentication/token.py | 3 +-- src/python/zigopt/authorization/owner.py | 3 +-- src/python/zigopt/authorization/user.py | 6 ++--- src/python/zigopt/common/lists.py | 3 +-- src/python/zigopt/db/service.py | 3 +-- src/python/zigopt/experiment/model.py | 6 ++--- src/python/zigopt/experiment/segmenter.py | 4 +-- src/python/zigopt/experiment/service.py | 10 +++---- src/python/zigopt/file/model.py | 3 +-- .../zigopt/handlers/aiexperiments/create.py | 3 +-- .../aiexperiments/training_runs/create.py | 9 +++---- src/python/zigopt/handlers/clients/delete.py | 5 ++-- src/python/zigopt/handlers/clients/invite.py | 9 +++---- src/python/zigopt/handlers/clients/tokens.py | 19 +++++--------- .../experiments/checkpoints/create.py | 3 +-- .../zigopt/handlers/experiments/create.py | 19 +++++--------- .../zigopt/handlers/experiments/list_base.py | 12 ++++----- .../experiments/metric_importances/update.py | 5 ++-- .../handlers/experiments/observations/base.py | 3 +-- .../experiments/observations/create.py | 8 +++--- .../experiments/queued_suggestions/base.py | 3 +-- .../handlers/experiments/suggestions/base.py | 3 +-- .../experiments/suggestions/create.py | 3 +-- .../experiments/suggestions/update.py | 5 ++-- .../zigopt/handlers/experiments/update.py | 13 ++++------ src/python/zigopt/handlers/files/base.py | 3 +-- .../zigopt/handlers/organizations/update.py | 3 +-- src/python/zigopt/handlers/projects/base.py | 5 ++-- .../zigopt/handlers/training_runs/base.py | 9 +++---- .../zigopt/handlers/training_runs/create.py | 6 ++--- .../zigopt/handlers/training_runs/files.py | 3 +-- .../zigopt/handlers/training_runs/list.py | 3 +-- .../zigopt/handlers/training_runs/tags.py | 15 +++++------ .../zigopt/handlers/training_runs/update.py | 9 +++---- src/python/zigopt/handlers/users/create.py | 6 ++--- src/python/zigopt/handlers/users/delete.py | 15 ++++------- src/python/zigopt/handlers/users/password.py | 3 +-- src/python/zigopt/handlers/users/sessions.py | 3 +-- src/python/zigopt/handlers/users/update.py | 12 +++------ .../zigopt/handlers/validate/training_run.py | 3 +-- src/python/zigopt/handlers/validate/user.py | 3 +-- .../zigopt/handlers/validate/validate_dict.py | 6 ++--- .../zigopt/handlers/validate/web_data/base.py | 3 +-- src/python/zigopt/handlers/web_data/base.py | 3 +-- src/python/zigopt/handlers/web_data/delete.py | 5 ++-- src/python/zigopt/handlers/web_data/update.py | 7 +++-- src/python/zigopt/json/assignments.py | 3 +-- src/python/zigopt/json/builder/experiment.py | 3 +-- .../zigopt/json/builder/json_builder.py | 3 +-- src/python/zigopt/json/render.py | 3 +-- src/python/zigopt/log/base.py | 9 +++---- src/python/zigopt/membership/service.py | 3 +-- src/python/zigopt/net/errors.py | 3 +-- src/python/zigopt/observation/from_json.py | 3 +-- src/python/zigopt/observation/model.py | 3 +-- src/python/zigopt/optimization_aux/service.py | 3 +-- src/python/zigopt/optimize/queue.py | 3 +-- src/python/zigopt/optimize/sources/base.py | 12 +++------ .../zigopt/optimize/sources/categorical.py | 3 +-- src/python/zigopt/organization/service.py | 6 ++--- src/python/zigopt/pagination/lib.py | 3 +-- src/python/zigopt/parameters/from_json.py | 13 ++++------ src/python/zigopt/permission/service.py | 5 ++-- src/python/zigopt/protobuf/json.py | 26 +++++++------------ src/python/zigopt/queue/monitor.py | 3 +-- src/python/zigopt/queue/redis/message.py | 3 +-- src/python/zigopt/queue/router.py | 6 ++--- src/python/zigopt/queue/workers.py | 8 +++--- .../zigopt/queued_suggestion/service.py | 3 +-- src/python/zigopt/redis/service.py | 3 +-- src/python/zigopt/sigoptcompute/adapter.py | 3 +-- src/python/zigopt/suggestion/broker/base.py | 5 ++-- src/python/zigopt/suggestion/service.py | 3 +-- .../zigopt/suggestion/unprocessed/service.py | 9 +++---- src/python/zigopt/token/service.py | 21 ++++++--------- src/python/zigopt/training_run/service.py | 8 +++--- src/python/zigopt/user/service.py | 3 +-- 83 files changed, 182 insertions(+), 330 deletions(-) diff --git a/src/python/zigopt/api/auth.py b/src/python/zigopt/api/auth.py index d6a14675..3ebfeec2 100644 --- a/src/python/zigopt/api/auth.py +++ b/src/python/zigopt/api/auth.py @@ -124,12 +124,10 @@ def _password_reset_authentication(services, request): to the method as the first argument. """ email = napply(request.optional_param("email"), validate_email) - optional_api_token = request.optional_api_token() - if optional_api_token: + if optional_api_token := request.optional_api_token(): token = _validate_api_token(request.optional_user_token()) token_authorization = _do_api_token_authentication(services, request, token, mandatory=True) - auth_email = token_authorization.current_user and token_authorization.current_user.email - if auth_email: + if auth_email := token_authorization.current_user and token_authorization.current_user.email: if email and auth_email != email: raise BadParamError("Invalid email parameter when authenticating with API token") email = auth_email @@ -258,8 +256,7 @@ def _login_authentication(services, request): rate_limit_identifier = login_rate_limit.increment_and_check_rate_limit(services, request) authentication = authenticate_login(services, email, password, code) login_rate_limit.reset_rate_limit(services, rate_limit_identifier) - user = authentication.user - if user is not None: + if (user := authentication.user) is not None: authenticated_from_email_link = authentication.authenticated_from_email_link user_token = services.token_service.create_temporary_user_token(user.id) return UserLoginAuthorization( diff --git a/src/python/zigopt/api/paging.py b/src/python/zigopt/api/paging.py index cc4cf6f1..e466b89a 100644 --- a/src/python/zigopt/api/paging.py +++ b/src/python/zigopt/api/paging.py @@ -13,8 +13,7 @@ def decode_base64_marker_without_padding(serialized_marker): # NOTE: "=" padding needs to be added to the marker in order for it to be correctly decoded - tail_len = len(serialized_marker) % 4 - if tail_len: + if tail_len := len(serialized_marker) % 4: serialized_marker += "=" * (4 - tail_len) return base64.urlsafe_b64decode(serialized_marker) diff --git a/src/python/zigopt/api/ratelimit.py b/src/python/zigopt/api/ratelimit.py index b019a257..0aa22f42 100644 --- a/src/python/zigopt/api/ratelimit.py +++ b/src/python/zigopt/api/ratelimit.py @@ -85,8 +85,7 @@ class _ApiTokenRateLimit(_BaseRateLimit): NON_MUTATING_METHODS = {"GET", "HEAD", "OPTIONS"} def _get_identifier(self, services, identifying_object): - token = identifying_object.optional_client_token() - if token is not None: + if (token := identifying_object.optional_client_token()) is not None: return token, TOKEN_IDENTIFIER return _UNUSED, _UNUSED diff --git a/src/python/zigopt/api/request.py b/src/python/zigopt/api/request.py index 3bfb0f44..7b221a53 100644 --- a/src/python/zigopt/api/request.py +++ b/src/python/zigopt/api/request.py @@ -221,8 +221,7 @@ def on_json_loading_failed(self, e): raise SigoptValidationError("Malformed json in request body") from e def params(self): - ret = self._params - if ret is self._MISSING: + if (ret := self._params) is self._MISSING: if self.method in ("GET", "DELETE"): self._params = {validate_api_input(k): validate_api_input(v) for k, v in self.args.items()} elif self.method in ("PUT", "POST", "MERGE"): @@ -259,8 +258,7 @@ def optional_param(self, name): def required_param(self, name): """Retrieve ``name`` from flask HTTP request and *fail* if ``name`` is not found.""" - value = self.optional_param(name) - if value is not None: + if (value := self.optional_param(name)) is not None: return value raise MissingParamError(name) @@ -283,8 +281,7 @@ def optional_list_param(self, name): """ assert self.method in ("GET", "POST", "DELETE"), "Must use get_with_validation for non-query arguments" param_string = self.params().get(name) - param_value = param_string.split(",") if param_string else None - if param_value is None: + if (param_value := param_string.split(",") if param_string else None) is None: return None if not isinstance(param_value, list): raise SigoptValidationError("Invalid list value: " + param_value) @@ -296,14 +293,12 @@ def get_sort(self, default_field, default_ascending=False): return SortRequest(field, ascending) def _parse_marker(self, name): - serialized_marker = self.optional_param(name) - if serialized_marker is None: + if (serialized_marker := self.optional_param(name)) is None: return None return deserialize_paging_marker(serialized_marker) def get_paging(self, max_limit=DEFAULT_PAGING_MAX_LIMIT): - limit = self.optional_int_param("limit") - if limit is None: + if (limit := self.optional_int_param("limit")) is None: limit = max_limit if limit < 0: raise InvalidValueError(f"Invalid limit: {limit}") diff --git a/src/python/zigopt/assignments/model.py b/src/python/zigopt/assignments/model.py index c974d298..9083c8f9 100644 --- a/src/python/zigopt/assignments/model.py +++ b/src/python/zigopt/assignments/model.py @@ -45,8 +45,7 @@ def assignments_map(self): raise Exception("Do not access .assignments_map directly - prefer .get_assignments(experiment)") def _get_required_parameter(self, array, parameter, index): - ret = array[index] - if ret is None: + if (ret := array[index]) is None: raise ValueError(f"Parameter has no replacement value: {parameter.name}") return ret diff --git a/src/python/zigopt/authentication/login.py b/src/python/zigopt/authentication/login.py index 06aa7736..4029efeb 100644 --- a/src/python/zigopt/authentication/login.py +++ b/src/python/zigopt/authentication/login.py @@ -88,14 +88,12 @@ def authenticate_login(services, email: str, password: str, code: str | None) -> delay_seconds_on_failure = min_delay_seconds + jitter_seconds * non_crypto_random.random() failure_time = start_time + delay_seconds_on_failure try: - user = services.user_service.find_by_email(email) - if user: + if user := services.user_service.find_by_email(email): # NOTE: If the user cannot auth with a password, we can't reveal the error they have made. # This would leak information about the account they are attempting to log in to, notably that # the account exists and someone has logged into it successfully. # The user must go through the "Forgot Password" flow - can_auth_with_password = bool(user.hashed_password) - if can_auth_with_password: + if can_auth_with_password := bool(user.hashed_password): if password is not None: return authenticate_password(services, user, password) if code is not None: diff --git a/src/python/zigopt/authentication/token.py b/src/python/zigopt/authentication/token.py index 2daf1d22..a18f4418 100644 --- a/src/python/zigopt/authentication/token.py +++ b/src/python/zigopt/authentication/token.py @@ -15,8 +15,7 @@ @deal.raises(ForbiddenError, TypeError) def authenticate_token(services, token: str) -> AuthenticationResult: - token_obj = services.token_service.find_by_token(token, include_expired=True) - if token_obj is None: + if (token_obj := services.token_service.find_by_token(token, include_expired=True)) is None: return AuthenticationResult() if token_obj.expired: raise ForbiddenError("Your API token has expired.", token_status=TokenStatus.EXPIRED) diff --git a/src/python/zigopt/authorization/owner.py b/src/python/zigopt/authorization/owner.py index 3be427be..8e39021f 100644 --- a/src/python/zigopt/authorization/owner.py +++ b/src/python/zigopt/authorization/owner.py @@ -50,8 +50,7 @@ def can_act_on_organization(self, services, requested_permission, organization): ) def _can_act_on_client_artifacts(self, services, requested_permission, client_id, owner_id_for_artifacts): - client = services.client_service.find_by_id(client_id) - if client is None: + if (client := services.client_service.find_by_id(client_id)) is None: return False return ( self.can_act_on_client(services, requested_permission, client) diff --git a/src/python/zigopt/authorization/user.py b/src/python/zigopt/authorization/user.py index 749b606e..ed2c935f 100644 --- a/src/python/zigopt/authorization/user.py +++ b/src/python/zigopt/authorization/user.py @@ -86,11 +86,9 @@ def can_act_on_organization(self, services, requested_permission, organization): return False def _can_act_on_client_artifacts(self, services, requested_permission, client_id, owner_id_for_artifacts): - organization_id = self._infer_organization_id_from_client_id(services, client_id) - if organization_id: + if organization_id := self._infer_organization_id_from_client_id(services, client_id): organization = self._infer_organization_from_organization_id(services, organization_id) - membership = self._membership(services, organization=organization) - if membership: + if membership := self._membership(services, organization=organization): if membership.is_owner: return True diff --git a/src/python/zigopt/common/lists.py b/src/python/zigopt/common/lists.py index 857a33a1..03f47055 100644 --- a/src/python/zigopt/common/lists.py +++ b/src/python/zigopt/common/lists.py @@ -83,8 +83,7 @@ def partition(lis: Sequence[T], predicate: Callable[[T], bool]) -> tuple[list[T] true_list = [] false_list = [] for l in as_list: - pred_value = predicate(l) - if pred_value: + if pred_value := predicate(l): true_list.append(l) else: false_list.append(l) diff --git a/src/python/zigopt/db/service.py b/src/python/zigopt/db/service.py index cfb37af3..6eb78e1e 100644 --- a/src/python/zigopt/db/service.py +++ b/src/python/zigopt/db/service.py @@ -408,8 +408,7 @@ def one(self, q: Query) -> Any: @sanitize_errors @retry_on_error def one_or_none(self, q: Query) -> Any | None: - ret = q.one_or_none() - if ret is not None: + if (ret := q.one_or_none()) is not None: self._expunge_one(ret) self._rollback() return ret diff --git a/src/python/zigopt/experiment/model.py b/src/python/zigopt/experiment/model.py index f91f5c4a..18babd45 100644 --- a/src/python/zigopt/experiment/model.py +++ b/src/python/zigopt/experiment/model.py @@ -374,14 +374,12 @@ def costliest_task(self): return max_option(self.tasks, key=lambda x: x.cost) def get_task_by_name(self, task_name): - matching_task = find(self.tasks, lambda t: t.name == task_name) - if matching_task is None: + if (matching_task := find(self.tasks, lambda t: t.name == task_name)) is None: raise ValueError(f"No task named {task_name} for experiment {self.id}.") return matching_task def get_task_by_cost(self, task_cost): - matching_task = find(self.tasks, lambda t: t.cost == task_cost) - if matching_task is None: + if (matching_task := find(self.tasks, lambda t: t.cost == task_cost)) is None: raise ValueError(f"No task with cost {task_cost} for experiment {self.id}.") return matching_task diff --git a/src/python/zigopt/experiment/segmenter.py b/src/python/zigopt/experiment/segmenter.py index 44ae4a63..a97badf9 100644 --- a/src/python/zigopt/experiment/segmenter.py +++ b/src/python/zigopt/experiment/segmenter.py @@ -68,8 +68,8 @@ def prune_intervals(self, experiment, with_assignments_maps, intervals): for has_assignments_map in with_assignments_maps: for name, assignment in has_assignments_map.get_assignments(experiment).items(): remaining_intervals = intervals.get(name, []) - to_remove = find(remaining_intervals, lambda i: assignment in i) # pylint: disable=cell-var-from-loop - if to_remove: +# pylint: disable=cell-var-from-loop + if to_remove := find(remaining_intervals, lambda i: assignment in i): remaining_intervals.remove(to_remove) def pick_value(self, parameter, intervals): diff --git a/src/python/zigopt/experiment/service.py b/src/python/zigopt/experiment/service.py index c1405162..be5918b2 100644 --- a/src/python/zigopt/experiment/service.py +++ b/src/python/zigopt/experiment/service.py @@ -228,14 +228,13 @@ def delete(self, experiment: Experiment) -> None: Not a true DB delete, just sets the deleted flag. """ timestamp = current_datetime() - update = self.services.database_service.update_one_or_none( + if update := self.services.database_service.update_one_or_none( self.services.database_service.query(Experiment).filter(Experiment.id == experiment.id), { Experiment.deleted: True, Experiment.date_updated: timestamp, }, - ) - if update: + ): experiment.date_updated = timestamp experiment.deleted = True self.services.project_service.mark_as_updated_by_experiment(experiment) @@ -271,15 +270,14 @@ def _include_deleted_clause(self, include_deleted: bool, q: Query) -> Query: def mark_as_updated(self, experiment: Experiment, timestamp: datetime.datetime | None = None) -> None: if timestamp is None: timestamp = current_datetime() - did_update = self.services.database_service.update_one_or_none( + if did_update := self.services.database_service.update_one_or_none( self.services.database_service.query(Experiment) .filter(Experiment.id == experiment.id) .filter(Experiment.date_updated < timestamp.replace(microsecond=0)), { Experiment.date_updated: timestamp, }, - ) - if did_update: + ): experiment.date_updated = timestamp self.services.project_service.mark_as_updated_by_experiment(experiment=experiment) diff --git a/src/python/zigopt/file/model.py b/src/python/zigopt/file/model.py index 6ce2707d..6d0a2d2e 100644 --- a/src/python/zigopt/file/model.py +++ b/src/python/zigopt/file/model.py @@ -90,8 +90,7 @@ def get_download_filename(self): "text/plain": ".txt", } mime_type = self.data.content_type - extension = overrides.get(mime_type) - if extension is None: + if (extension := overrides.get(mime_type)) is None: extension = mimetypes.guess_extension(self.data.content_type) if extension is None: extension = "" diff --git a/src/python/zigopt/handlers/aiexperiments/create.py b/src/python/zigopt/handlers/aiexperiments/create.py index 68dbd09c..932213d5 100644 --- a/src/python/zigopt/handlers/aiexperiments/create.py +++ b/src/python/zigopt/handlers/aiexperiments/create.py @@ -340,8 +340,7 @@ def get_metric_list_from_json(cls, json_dict): @classmethod def get_metric_name(cls, metric, seen_names, num_metrics): - name = super().get_metric_name(metric, seen_names, num_metrics) - if name is None: + if (name := super().get_metric_name(metric, seen_names, num_metrics)) is None: raise MissingJsonKeyError("name", "All metrics require a name") return name diff --git a/src/python/zigopt/handlers/aiexperiments/training_runs/create.py b/src/python/zigopt/handlers/aiexperiments/training_runs/create.py index 933b1338..f76974fe 100644 --- a/src/python/zigopt/handlers/aiexperiments/training_runs/create.py +++ b/src/python/zigopt/handlers/aiexperiments/training_runs/create.py @@ -88,8 +88,7 @@ def maybe_create_observation_from_params(self, training_run_params): assignments = self.get_assignments_relevant_to_experiment(training_run_params) values, failed = self.get_observation_values(training_run_params) - observation = self.maybe_create_observation_from_data(assignments, values, failed) - if observation: + if observation := self.maybe_create_observation_from_data(assignments, values, failed): observation_assignments = observation.get_assignments(self.experiment) return observation, assignments_json(self.experiment, observation_assignments) return None, {} @@ -119,8 +118,7 @@ def maybe_create_explicit_suggestion_from_assignments(self, assignments): def maybe_create_suggestion(self, training_run_params): assert self.experiment is not None - provided_experiment_assignments = self.get_assignments_relevant_to_experiment(training_run_params) - if provided_experiment_assignments: + if provided_experiment_assignments := self.get_assignments_relevant_to_experiment(training_run_params): suggestion = self.maybe_create_explicit_suggestion_from_assignments(provided_experiment_assignments) else: suggestion = self.serve_suggestion() @@ -156,8 +154,7 @@ def handle(self, params): params.training_run_params.training_run_data.assignments_struct.update(assignments) training_run_data = params.training_run_params.training_run_data - assignments_meta = training_run_data.assignments_meta - if assignments_meta is not None: + if (assignments_meta := training_run_data.assignments_meta) is not None: validate_assignments_meta(training_run_data.assignments_struct, assignments_meta, None) if suggestion: for assignment in assignments.keys(): diff --git a/src/python/zigopt/handlers/clients/delete.py b/src/python/zigopt/handlers/clients/delete.py index 76888771..c7452c5b 100644 --- a/src/python/zigopt/handlers/clients/delete.py +++ b/src/python/zigopt/handlers/clients/delete.py @@ -14,11 +14,10 @@ def handle(self): assert self.auth is not None assert self.client is not None - user_can_delete = self.services.membership_service.user_is_owner_for_organization( + if user_can_delete := self.services.membership_service.user_is_owner_for_organization( user_id=self.auth.current_user.id, organization_id=self.client.organization_id, - ) - if user_can_delete: + ): self.do_delete() return {} raise ForbiddenError("You cannot delete this client.") diff --git a/src/python/zigopt/handlers/clients/invite.py b/src/python/zigopt/handlers/clients/invite.py index 25e8c3d5..c8800a1a 100644 --- a/src/python/zigopt/handlers/clients/invite.py +++ b/src/python/zigopt/handlers/clients/invite.py @@ -73,9 +73,8 @@ def handle(self, request): raise ForbiddenError("You do not have permission to invite this email address.") client_invites = [dict(id=client.id, role=role, old_role=old_role)] - existing_invite = self.services.invite_service.find_by_email_and_organization(email, client.organization_id) - if existing_invite: + if existing_invite := self.services.invite_service.find_by_email_and_organization(email, client.organization_id): if existing_invite.membership_type == MembershipType.owner: raise InvalidValueError("This email was already invited as an owner") @@ -153,9 +152,7 @@ def handle(self, email): ) self.services.pending_permission_service.delete_by_email_and_client(email, self.client) - invite = self.services.invite_service.find_by_email_and_organization(email, self.client.organization_id) - if invite: - num_pending_permissions = self.services.pending_permission_service.count_by_invite_id(invite.id) - if num_pending_permissions == 0: + if invite := self.services.invite_service.find_by_email_and_organization(email, self.client.organization_id): + if (num_pending_permissions := self.services.pending_permission_service.count_by_invite_id(invite.id)) == 0: self.services.invite_service.delete_by_email_and_organization(email, self.client.organization_id) return {} diff --git a/src/python/zigopt/handlers/clients/tokens.py b/src/python/zigopt/handlers/clients/tokens.py index 89f96258..446ab66f 100644 --- a/src/python/zigopt/handlers/clients/tokens.py +++ b/src/python/zigopt/handlers/clients/tokens.py @@ -57,8 +57,7 @@ def handle(self): assert self.token is not None if not self.token.expiration_timestamp and self.token.all_experiments: raise ForbiddenError("Cannot delete root token") - success = self.services.token_service.delete_token(self.token) - if success: + if success := self.services.token_service.delete_token(self.token): return {} raise NotFoundError("Token not found") @@ -81,8 +80,7 @@ def handle(self, params): assert self.auth is not None assert self.token is not None - new_token_value = params.token - if new_token_value is not None: + if (new_token_value := params.token) is not None: if new_token_value != "rotate": raise SigoptValidationError('Token must equal "rotate"') self.services.token_service.rotate_token(self.token) @@ -90,8 +88,7 @@ def handle(self, params): meta = copy_protobuf(self.token.meta) meta.lasts_forever = params.lasts_forever self.services.token_service.update_meta(self.token, meta) - new_expires_value = params.expires - if new_expires_value is not None: + if (new_expires_value := params.expires) is not None: if new_expires_value != "renew": raise SigoptValidationError('Expires must equal "renew"') if self.token.meta.can_renew: @@ -190,11 +187,10 @@ def ensure_includes_role_token(self, tokens): auth = self.auth assert self.client is not None client = self.client - role_token = find( + if (role_token := find( tokens, lambda t: t.user_id == auth.current_user.id and t.client_id == client.id and t.development is False, - ) - if role_token is None: + )) is None: role_token = self.services.token_service.get_or_create_role_token( self.client.id, self.auth.current_user.id, @@ -208,11 +204,10 @@ def ensure_includes_development_role_token(self, tokens): assert current_user is not None assert self.client is not None client = self.client - development_token = find( + if (development_token := find( tokens, lambda t: t.user_id == current_user.id and t.client_id == client.id and t.development is True, - ) - if development_token is None: + )) is None: development_token = self.services.token_service.get_or_create_development_role_token( self.client.id, self.auth.current_user.id, diff --git a/src/python/zigopt/handlers/experiments/checkpoints/create.py b/src/python/zigopt/handlers/experiments/checkpoints/create.py index 29ffd6bc..2bf70f40 100644 --- a/src/python/zigopt/handlers/experiments/checkpoints/create.py +++ b/src/python/zigopt/handlers/experiments/checkpoints/create.py @@ -43,8 +43,7 @@ def _parse_values(self, data): observation_value.name = get_with_validation(value_dict, "name", ValidationType.string) observation_value.value = get_with_validation(value_dict, "value", ValidationType.number) value_stddev: float | None = get_opt_with_validation(value_dict, "value_stddev", ValidationType.number) - value_var = napply(value_stddev, lambda stddev: stddev * stddev) - if value_var is not None: + if (value_var := napply(value_stddev, lambda stddev: stddev * stddev)) is not None: observation_value.value_var = value_var yield observation_value diff --git a/src/python/zigopt/handlers/experiments/create.py b/src/python/zigopt/handlers/experiments/create.py index afc33bed..f8605054 100644 --- a/src/python/zigopt/handlers/experiments/create.py +++ b/src/python/zigopt/handlers/experiments/create.py @@ -280,15 +280,13 @@ def make_experiment_meta_from_json( has_constraint_metrics = any(m.strategy == ExperimentMetric.CONSTRAINT for m in experiment_meta.metrics) has_optimization_metrics = len(optimized_metrics) > 0 - num_solutions = cls.get_num_solutions_from_json( + if (num_solutions := cls.get_num_solutions_from_json( json_dict, experiment_meta.all_parameters_unsorted, - ) - if num_solutions is not None: + )) is not None: experiment_meta.num_solutions = num_solutions - parallel_bandwidth = cls.get_parallel_bandwidth_from_json(json_dict) - if parallel_bandwidth is not None: + if (parallel_bandwidth := cls.get_parallel_bandwidth_from_json(json_dict)) is not None: experiment_meta.parallel_bandwidth = parallel_bandwidth # Set observation budget if present and check to see which features require a budget @@ -305,8 +303,7 @@ def make_experiment_meta_from_json( cls._check_multisolution_viability(experiment_meta, num_solutions, optimized_metrics) - client_provided_data = cls.get_client_provided_data(json_dict) - if client_provided_data is not None: + if (client_provided_data := cls.get_client_provided_data(json_dict)) is not None: experiment_meta.client_provided_data = client_provided_data if not (has_optimization_metrics or has_constraint_metrics): @@ -382,12 +379,11 @@ def get_metric_strategy(cls, metric): @classmethod def get_metric_list_from_json(cls, json_dict): - metrics = get_opt_with_validation( + if (metrics := get_opt_with_validation( json_dict, "metrics", ValidationType.arrayOf(ValidationType.oneOf([ValidationType.string, ValidationType.object])), - ) - if metrics is None: + )) is None: assert MAX_METRICS_ANY_STRATEGY >= 1 assert MAX_OPTIMIZED_METRICS >= 1 return [ExperimentMetric()] @@ -479,8 +475,7 @@ def _add_constraint_term( constrained_integer_variables, constraint_var_set, ): - coeff = get_opt_with_validation(term, "weight", ValidationType.number) - if coeff == 0: + if (coeff := get_opt_with_validation(term, "weight", ValidationType.number)) == 0: return name = get_opt_with_validation(term, "name", ValidationType.string) name = validate_variable_name(name) diff --git a/src/python/zigopt/handlers/experiments/list_base.py b/src/python/zigopt/handlers/experiments/list_base.py index f1999950..f1185874 100644 --- a/src/python/zigopt/handlers/experiments/list_base.py +++ b/src/python/zigopt/handlers/experiments/list_base.py @@ -71,8 +71,7 @@ def parse_state_param(self, request): @classmethod def get_include_ai_param(cls, request): - include_ai = request.optional_bool_param("include_ai") - if include_ai is None: + if (include_ai := request.optional_bool_param("include_ai")) is None: include_ai = False return include_ai @@ -132,7 +131,9 @@ def _maybe_search(self, query_args, params, client_ids): if params.search is None: return keyword = params.search.lower().strip() - matching_user_ids = set( + + # TODO: Consider optimization of search ordering + if matching_user_ids := set( u_id for (u_id,) in flatten( [ @@ -152,10 +153,7 @@ def _maybe_search(self, query_args, params, client_ids): ), ] ) - ) - - # TODO: Consider optimization of search ordering - if matching_user_ids: + ): search_filter = or_(Experiment.created_by.in_(matching_user_ids), Experiment.name.ilike(f"%{keyword}%")) else: search_filter = Experiment.name.ilike(f"%{keyword}%") diff --git a/src/python/zigopt/handlers/experiments/metric_importances/update.py b/src/python/zigopt/handlers/experiments/metric_importances/update.py index d490eec8..313ff27b 100644 --- a/src/python/zigopt/handlers/experiments/metric_importances/update.py +++ b/src/python/zigopt/handlers/experiments/metric_importances/update.py @@ -16,11 +16,10 @@ def handle(self): assert self.experiment is not None num_observations = self.services.observation_service.count_by_experiment(self.experiment) - q_msg = self.services.optimize_queue_service.always_enqueue_importances( + if (q_msg := self.services.optimize_queue_service.always_enqueue_importances( experiment=self.experiment, num_observations=num_observations, - ) - if q_msg is None: + )) is None: raise UnprocessableEntityError( "Parameter importances update failed. (This experiment may not support importances.)" ) diff --git a/src/python/zigopt/handlers/experiments/observations/base.py b/src/python/zigopt/handlers/experiments/observations/base.py index 7885a42e..dbc6bc9d 100644 --- a/src/python/zigopt/handlers/experiments/observations/base.py +++ b/src/python/zigopt/handlers/experiments/observations/base.py @@ -23,8 +23,7 @@ def find_objects(self): ) def _find_observation(self, observation_id): - observation = self.services.observation_service.find_by_id(observation_id) - if observation is not None: + if (observation := self.services.observation_service.find_by_id(observation_id)) is not None: if observation.experiment_id == self.experiment_id: return observation raise NotFoundError(f"No observation {observation_id} for experiment {self.experiment_id}") diff --git a/src/python/zigopt/handlers/experiments/observations/create.py b/src/python/zigopt/handlers/experiments/observations/create.py index 2010c8f3..c97d0bc8 100644 --- a/src/python/zigopt/handlers/experiments/observations/create.py +++ b/src/python/zigopt/handlers/experiments/observations/create.py @@ -70,10 +70,9 @@ def observation_from_json( elif observation.timestamp: observation_data.timestamp = observation.timestamp - client_provided_data = BaseExperimentsCreateHandler.get_client_provided_data( + if (client_provided_data := BaseExperimentsCreateHandler.get_client_provided_data( json_dict, default=observation.client_provided_data - ) - if client_provided_data is not None: + )) is not None: observation_data.client_provided_data = client_provided_data else: if observation_data.HasField("client_provided_data"): @@ -171,8 +170,7 @@ def handle(self, request): no_optimize=self.get_no_optimize_for_observation_json(params), ) - bad_keys = params.keys() - observation_json.keys() - if bad_keys: + if bad_keys := params.keys() - observation_json.keys(): raise InvalidKeyError(f"Unknown keys were provided for observation create: {bad_keys}") observation_json = remove_nones_mapping(observation_json) diff --git a/src/python/zigopt/handlers/experiments/queued_suggestions/base.py b/src/python/zigopt/handlers/experiments/queued_suggestions/base.py index 7a109911..8059d3dd 100644 --- a/src/python/zigopt/handlers/experiments/queued_suggestions/base.py +++ b/src/python/zigopt/handlers/experiments/queued_suggestions/base.py @@ -30,8 +30,7 @@ def can_act_on_objects(self, requested_permission, objects): ) def _find_queued_suggestion(self, queued_suggestion_id): - queued_suggestion = self.services.queued_suggestion_service.find_by_id(self.experiment_id, queued_suggestion_id) - if queued_suggestion is not None: + if (queued_suggestion := self.services.queued_suggestion_service.find_by_id(self.experiment_id, queued_suggestion_id)) is not None: if queued_suggestion.experiment_id == self.experiment_id: return queued_suggestion raise NotFoundError(f"No QueuedSuggestion {queued_suggestion_id} for experiment {self.experiment_id}") diff --git a/src/python/zigopt/handlers/experiments/suggestions/base.py b/src/python/zigopt/handlers/experiments/suggestions/base.py index 623c9f5f..078b5d48 100644 --- a/src/python/zigopt/handlers/experiments/suggestions/base.py +++ b/src/python/zigopt/handlers/experiments/suggestions/base.py @@ -23,8 +23,7 @@ def find_objects(self): ) def _find_suggestion(self, suggestion_id): - suggestion = self.services.suggestion_service.find_by_id(suggestion_id) - if suggestion is not None: + if (suggestion := self.services.suggestion_service.find_by_id(suggestion_id)) is not None: if suggestion.experiment_id == self.experiment_id: return suggestion raise NotFoundError(f"No suggestion {suggestion_id} for experiment {self.experiment_id}") diff --git a/src/python/zigopt/handlers/experiments/suggestions/create.py b/src/python/zigopt/handlers/experiments/suggestions/create.py index aa7aeca7..443c8db8 100644 --- a/src/python/zigopt/handlers/experiments/suggestions/create.py +++ b/src/python/zigopt/handlers/experiments/suggestions/create.py @@ -60,8 +60,7 @@ def make_suggestion_meta_from_json(self, json_dict): def make_processed_suggestion_meta_from_json(json_dict): suggestion_meta = ProcessedSuggestionMeta() - client_provided_data = BaseExperimentsCreateHandler.get_client_provided_data(json_dict) - if client_provided_data is not None: + if (client_provided_data := BaseExperimentsCreateHandler.get_client_provided_data(json_dict)) is not None: suggestion_meta.client_provided_data = client_provided_data return suggestion_meta diff --git a/src/python/zigopt/handlers/experiments/suggestions/update.py b/src/python/zigopt/handlers/experiments/suggestions/update.py index 68759e66..f7a047f9 100644 --- a/src/python/zigopt/handlers/experiments/suggestions/update.py +++ b/src/python/zigopt/handlers/experiments/suggestions/update.py @@ -22,10 +22,9 @@ def handle(self, json_dict): assert self.suggestion is not None suggestion_meta = ProcessedSuggestionMeta() - client_provided_data = BaseExperimentsCreateHandler.get_client_provided_data( + if (client_provided_data := BaseExperimentsCreateHandler.get_client_provided_data( json_dict, default=self.suggestion.client_provided_data - ) - if client_provided_data is not None: + )) is not None: suggestion_meta.client_provided_data = client_provided_data processed = self.suggestion.processed diff --git a/src/python/zigopt/handlers/experiments/update.py b/src/python/zigopt/handlers/experiments/update.py index a1b7da2b..67426717 100644 --- a/src/python/zigopt/handlers/experiments/update.py +++ b/src/python/zigopt/handlers/experiments/update.py @@ -287,8 +287,7 @@ def _maybe_set_metadata(self, json_dict, new_meta, update_meta_fields): def _maybe_set_parallel_bandwidth(self, json_dict, new_meta, update_meta_fields): if "parallel_bandwidth" not in json_dict: return - parallel_bandwidth = BaseExperimentsCreateHandler.get_parallel_bandwidth_from_json(json_dict) - if parallel_bandwidth is None: + if (parallel_bandwidth := BaseExperimentsCreateHandler.get_parallel_bandwidth_from_json(json_dict)) is None: new_meta.ClearField("parallel_bandwidth") else: new_meta.parallel_bandwidth = parallel_bandwidth @@ -355,11 +354,10 @@ def handle(self, json_dict): update_experiment_fields["date_updated"] = current_datetime() self.experiment.date_updated = update_experiment_fields["date_updated"] - update_count = self.services.database_service.update_one( + if (update_count := self.services.database_service.update_one( self.services.database_service.query(Experiment).filter(Experiment.id == self.experiment.id), update_experiment_fields, - ) - if update_count == 0: + )) == 0: raise NotFoundError(f"No experiment {self.experiment.id}") if original_project_id is not None: @@ -528,12 +526,11 @@ def _maybe_set_parameter_grid_values(self, parameter, parameter_json): set_grid_values_from_json(parameter, parameter_json) def _maybe_set_parameter_categorical_values(self, parameter, parameter_json): - categorical_values_json = get_opt_with_validation( + if (categorical_values_json := get_opt_with_validation( parameter_json, "categorical_values", ValidationType.arrayOf(ValidationType.oneOf([ValidationType.object, ValidationType.string])), - ) - if categorical_values_json is None: + )) is None: return categorical_values_map = dict((c.name, c) for c in parameter.all_categorical_values) diff --git a/src/python/zigopt/handlers/files/base.py b/src/python/zigopt/handlers/files/base.py index 9841c36e..fcbaf7e1 100644 --- a/src/python/zigopt/handlers/files/base.py +++ b/src/python/zigopt/handlers/files/base.py @@ -28,8 +28,7 @@ def get_file_not_found_error(self): def find_objects(self): objects = super().find_objects() - file_obj = self.services.file_service.find_by_id(self.file_id) - if file_obj is None: + if (file_obj := self.services.file_service.find_by_id(self.file_id)) is None: raise self.get_file_not_found_error() objects["file"] = file_obj client = self.services.client_service.find_by_id(file_obj.client_id) diff --git a/src/python/zigopt/handlers/organizations/update.py b/src/python/zigopt/handlers/organizations/update.py index 061a5284..56c1078b 100644 --- a/src/python/zigopt/handlers/organizations/update.py +++ b/src/python/zigopt/handlers/organizations/update.py @@ -44,8 +44,7 @@ def parse_params(self, request): "allow_signup_from_email_domains", ValidationType.boolean, ) - email_domains = get_opt_with_validation(data, "email_domains", ValidationType.arrayOf(ValidationType.string)) - if email_domains: + if email_domains := get_opt_with_validation(data, "email_domains", ValidationType.arrayOf(ValidationType.string)): email_domains = [validate_email_domain(domain) for domain in email_domains] did_enable_allow_signup = ( diff --git a/src/python/zigopt/handlers/projects/base.py b/src/python/zigopt/handlers/projects/base.py index 73c636dc..5197cce1 100644 --- a/src/python/zigopt/handlers/projects/base.py +++ b/src/python/zigopt/handlers/projects/base.py @@ -24,11 +24,10 @@ def find_objects(self): ) def _find_project(self): - project = self.services.project_service.find_by_client_and_reference_id( + if (project := self.services.project_service.find_by_client_and_reference_id( client_id=self.client_id, reference_id=self.project_reference_id, - ) - if project is None: + )) is None: raise NotFoundError(f"No project {self.project_reference_id} in client {self.client_id}") return project diff --git a/src/python/zigopt/handlers/training_runs/base.py b/src/python/zigopt/handlers/training_runs/base.py index 8a664062..62955aae 100644 --- a/src/python/zigopt/handlers/training_runs/base.py +++ b/src/python/zigopt/handlers/training_runs/base.py @@ -66,22 +66,19 @@ def _maybe_find_experiment(self, experiment_id): return None def _find_training_run(self, training_run_id): - training_run = self.services.training_run_service.find_by_id(training_run_id) - if training_run: + if training_run := self.services.training_run_service.find_by_id(training_run_id): return training_run raise NotFoundError(f"Training run {training_run_id} not found") def _find_client(self, client_id): - client = self.services.client_service.find_by_id(client_id) - if client: + if client := self.services.client_service.find_by_id(client_id): return client raise NotFoundError(f"Client {client_id} not found") def _find_project(self, project_id, client_id, training_run_id): if project_id is None: raise NotFoundError(f"Training run {training_run_id} not found") - project = self.services.project_service.find_by_client_and_id(client_id, project_id) - if project is None: + if (project := self.services.project_service.find_by_client_and_id(client_id, project_id)) is None: raise NotFoundError(f"Training run {training_run_id} not found") return project diff --git a/src/python/zigopt/handlers/training_runs/create.py b/src/python/zigopt/handlers/training_runs/create.py index 146137ca..7289dd7e 100644 --- a/src/python/zigopt/handlers/training_runs/create.py +++ b/src/python/zigopt/handlers/training_runs/create.py @@ -36,8 +36,7 @@ def parse_params(self, request): return self.parse_request(request) def parse_request(self, request): - unaccepted_params = request.params().keys() - TrainingRunRequestParams.valid_fields - if unaccepted_params: + if unaccepted_params := request.params().keys() - TrainingRunRequestParams.valid_fields: raise SigoptValidationError(f"Unknown parameters: {unaccepted_params}") training_run_params = TrainingRunRequestParser().parse_params(request) @@ -48,8 +47,7 @@ def parse_request(self, request): if training_run_params.project is not None or training_run_params.field_is_explicitly_null("project"): raise SigoptValidationError("`project` is not a valid JSON key for this endpoint.") - assignments_meta = training_run_params.training_run_data.assignments_meta - if assignments_meta is not None: + if (assignments_meta := training_run_params.training_run_data.assignments_meta) is not None: validate_assignments_meta(training_run_params.training_run_data.assignments_struct, assignments_meta, None) return self.Params(training_run_params=training_run_params) diff --git a/src/python/zigopt/handlers/training_runs/files.py b/src/python/zigopt/handlers/training_runs/files.py index 7578f927..b754a9e1 100644 --- a/src/python/zigopt/handlers/training_runs/files.py +++ b/src/python/zigopt/handlers/training_runs/files.py @@ -52,8 +52,7 @@ def can_act_on_objects(self, requested_permission, objects): def parse_params(self, request): provided_params = request.params() acceptable_params = [key for key, _ in self.REQUIRED_INPUT_PARAMS + self.OPTIONAL_INPUT_PARAMS] - unaccepted_params = provided_params.keys() - acceptable_params - if unaccepted_params: + if unaccepted_params := provided_params.keys() - acceptable_params: raise SigoptValidationError( f"Unknown parameters: {unaccepted_params}. Only the following parameters are accepted: {acceptable_params}" ) diff --git a/src/python/zigopt/handlers/training_runs/list.py b/src/python/zigopt/handlers/training_runs/list.py index cbdb9178..691020c5 100644 --- a/src/python/zigopt/handlers/training_runs/list.py +++ b/src/python/zigopt/handlers/training_runs/list.py @@ -187,8 +187,7 @@ def _parse_filters(self, filters_param_value): if field.clause is None: raise InvalidKeyError("clause", f"Invalid field: {field.name}") input_operator = get_with_validation(f, "operator", ValidationType.string) - resolved_operator = field.interpret_operator(input_operator) - if resolved_operator is STRING_TO_OPERATOR_DICT["isnull"]: + if (resolved_operator := field.interpret_operator(input_operator)) is STRING_TO_OPERATOR_DICT["isnull"]: cast, value = None, None else: input_value = get_unvalidated(f, "value") diff --git a/src/python/zigopt/handlers/training_runs/tags.py b/src/python/zigopt/handlers/training_runs/tags.py index 7f7542b0..86b3819a 100644 --- a/src/python/zigopt/handlers/training_runs/tags.py +++ b/src/python/zigopt/handlers/training_runs/tags.py @@ -39,8 +39,7 @@ class TrainingRunsAddTagHandler(BaseTrainingRunsTagHandler): def parse_params(self, request): provided_params = request.params() acceptable_params = [key for key, _ in self.INPUT_PARAMS] - unaccepted_params = provided_params.keys() - acceptable_params - if unaccepted_params: + if unaccepted_params := provided_params.keys() - acceptable_params: raise SigoptValidationError( f"Unknown parameters: {unaccepted_params}. Only the following parameters are accepted: {acceptable_params}" ) @@ -53,12 +52,11 @@ def handle(self, params): assert self.training_run is not None tag_id = params[self.ID_PARAM] - tag = self.services.tag_service.find_by_client_and_id( + + if (tag := self.services.tag_service.find_by_client_and_id( client_id=self.training_run.client_id, tag_id=tag_id, - ) - - if tag is None: + )) is None: raise UnprocessableEntityError( f"The tag with id {tag_id} cannot be added to this training run because it does not exist." ) @@ -90,11 +88,10 @@ def __init__(self, *args, tag_id, **kwargs): def find_objects(self): objs = super().find_objects() - tag = self.services.tag_service.find_by_client_and_id( + if (tag := self.services.tag_service.find_by_client_and_id( client_id=objs["training_run"].client_id, tag_id=self.tag_id, - ) - if tag is None: + )) is None: raise NotFoundError("Tag not found") objs["tag"] = tag return objs diff --git a/src/python/zigopt/handlers/training_runs/update.py b/src/python/zigopt/handlers/training_runs/update.py index e56a0a0d..207e07ba 100644 --- a/src/python/zigopt/handlers/training_runs/update.py +++ b/src/python/zigopt/handlers/training_runs/update.py @@ -97,8 +97,7 @@ class TrainingRunsUpdateHandler(CreatesObservationsMixin, TrainingRunHandler): training_run: TrainingRun | None def parse_params(self, request): - unaccepted_params = request.params().keys() - TrainingRunRequestParams.valid_fields - if unaccepted_params: + if unaccepted_params := request.params().keys() - TrainingRunRequestParams.valid_fields: raise SigoptValidationError(f"Unknown parameters: {unaccepted_params}") method = request.method.lower() @@ -167,8 +166,7 @@ def handle(self, params): now = current_datetime() training_run_data = params.training_run_params.training_run_data - assignments_meta = training_run_data.assignments_meta - if assignments_meta is not None: + if (assignments_meta := training_run_data.assignments_meta) is not None: validate_assignments_meta(training_run_data.assignments_struct, assignments_meta, self.training_run) if params.training_run_params.deleted is not None: @@ -195,9 +193,8 @@ def handle(self, params): update_clause[TrainingRun.completed] = sql_coalesce(TrainingRun.completed, now) previous_deleted = self.training_run.deleted - new_deleted = update_clause.pop(TrainingRun.deleted, previous_deleted) - if new_deleted != previous_deleted: + if (new_deleted := update_clause.pop(TrainingRun.deleted, previous_deleted)) != previous_deleted: self.services.training_run_service.set_deleted(self.training_run.id, deleted=new_deleted) if update_clause: diff --git a/src/python/zigopt/handlers/users/create.py b/src/python/zigopt/handlers/users/create.py index 40ef9d8a..93e1d049 100644 --- a/src/python/zigopt/handlers/users/create.py +++ b/src/python/zigopt/handlers/users/create.py @@ -174,8 +174,7 @@ def get_verified_invite(self, email, invite_code): if invite.invite_code == invite_code or (not self.services.email_verification_service.enabled) ] # TODO: This doesn't handle multiple invite codes. We probably don't need to - claimable_invite = list_get(claimable_invites, 0) - if claimable_invite: + if claimable_invite := list_get(claimable_invites, 0): if self.services.invite_service.invite_is_valid(claimable_invite): return claimable_invite raise ForbiddenError("This invite is no longer valid.") @@ -247,8 +246,7 @@ def parse_params(self, request): ) def handle(self, params): - verified_invite = self.get_verified_invite(params.user_attributes.email, params.invite_code) - if verified_invite: + if verified_invite := self.get_verified_invite(params.user_attributes.email, params.invite_code): has_verified_email = params.has_verified_email or verified_invite.invite_code == params.invite_code user = self.create_user_by_invite(params.user_attributes, verified_invite, has_verified_email) self.create_clients_and_permissions(user, verified_invite) diff --git a/src/python/zigopt/handlers/users/delete.py b/src/python/zigopt/handlers/users/delete.py index b02c5043..63de4751 100644 --- a/src/python/zigopt/handlers/users/delete.py +++ b/src/python/zigopt/handlers/users/delete.py @@ -24,8 +24,7 @@ def handle(self, password): self.ensure_no_orphaned_organizations() self.ensure_no_orphaned_clients() - validated_request = password and password_matches(password, self.user.hashed_password) - if validated_request: + if validated_request := password and password_matches(password, self.user.hashed_password): do_password_hash_work_factor_update(self.services, self.user, password) if validated_request: self.do_delete() @@ -54,14 +53,12 @@ def ensure_no_orphaned_organizations(self): ) still_owned_organization_ids = set(m.organization_id for m in other_owner_memberships) - unowned_organization_ids = organization_ids - still_owned_organization_ids - if unowned_organization_ids: + if unowned_organization_ids := organization_ids - still_owned_organization_ids: other_non_owner_memberships = self.services.membership_service.organizations_with_other_non_owners( list(unowned_organization_ids), self.user.id, ) - orphaned_organization_ids = set(m.organization_id for m in other_non_owner_memberships) - if orphaned_organization_ids: + if orphaned_organization_ids := set(m.organization_id for m in other_non_owner_memberships): raise ForbiddenError( "This user cannot be deleted without assigning another owner or removing " f"all other users from the following organizations: {orphaned_organization_ids}" @@ -80,18 +77,16 @@ def ensure_no_orphaned_clients(self): self.user.id, ) other_owned_organizations = set(m.organization_id for m in organizations_with_other_owners) - client_ids = [p.client_id for p in permissions if p.organization_id not in other_owned_organizations] - if client_ids: + if client_ids := [p.client_id for p in permissions if p.organization_id not in other_owned_organizations]: outstanding_permissions = self.services.database_service.all( self.services.database_service.query(Permission) .filter(Permission.client_id.in_(client_ids)) .filter(Permission.user_id != self.user.id) ) outstanding_client_ids = [p.client_id for p in outstanding_permissions] - orphaned_clients = list(set(client_ids) - set(outstanding_client_ids)) - if orphaned_clients: + if orphaned_clients := list(set(client_ids) - set(outstanding_client_ids)): raise ForbiddenError(f"This user cannot be deleted without deleting the following clients: {orphaned_clients}") def do_delete(self): diff --git a/src/python/zigopt/handlers/users/password.py b/src/python/zigopt/handlers/users/password.py index 9cf30d2a..967617e7 100644 --- a/src/python/zigopt/handlers/users/password.py +++ b/src/python/zigopt/handlers/users/password.py @@ -88,8 +88,7 @@ def parse_params(self, request): return validate_email(request.required_param("email")) def handle(self, email): - user = self.services.user_service.find_by_email(email) - if user: + if user := self.services.user_service.find_by_email(email): if user.hashed_password: code = self.services.user_service.set_password_reset_code(user) self.services.email_router.send( diff --git a/src/python/zigopt/handlers/users/sessions.py b/src/python/zigopt/handlers/users/sessions.py index f4649cb1..c767cb55 100644 --- a/src/python/zigopt/handlers/users/sessions.py +++ b/src/python/zigopt/handlers/users/sessions.py @@ -24,8 +24,7 @@ def get_client_for_user(self, user, preferred_client_id=None): user, memberships, ) - client = find(clients, lambda c: c.id == preferred_client_id) - if client is None: + if (client := find(clients, lambda c: c.id == preferred_client_id)) is None: client = min(clients, key=lambda c: (c.id if c else 0), default=None) if client: diff --git a/src/python/zigopt/handlers/users/update.py b/src/python/zigopt/handlers/users/update.py index 1f9b2efb..c92dc37d 100644 --- a/src/python/zigopt/handlers/users/update.py +++ b/src/python/zigopt/handlers/users/update.py @@ -34,8 +34,7 @@ class UsersUpdateHandler(UserHandler): def parse_params(self, request): data = request.params() - name = get_opt_with_validation(data, "name", ValidationType.string) - if name is not None: + if (name := get_opt_with_validation(data, "name", ValidationType.string)) is not None: name = validate_user_name(name) educational_user = get_opt_with_validation(data, "educational_user", ValidationType.boolean) email = get_opt_with_validation(data, "email", ValidationType.string) @@ -98,8 +97,7 @@ def do_update(self, user_id, uploaded_user): user.user_meta = user_meta old_email = user.email - new_email = uploaded_user.email - if new_email is not None: + if (new_email := uploaded_user.email) is not None: if not user.hashed_password: raise SigoptValidationError("You cannot change your email because your account is externally administered.") if uploaded_user.password is None: @@ -110,14 +108,12 @@ def do_update(self, user_id, uploaded_user): email_verification_code = self.services.user_service.change_user_email_without_save(user, new_email) - show_welcome = uploaded_user.show_welcome - if show_welcome is not None: + if (show_welcome := uploaded_user.show_welcome) is not None: user_meta = copy_protobuf(user.user_meta) user_meta.show_welcome = show_welcome user.user_meta = user_meta - planned_usage = uploaded_user.planned_usage - if planned_usage is None: + if (planned_usage := uploaded_user.planned_usage) is None: user.user_meta.ClearField("planned_usage") elif planned_usage is self.NOT_PROVIDED: pass diff --git a/src/python/zigopt/handlers/validate/training_run.py b/src/python/zigopt/handlers/validate/training_run.py index 2f7fc181..588fb01e 100644 --- a/src/python/zigopt/handlers/validate/training_run.py +++ b/src/python/zigopt/handlers/validate/training_run.py @@ -53,8 +53,7 @@ def validate_assignments_meta( if OPTIMIZED_ASSIGNMENT_SOURCE in sources: raise SigoptValidationError(f" {OPTIMIZED_ASSIGNMENT_SOURCE} source is reserved and set automatically.") - metas_without_param = set(meta_keys) - set(params) - if metas_without_param: + if metas_without_param := set(meta_keys) - set(params): raise SigoptValidationError( f" Parameter meta exist for {metas_without_param} but there is no corresponding parameter." ) diff --git a/src/python/zigopt/handlers/validate/user.py b/src/python/zigopt/handlers/validate/user.py index d5adc38b..22030e7e 100644 --- a/src/python/zigopt/handlers/validate/user.py +++ b/src/python/zigopt/handlers/validate/user.py @@ -10,8 +10,7 @@ def validate_user_name(name: Optional[str]) -> str: - name = validate_name(name) - if name: + if name := validate_name(name): if len(name) >= User.NAME_MAX_LENGTH: raise InvalidValueError(f"Name must be fewer than {User.NAME_MAX_LENGTH} characters") return name diff --git a/src/python/zigopt/handlers/validate/validate_dict.py b/src/python/zigopt/handlers/validate/validate_dict.py index 069035ce..7a46ee37 100644 --- a/src/python/zigopt/handlers/validate/validate_dict.py +++ b/src/python/zigopt/handlers/validate/validate_dict.py @@ -526,15 +526,13 @@ def get_opt_with_validation(json_obj: dict[str, Any], key: str, value_type: IOVa :rtype: ``value_type`` or None """ - value = json_obj.get(key) - if value is None: + if (value := json_obj.get(key)) is None: return None return validate_type(value, value_type.get_input_validator(), key=key) def get_unvalidated(json_obj: dict[str, Any], key: str) -> Any: - value = json_obj.get(key) - if value is None: + if (value := json_obj.get(key)) is None: raise MissingJsonKeyError(key, json_obj) return value diff --git a/src/python/zigopt/handlers/validate/web_data/base.py b/src/python/zigopt/handlers/validate/web_data/base.py index f8f0c172..b95f0862 100644 --- a/src/python/zigopt/handlers/validate/web_data/base.py +++ b/src/python/zigopt/handlers/validate/web_data/base.py @@ -68,8 +68,7 @@ def rest_parent_id_validator(parent_resource_id): def validate_resource_exists(parent_resource_type, web_data_type): - test = schema_by_resource.get(parent_resource_type, {}).get(web_data_type, None) - if test is None: + if (test := schema_by_resource.get(parent_resource_type, {}).get(web_data_type, None)) is None: raise SigoptValidationError( f"Could not find web data type: `{web_data_type}` for resource: `{parent_resource_type}`" ) diff --git a/src/python/zigopt/handlers/web_data/base.py b/src/python/zigopt/handlers/web_data/base.py index 6e367136..36866cfa 100644 --- a/src/python/zigopt/handlers/web_data/base.py +++ b/src/python/zigopt/handlers/web_data/base.py @@ -23,8 +23,7 @@ def __init__(self, *args, **kwargs): def can_act_on_objects(self, requested_permission, objects): params = self._request.request.params() - parent_resource_id = params.get("parent_resource_id", None) - if parent_resource_id is None: + if (parent_resource_id := params.get("parent_resource_id", None)) is None: raise SigoptValidationError("No parent_resource_id set.") # For list/delete parent_resource_id is stringified to fit inside query params diff --git a/src/python/zigopt/handlers/web_data/delete.py b/src/python/zigopt/handlers/web_data/delete.py index 31fc0b7b..ed5e1809 100644 --- a/src/python/zigopt/handlers/web_data/delete.py +++ b/src/python/zigopt/handlers/web_data/delete.py @@ -26,10 +26,9 @@ def handle(self, params): web_data_id = params["id"] # Ensure the id is for the right resource - web_data = self.services.web_data_service.find_by_parent_resource_id_and_id( + if (web_data := self.services.web_data_service.find_by_parent_resource_id_and_id( parent_resource, web_data_type, parent_resource_id, web_data_id - ) - if web_data is None: + )) is None: raise NotFoundError( f"Cannot find web data of type: {web_data_type}, with parent resource: {parent_resource} and id: {web_data_id}." ) diff --git a/src/python/zigopt/handlers/web_data/update.py b/src/python/zigopt/handlers/web_data/update.py index fe7cd684..565b8019 100644 --- a/src/python/zigopt/handlers/web_data/update.py +++ b/src/python/zigopt/handlers/web_data/update.py @@ -28,12 +28,11 @@ def handle(self, params): payload = params["payload"] web_data_id = params["id"] - old_web_data = self.services.web_data_service.find_by_parent_resource_id_and_id( - parent_resource_type, web_data_type, parent_resource_id, web_data_id - ) # Web Data cannot change parent resoruce - if old_web_data is None: + if (old_web_data := self.services.web_data_service.find_by_parent_resource_id_and_id( + parent_resource_type, web_data_type, parent_resource_id, web_data_id + )) is None: raise ForbiddenError(f"{web_data_type} cannot be moved between {parent_resource_type}.") update_dict = { diff --git a/src/python/zigopt/json/assignments.py b/src/python/zigopt/json/assignments.py index 8e6a3ceb..5f197a41 100644 --- a/src/python/zigopt/json/assignments.py +++ b/src/python/zigopt/json/assignments.py @@ -22,8 +22,7 @@ def assignments_json(experiment: Experiment, assignments_dict: dict[str, float]) }, } - unknown_keys = set(assignments_dict.keys()) - set(json_dict.keys()) - if unknown_keys: + if unknown_keys := set(assignments_dict.keys()) - set(json_dict.keys()): raise Exception(f"Attempting to render assignments, unknown keys: {', '.join(unknown_keys)}") return json_dict diff --git a/src/python/zigopt/json/builder/experiment.py b/src/python/zigopt/json/builder/experiment.py index 0687947a..f05cbb84 100644 --- a/src/python/zigopt/json/builder/experiment.py +++ b/src/python/zigopt/json/builder/experiment.py @@ -142,8 +142,7 @@ def _pattr(self, attr: str, default: Any = None) -> Any: return coalesce(value, default) def _observation(self, attr: str) -> Optional[ObservationJsonBuilder]: - obs = self._pattr(attr) - if obs is not None: + if (obs := self._pattr(attr)) is not None: return ObservationJsonBuilder(self._experiment, obs) return None diff --git a/src/python/zigopt/json/builder/json_builder.py b/src/python/zigopt/json/builder/json_builder.py index 6f28ad9b..dee1180f 100644 --- a/src/python/zigopt/json/builder/json_builder.py +++ b/src/python/zigopt/json/builder/json_builder.py @@ -172,8 +172,7 @@ def __new__(cls, *args, **kwargs): ) for attr in dir(cls): func = getattr(cls, attr) - field_exposer = getattr(func, "__field_exposer", None) - if field_exposer is None: + if (field_exposer := getattr(func, "__field_exposer", None)) is None: continue field_exposer.update_field_dict(field_dict) setattr(instance, "builder", BuilderDetails(cls.object_name, field_dict)) diff --git a/src/python/zigopt/json/render.py b/src/python/zigopt/json/render.py index ae6cec91..c8302685 100644 --- a/src/python/zigopt/json/render.py +++ b/src/python/zigopt/json/render.py @@ -48,8 +48,7 @@ def render_param_value(parameter: ExperimentParameterProxy, assignment: float) - def render_conditional_value(conditional: ExperimentConditional, assignment: float) -> str: - value = find(conditional.values, lambda c: c.enum_index == assignment) - if value: + if value := find(conditional.values, lambda c: c.enum_index == assignment): return value.name raise ConditionalParamRenderException( f"Conditional {conditional.name} attempting to render unknown value {assignment}", conditional, assignment diff --git a/src/python/zigopt/log/base.py b/src/python/zigopt/log/base.py index deaa3faa..2d1bac31 100644 --- a/src/python/zigopt/log/base.py +++ b/src/python/zigopt/log/base.py @@ -91,8 +91,7 @@ def sensitive_filter(record): def set_default_formatter(formatter): root_logger = logging.getLogger() - default_handler = list_get(root_logger.handlers, 0) - if default_handler: + if default_handler := list_get(root_logger.handlers, 0): default_handler.setFormatter(formatter) @@ -131,8 +130,7 @@ def configure_warnings(): def configure_loggers(config_broker): syslog_handler = syslog_logger_setup(config_broker) - force_level = config_broker.get("logging.force") - if force_level: + if force_level := config_broker.get("logging.force"): LOG_LEVELS = {"": force_level} else: LOG_LEVELS = { @@ -168,8 +166,7 @@ def configure_loggers(config_broker): if config_broker.get("logging.warnings", "ignore") == "error": warnings.simplefilter("error", append=True) - log_format = config_broker.get("logging.format", "verbose") - if log_format == "compact": + if (log_format := config_broker.get("logging.format", "verbose")) == "compact": set_default_formatter(COMPACT_FORMATTER) elif log_format == "json": set_default_formatter( diff --git a/src/python/zigopt/membership/service.py b/src/python/zigopt/membership/service.py index 8aa7fd44..968ba973 100644 --- a/src/python/zigopt/membership/service.py +++ b/src/python/zigopt/membership/service.py @@ -171,8 +171,7 @@ def insert(self, user_id: int, organization_id: int, **kwargs) -> Membership: def create_if_not_exists( self, user_id: int, organization_id: int, membership_type: MembershipType | None = None, **kwargs ) -> Membership: - existing = self.find_by_user_and_organization(user_id, organization_id) - if existing: + if existing := self.find_by_user_and_organization(user_id, organization_id): return existing return self.insert(user_id=user_id, organization_id=organization_id, membership_type=membership_type, **kwargs) diff --git a/src/python/zigopt/net/errors.py b/src/python/zigopt/net/errors.py index 85328413..f4ca28a5 100644 --- a/src/python/zigopt/net/errors.py +++ b/src/python/zigopt/net/errors.py @@ -112,8 +112,7 @@ def __init__(self, msg=None, token_status=None): class EndpointNotFoundError(NotFoundError): def __init__(self, path, msg=None): - parts = [p for p in path.lstrip("/").split("/") if p] - if parts: + if parts := [p for p in path.lstrip("/").split("/") if p]: if parts[0] == "v1": msg = f"Endpoint not found: {path}" else: diff --git a/src/python/zigopt/observation/from_json.py b/src/python/zigopt/observation/from_json.py index aead3db5..e1a34c5d 100644 --- a/src/python/zigopt/observation/from_json.py +++ b/src/python/zigopt/observation/from_json.py @@ -68,8 +68,7 @@ def set_observation_data_assignments_map_from_json(observation_data, json_dict, def set_observation_data_task_from_json(observation_data, json_dict, experiment): - task = extract_task_from_json(experiment, json_dict) - if task: + if task := extract_task_from_json(experiment, json_dict): observation_data.task.CopyFrom(task) diff --git a/src/python/zigopt/observation/model.py b/src/python/zigopt/observation/model.py index c6ed73b5..56b33f64 100644 --- a/src/python/zigopt/observation/model.py +++ b/src/python/zigopt/observation/model.py @@ -114,8 +114,7 @@ def value_for_maximization(self, experiment, name): return self.data_proxy.value_for_maximization(experiment, name) def _within_metric_threshold(self, metric, experiment): - metric_value = self.metric_value(experiment, metric.name) - if metric_value is None: + if (metric_value := self.metric_value(experiment, metric.name)) is None: return False if metric.threshold is None: return True diff --git a/src/python/zigopt/optimization_aux/service.py b/src/python/zigopt/optimization_aux/service.py index ca47b6d9..a0dcf949 100644 --- a/src/python/zigopt/optimization_aux/service.py +++ b/src/python/zigopt/optimization_aux/service.py @@ -120,8 +120,7 @@ def persist_hyperparameters( hyperparameters, current_aux_date_updated, ): - experiment = self.services.experiment_service.find_by_id(experiment.id) - if experiment: + if experiment := self.services.experiment_service.find_by_id(experiment.id): self.services.database_service.upsert( ExperimentOptimizationAux( experiment_id=experiment.id, diff --git a/src/python/zigopt/optimize/queue.py b/src/python/zigopt/optimize/queue.py index 6c99ea7e..569388bb 100644 --- a/src/python/zigopt/optimize/queue.py +++ b/src/python/zigopt/optimize/queue.py @@ -62,8 +62,7 @@ def _create_importances_message(self, experiment, num_observations, force): @generator_to_list def _maybe_enqueue_importances(self, experiment, num_observations, force): - message = self._create_importances_message(experiment, num_observations, force) - if message is not None: + if (message := self._create_importances_message(experiment, num_observations, force)) is not None: if self.services.importances_service.should_update_importances(experiment, num_observations) or force: yield message diff --git a/src/python/zigopt/optimize/sources/base.py b/src/python/zigopt/optimize/sources/base.py index 911e3f3c..4208bcb9 100644 --- a/src/python/zigopt/optimize/sources/base.py +++ b/src/python/zigopt/optimize/sources/base.py @@ -163,21 +163,17 @@ def execute_gp_hyper_opt_call_based_on_lag(self, num_successful_observations): ] # 1 250 500 750 1000 1250 1500, N bounds - dimension_index = self._get_dimension_index(hyper_opt_dimension) - if dimension_index is None: + if (dimension_index := self._get_dimension_index(hyper_opt_dimension)) is None: return False - num_observed_index = self._get_num_observed_index(num_successful_observations) - if num_observed_index is None: + if (num_observed_index := self._get_num_observed_index(num_successful_observations)) is None: return False lag = lag_matrix[dimension_index][num_observed_index] - should_stagger_hyperopt = self.experiment.constraints - if should_stagger_hyperopt: + if should_stagger_hyperopt := self.experiment.constraints: lag = self._get_lag(hyper_opt_dimension) - lag_scale = self._get_lag_scale(num_successful_observations) - if lag_scale is None: + if (lag_scale := self._get_lag_scale(num_successful_observations)) is None: return False lag = int(lag_scale * lag) diff --git a/src/python/zigopt/optimize/sources/categorical.py b/src/python/zigopt/optimize/sources/categorical.py index a422a279..cb827d2b 100644 --- a/src/python/zigopt/optimize/sources/categorical.py +++ b/src/python/zigopt/optimize/sources/categorical.py @@ -112,8 +112,7 @@ def should_force_default_hyperparameters(self, optimization_args, hyperparameter def extract_hyperparameter_dict(self, optimization_args): multimetric_hyperparameter_dicts = [] - multimetric_hyperparameters = optimization_args.old_hyperparameters - if multimetric_hyperparameters is None: + if (multimetric_hyperparameters := optimization_args.old_hyperparameters) is None: return self.default_hyperparameter_dict(optimization_args) for hp in sorted(multimetric_hyperparameters.multimetric_hyperparameter_value, key=lambda val: val.metric_name): cat_hp_dict = self.extract_cat_hyperparameter_dict(optimization_args, hp.categorical_hyperparameters) diff --git a/src/python/zigopt/organization/service.py b/src/python/zigopt/organization/service.py index 10c1d2f5..ee3e3fea 100644 --- a/src/python/zigopt/organization/service.py +++ b/src/python/zigopt/organization/service.py @@ -81,8 +81,7 @@ def insert(self, organization: Organization) -> Organization: def delete(self, organization: Organization) -> bool: date_deleted = current_datetime() - did_update = self.delete_by_id(organization.id, date_deleted=date_deleted) - if did_update: + if did_update := self.delete_by_id(organization.id, date_deleted=date_deleted): organization.date_deleted = date_deleted return did_update @@ -206,8 +205,7 @@ def merge_organizations_into_destination( ) for i in invites_to_organization: - existing_invite = self.services.invite_service.find_by_email_and_organization(i.email, dest_organization.id) - if existing_invite: + if existing_invite := self.services.invite_service.find_by_email_and_organization(i.email, dest_organization.id): self.services.invite_service.delete_by_id(i.id) else: self.services.database_service.update( diff --git a/src/python/zigopt/pagination/lib.py b/src/python/zigopt/pagination/lib.py index fbb5277d..c4edad63 100644 --- a/src/python/zigopt/pagination/lib.py +++ b/src/python/zigopt/pagination/lib.py @@ -47,8 +47,7 @@ class DefinedField: def get_value_of_paging_symbol(symbol): assert isinstance(symbol, PagingSymbol) - which_oneof = symbol.WhichOneof("type") - if which_oneof is None: + if (which_oneof := symbol.WhichOneof("type")) is None: return None if which_oneof == "null_value": return None diff --git a/src/python/zigopt/parameters/from_json.py b/src/python/zigopt/parameters/from_json.py index 5bfce00d..45949d7b 100644 --- a/src/python/zigopt/parameters/from_json.py +++ b/src/python/zigopt/parameters/from_json.py @@ -84,8 +84,7 @@ def set_parameter_type_from_json(parameter, parameter_json): def set_bounds_from_json(parameter, parameter_json, experiment_type): - grid_values = get_opt_with_validation(parameter_json, "grid", ValidationType.arrayOf(ValidationType.number)) - if grid_values: + if grid_values := get_opt_with_validation(parameter_json, "grid", ValidationType.arrayOf(ValidationType.number)): if parameter_json.get("bounds") is not None: raise InvalidKeyError( "Parameters with grid cannot have bounds on parameters. They are inferred from the supplied grid." @@ -269,8 +268,7 @@ def set_grid_values_from_json(parameter, parameter_json): def set_default_value_from_json(parameter, parameter_json): - default_value = get_opt_with_validation(parameter_json, "default_value", ValidationType.assignment) - if default_value is not None: + if (default_value := get_opt_with_validation(parameter_json, "default_value", ValidationType.assignment)) is not None: value_to_set = get_assignment(parameter, default_value) parameter.replacement_value_if_missing = value_to_set @@ -308,8 +306,8 @@ def set_parameter_conditions_from_json(parameter, parameter_json, conditionals_m values_json = values_json if is_sequence(values_json) else [values_json] values = [] for value_json in values_json: - conditional_value = find(conditional.values, lambda c: c.name == value_json) # pylint: disable=cell-var-from-loop - if conditional_value is None: +# pylint: disable=cell-var-from-loop + if (conditional_value := find(conditional.values, lambda c: c.name == value_json)) is None: raise SigoptValidationError( f"Conditional {condition.name} on parameter {parameter.name} attempted to use non-existent value {value_json}" ) @@ -322,8 +320,7 @@ def set_transformation_from_json(parameter, parameter_json): if parameter_json.get("transformation") is not None: raise SigoptValidationError("Transformation is only valid for parameters of type `double`") - transformation_string = get_opt_with_validation(parameter_json, "transformation", ValidationType.string) - if transformation_string: + if transformation_string := get_opt_with_validation(parameter_json, "transformation", ValidationType.string): try: parameter.transformation = PARAMETER_TRANSFORMATION_NAME_TO_TYPE[transformation_string] except KeyError as e: diff --git a/src/python/zigopt/permission/service.py b/src/python/zigopt/permission/service.py index 9ae79831..5c54d81d 100644 --- a/src/python/zigopt/permission/service.py +++ b/src/python/zigopt/permission/service.py @@ -153,11 +153,10 @@ def upsert( requestor: User, role_for_logging: str, ) -> Permission: - membership = self.services.membership_service.find_by_user_and_organization( + if (membership := self.services.membership_service.find_by_user_and_organization( user_id=user.id, organization_id=client.organization_id, - ) - if membership is None: + )) is None: raise ValueError("Permission cannot be created without a membership.") if user and user.id and client and client.id: client_permissions = self.find_by_client_id(client.id) diff --git a/src/python/zigopt/protobuf/json.py b/src/python/zigopt/protobuf/json.py index fe59f14f..e262ae47 100644 --- a/src/python/zigopt/protobuf/json.py +++ b/src/python/zigopt/protobuf/json.py @@ -37,14 +37,12 @@ def __call__(self, serialized_value=_NO_ARG): if serialized_value is _NO_ARG: return self.descriptor.default_value assert is_string(serialized_value) - enum_value = self.descriptor.enum_type.values_by_name.get(serialized_value) - if enum_value is None: + if (enum_value := self.descriptor.enum_type.values_by_name.get(serialized_value)) is None: raise ValueError(f"Unknown name for enum {self.descriptor.enum_type.name}: {serialized_value}") return enum_value.number def serialize(self, value): - serialized_value = self.descriptor.enum_type.values_by_number.get(value) - if serialized_value is None: + if (serialized_value := self.descriptor.enum_type.values_by_number.get(value)) is None: raise ValueError(f"Unknown number for enum {self.descriptor.enum_type.name}: {value}") return serialized_value.name @@ -67,7 +65,7 @@ def next_descriptor_for_field_descriptor(descriptor): if descriptor.type == FieldDescriptor.TYPE_ENUM: return OurEnumDescriptor(descriptor) - python_type = ( + if (python_type := ( { FieldDescriptor.TYPE_BOOL: bool, FieldDescriptor.TYPE_DOUBLE: float, @@ -84,8 +82,7 @@ def next_descriptor_for_field_descriptor(descriptor): FieldDescriptor.TYPE_UINT32: int, FieldDescriptor.TYPE_UINT64: int, } - ).get(descriptor.type) - if python_type is None: + ).get(descriptor.type)) is None: raise NotImplementedError(f"Unknown message type: {descriptor.type}") return python_type @@ -101,12 +98,10 @@ def field_descriptor_to_scalar_descriptor(field_descriptor): def get_json_key_from_field_descriptor(descriptor, key): assert isinstance(descriptor, FieldDescriptor) - is_array_access = is_integer(key) - if is_array_access: + if is_array_access := is_integer(key): if descriptor.label != FieldDescriptor.LABEL_REPEATED: raise InvalidPathError(f"{key} is not a repeated field for {descriptor.full_name}") - reached_end = descriptor.message_type is None - if reached_end: + if reached_end := descriptor.message_type is None: return field_descriptor_to_scalar_descriptor(descriptor) return descriptor.message_type if descriptor.label == FieldDescriptor.LABEL_REPEATED: @@ -148,8 +143,7 @@ def get_json_key(descriptor, key, json=False): def _validate_array(value, descriptor, is_emit): - is_field_descriptor = isinstance(descriptor, FieldDescriptor) - if is_field_descriptor: + if is_field_descriptor := isinstance(descriptor, FieldDescriptor): is_repeated_field = (descriptor.label == FieldDescriptor.LABEL_REPEATED) and not IsMapEntry(descriptor) is_array = is_sequence(value) if is_array ^ is_repeated_field: @@ -176,8 +170,7 @@ def is_valid_scalar_descriptor_for_value(value, descriptor): def emit_json_with_descriptor(value, descriptor): # pylint: disable=too-many-return-statements - is_array = _validate_array(value, descriptor, is_emit=True) - if is_array: + if is_array := _validate_array(value, descriptor, is_emit=True): next_descriptor = next_descriptor_for_field_descriptor(descriptor) return [emit_json_with_descriptor(v, next_descriptor) for v in value] if isinstance(descriptor, Descriptor): @@ -206,8 +199,7 @@ def emit_json_with_descriptor(value, descriptor): def parse_json_with_descriptor(value, descriptor, ignore_unknown_fields): # pylint: disable=too-many-return-statements - is_array = _validate_array(value, descriptor, is_emit=False) - if is_array: + if is_array := _validate_array(value, descriptor, is_emit=False): next_descriptor = next_descriptor_for_field_descriptor(descriptor) return [parse_json_with_descriptor(v, next_descriptor, ignore_unknown_fields) for v in value] if isinstance(descriptor, Descriptor): diff --git a/src/python/zigopt/queue/monitor.py b/src/python/zigopt/queue/monitor.py index 07649c39..79ef833f 100644 --- a/src/python/zigopt/queue/monitor.py +++ b/src/python/zigopt/queue/monitor.py @@ -97,8 +97,7 @@ def _get_status(self, queue_name): def _update_status(self, queue_name, status): with self.services.exception_logger.tolerate_exceptions(RedisServiceError): - status = remove_nones_mapping(status) - if status: + if status := remove_nones_mapping(status): self.services.redis_service.set_hash_fields(self._queue_monitor_name(queue_name), status) def _can_send_alert(self, status, now): diff --git a/src/python/zigopt/queue/redis/message.py b/src/python/zigopt/queue/redis/message.py index 47ef650c..c800f179 100644 --- a/src/python/zigopt/queue/redis/message.py +++ b/src/python/zigopt/queue/redis/message.py @@ -34,8 +34,7 @@ def enqueue(self, queue_messages, group_key, enqueue_time, message_score): def dequeue(self, wait_time_seconds=None): wait_time_seconds = coalesce(wait_time_seconds, self.wait_time_seconds) - redis_body = self._pop_from_queue(self.redis_key, wait_time_seconds) - if redis_body is None: + if (redis_body := self._pop_from_queue(self.redis_key, wait_time_seconds)) is None: return redis_body message_with_name = self._parse_redis_body(redis_body) queue_message = self.services.message_router.deserialize_message( diff --git a/src/python/zigopt/queue/router.py b/src/python/zigopt/queue/router.py index 35d1c7c5..5e5e5db1 100644 --- a/src/python/zigopt/queue/router.py +++ b/src/python/zigopt/queue/router.py @@ -39,16 +39,14 @@ def get_worker_class_for_message(cls, message_type): @classmethod def deserialize_message(cls, message_type, serialized_body): - WorkerClass = cls.get_worker_class_for_message(message_type) - if WorkerClass: + if WorkerClass := cls.get_worker_class_for_message(message_type): message_body = WorkerClass.MessageBody.deserialize(serialized_body) return QueueMessage(message_type, message_body) return None @classmethod def make_queue_message(cls, _message_type, *args, **kwargs): - WorkerClass = cls.get_worker_class_for_message(_message_type) - if WorkerClass is None: + if (WorkerClass := cls.get_worker_class_for_message(_message_type)) is None: raise Exception( f"Could not find a worker for {_message_type}" f" Please make sure a worker has this MESSAGE_TYPE and has been added to {__file__}:WORKER_CLASSES." diff --git a/src/python/zigopt/queue/workers.py b/src/python/zigopt/queue/workers.py index bb46e55e..73a6ad76 100644 --- a/src/python/zigopt/queue/workers.py +++ b/src/python/zigopt/queue/workers.py @@ -72,8 +72,7 @@ def __init__(self, message_group, global_services, request_local_services_factor self._get_pull_queue_name() def _get_pull_queue_name(self): - queue_name = self.global_services.message_router.get_queue_name_from_message_group(self.message_group) - if queue_name is None: + if (queue_name := self.global_services.message_router.get_queue_name_from_message_group(self.message_group)) is None: raise Exception(f"Missing queue name for {self.message_group}") return queue_name @@ -93,11 +92,10 @@ def _check_stop_conditions(self, base_max_messages): self.logger.info("QueueWorkers killed by kill policy") raise WorkerKilledException() - max_messages_threshold = coalesce( + if (max_messages_threshold := coalesce( base_max_messages, self.global_services.config_broker.get(f"queue.{self.message_group.value}.max_messages"), - ) - if max_messages_threshold is None: + )) is None: should_process = True else: max_messages_threshold = max(max_messages_threshold, self.jitter * 2) diff --git a/src/python/zigopt/queued_suggestion/service.py b/src/python/zigopt/queued_suggestion/service.py index 338ce54b..69c220ac 100644 --- a/src/python/zigopt/queued_suggestion/service.py +++ b/src/python/zigopt/queued_suggestion/service.py @@ -39,8 +39,7 @@ def find_by_id(self, experiment_id: int, queued_id: int, include_deleted: bool = return self.services.database_service.one_or_none(q) def delete_by_id(self, experiment_id: int, queued_id: int) -> None: - suggestion = self.find_by_id(experiment_id, queued_id) - if suggestion: + if suggestion := self.find_by_id(experiment_id, queued_id): meta: SuggestionMeta = copy_protobuf(suggestion.meta) meta.deleted = True self.services.database_service.update_one( diff --git a/src/python/zigopt/redis/service.py b/src/python/zigopt/redis/service.py index be7cf2d3..3569b55f 100644 --- a/src/python/zigopt/redis/service.py +++ b/src/python/zigopt/redis/service.py @@ -227,8 +227,7 @@ def _make_redis_unix_socket(cls, config: Mapping[str, Any]) -> redis.Redis: @classmethod def make_redis(cls, config: Mapping[str, Any], **kwargs) -> redis.Redis: config = extend_dict({}, config, kwargs) - mode = config.get("connection_mode", "socket") - if mode == "socket": + if (mode := config.get("connection_mode", "socket")) == "socket": ret = cls._make_redis_unix_socket(config) else: assert mode == "tcp" diff --git a/src/python/zigopt/sigoptcompute/adapter.py b/src/python/zigopt/sigoptcompute/adapter.py index 95bb5c26..4c5c41ad 100644 --- a/src/python/zigopt/sigoptcompute/adapter.py +++ b/src/python/zigopt/sigoptcompute/adapter.py @@ -581,8 +581,7 @@ def pe_parameter_info(p): return var_type, elements, name def parameter_to_prior_info(p): - prior_type = p.prior.prior_type if p.prior.HasField("prior_type") else None - if prior_type == Prior.NORMAL: + if (prior_type := p.prior.prior_type if p.prior.HasField("prior_type") else None) == Prior.NORMAL: assert p.is_double name = ParameterPriorNames.NORMAL params = { diff --git a/src/python/zigopt/suggestion/broker/base.py b/src/python/zigopt/suggestion/broker/base.py index c3ef7162..1b05d468 100644 --- a/src/python/zigopt/suggestion/broker/base.py +++ b/src/python/zigopt/suggestion/broker/base.py @@ -19,10 +19,9 @@ class BaseBroker(Service): def serve_suggestion(self, experiment, processed_suggestion_meta, auth, automatic=False): - queued_suggestion = self.retrieve_queued_suggestions_if_exists( + if (queued_suggestion := self.retrieve_queued_suggestions_if_exists( experiment, processed_suggestion_meta, automatic=automatic - ) - if queued_suggestion is not None: + )) is not None: return queued_suggestion next_suggestion = self.next_suggestion(experiment, processed_suggestion_meta, automatic=automatic) diff --git a/src/python/zigopt/suggestion/service.py b/src/python/zigopt/suggestion/service.py index 11666151..0d296097 100644 --- a/src/python/zigopt/suggestion/service.py +++ b/src/python/zigopt/suggestion/service.py @@ -77,8 +77,7 @@ def find_open_by_experiment( unprocessed_suggestions_by_id = to_map_by_key(unprocessed_suggestions, lambda u: u.id) for p in processed_suggestions: - u = unprocessed_suggestions_by_id.get(p.suggestion_id) - if u: + if u := unprocessed_suggestions_by_id.get(p.suggestion_id): yield Suggestion(processed=p, unprocessed=u) else: self.services.exception_logger.soft_exception( diff --git a/src/python/zigopt/suggestion/unprocessed/service.py b/src/python/zigopt/suggestion/unprocessed/service.py index 23d9c766..7454f2cc 100644 --- a/src/python/zigopt/suggestion/unprocessed/service.py +++ b/src/python/zigopt/suggestion/unprocessed/service.py @@ -81,8 +81,7 @@ def delete_all_for_experiment(self, experiment: Experiment) -> None: self.services.database_service.update_all(UnprocessedSuggestion, updated_suggestions) def delete_by_id(self, experiment: Experiment, suggestion_id: int) -> None: - suggestion = self.find_by_id(suggestion_id) - if suggestion: + if suggestion := self.find_by_id(suggestion_id): new_meta: SuggestionMeta = copy_protobuf(suggestion.suggestion_meta) new_meta.deleted = True self.services.database_service.update_one( @@ -130,8 +129,7 @@ def count_by_experiment(self, experiment: Experiment, include_deleted: bool = Fa ) def insert_suggestions_to_be_processed(self, generated_suggestions: Sequence[UnprocessedSuggestion]) -> None: - generated_suggestions = list(generated_suggestions) - if generated_suggestions: + if generated_suggestions := list(generated_suggestions): generated_ids = self.services.database_service.reserve_ids( SUGGESTIONS_ID_SEQUENCE_NAME, len(generated_suggestions), @@ -200,8 +198,7 @@ def _truncate_suggestion_length(self, experiment_id: int, source: int, num_to_ke source, ) # suggestions are sorted by time created/added; this grabs the oldest ones - suggestions_to_drop = self.services.redis_service.get_sorted_set_range(suggestion_timestamp_key, 0, stop_index) - if suggestions_to_drop: + if suggestions_to_drop := self.services.redis_service.get_sorted_set_range(suggestion_timestamp_key, 0, stop_index): suggestion_protobuf_key = self.services.redis_key_service.create_suggestion_protobuf_key(experiment_id, source) self.services.redis_service.remove_from_hash(suggestion_protobuf_key, *suggestions_to_drop) self.services.redis_service.remove_from_sorted_set(suggestion_timestamp_key, *suggestions_to_drop) diff --git a/src/python/zigopt/token/service.py b/src/python/zigopt/token/service.py index 4a78a64a..eedbac6c 100644 --- a/src/python/zigopt/token/service.py +++ b/src/python/zigopt/token/service.py @@ -66,16 +66,14 @@ def _make_meta(self, session_expiration: int | None, token_type: str, can_renew: napply(session_expiration, lambda s: max(s - now, 0)), self.services.config_broker.get("external_authorization.token_ttl_seconds"), ] - ttl_seconds = min_option(remove_nones_sequence(ttl_options)) - if ttl_seconds is not None: + if (ttl_seconds := min_option(remove_nones_sequence(ttl_options))) is not None: meta.ttl_seconds = ttl_seconds return meta def _get_or_create_role_token(self, client_id: int, user_id: int, development: bool) -> Token: assert client_id is not None assert user_id is not None - existing = [token for token in self.find_by_client_and_user(client_id, user_id) if token.development == development] - if existing: + if existing := [token for token in self.find_by_client_and_user(client_id, user_id) if token.development == development]: return existing[0] token_type = TokenType.CLIENT_DEV if development else TokenType.CLIENT_API meta = self._make_meta(session_expiration=None, token_type=token_type, can_renew=False) @@ -112,8 +110,7 @@ def get_or_create_client_signup_token(self, client_id: int, creating_user_id: in assert client_id is not None assert creating_user_id is not None token_type = TokenType.GUEST - existing = self.get_client_signup_token(client_id, creating_user_id=creating_user_id) - if existing: + if existing := self.get_client_signup_token(client_id, creating_user_id=creating_user_id): return existing meta = self._make_meta(session_expiration=None, token_type=token_type, can_renew=False) meta.creating_user_id = creating_user_id @@ -182,11 +179,10 @@ def _create_user_token(self, user_id: int, session_expiration: int | None, can_r def renew_token(self, token: Token) -> Token | None: now = unix_timestamp() - updated = self.services.database_service.update_one_or_none( + if updated := self.services.database_service.update_one_or_none( self.services.database_service.query(Token).filter(Token.token == token.token).filter(~~Token.meta.can_renew), {Token.meta: jsonb_set(Token.meta, JsonPath(*unwind_json_path(Token.meta.date_renewed)), now)}, - ) - if updated: + ): meta: TokenMeta = copy_protobuf(token.meta) meta.date_renewed = now token.meta = meta @@ -196,11 +192,10 @@ def renew_token(self, token: Token) -> Token | None: def rotate_token(self, token: Token) -> Token | None: token_string = random_string() - updated = self.services.database_service.update_one_or_none( - self.services.database_service.query(Token).filter(Token.token == token.token), {Token.token: token_string} - ) - if updated: + if updated := self.services.database_service.update_one_or_none( + self.services.database_service.query(Token).filter(Token.token == token.token), {Token.token: token_string} + ): token.token = token_string return token return None diff --git a/src/python/zigopt/training_run/service.py b/src/python/zigopt/training_run/service.py index 266b31d4..2a43b9f6 100644 --- a/src/python/zigopt/training_run/service.py +++ b/src/python/zigopt/training_run/service.py @@ -68,8 +68,7 @@ def set_deleted(self, training_run_id: int, deleted: bool = True) -> None: ) if not tr: return - exp = self.services.experiment_service.find_by_id(tr.experiment_id, include_deleted=True) - if exp: + if exp := self.services.experiment_service.find_by_id(tr.experiment_id, include_deleted=True): if tr.suggestion_id: self.services.processed_suggestion_service.set_delete_by_ids(exp, [tr.suggestion_id], deleted=deleted) if tr.observation_id: @@ -100,7 +99,7 @@ def _readable_name(self, field): # pylint: disable=too-many-return-statements if "." not in field.name: return field.name - readable_name = ( + if readable_name := ( { "logs.stdout.content": "Output Logs", "logs.stderr.content": "Error Logs", @@ -108,8 +107,7 @@ def _readable_name(self, field): "source_code.content": "Source Code", "model.type": "Model Type", } - ).get(field.name) - if readable_name: + ).get(field.name): return readable_name name_parts = field.name.split(".") diff --git a/src/python/zigopt/user/service.py b/src/python/zigopt/user/service.py index c14a9592..a75003ad 100644 --- a/src/python/zigopt/user/service.py +++ b/src/python/zigopt/user/service.py @@ -78,8 +78,7 @@ def set_password_reset_code(self, user: User) -> str: return code def change_user_email_without_save(self, user: User, new_email: str) -> str: - existing_user = self.services.user_service.find_by_email(new_email) - if existing_user is not None: + if (existing_user := self.services.user_service.find_by_email(new_email)) is not None: raise SigoptValidationError("Unable to change email.") email_verification_code = self.services.email_verification_service.set_email_verification_code_without_save(user) From 20a69b6f2af03e44c1f6e89781309b17cf386d44 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 20 Feb 2024 22:17:44 +0000 Subject: [PATCH 2/2] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/python/zigopt/experiment/segmenter.py | 2 +- src/python/zigopt/handlers/clients/tokens.py | 20 ++++++---- .../zigopt/handlers/experiments/create.py | 22 ++++++----- .../experiments/metric_importances/update.py | 10 +++-- .../experiments/observations/create.py | 8 ++-- .../experiments/queued_suggestions/base.py | 4 +- .../experiments/suggestions/update.py | 8 ++-- .../zigopt/handlers/experiments/update.py | 22 ++++++----- src/python/zigopt/handlers/projects/base.py | 10 +++-- .../zigopt/handlers/training_runs/tags.py | 20 ++++++---- src/python/zigopt/handlers/web_data/delete.py | 8 ++-- src/python/zigopt/handlers/web_data/update.py | 9 +++-- src/python/zigopt/organization/service.py | 4 +- src/python/zigopt/parameters/from_json.py | 2 +- src/python/zigopt/permission/service.py | 10 +++-- src/python/zigopt/protobuf/json.py | 38 ++++++++++--------- src/python/zigopt/queue/workers.py | 14 ++++--- src/python/zigopt/suggestion/broker/base.py | 8 ++-- src/python/zigopt/token/service.py | 5 ++- 19 files changed, 133 insertions(+), 91 deletions(-) diff --git a/src/python/zigopt/experiment/segmenter.py b/src/python/zigopt/experiment/segmenter.py index a97badf9..85530afd 100644 --- a/src/python/zigopt/experiment/segmenter.py +++ b/src/python/zigopt/experiment/segmenter.py @@ -68,7 +68,7 @@ def prune_intervals(self, experiment, with_assignments_maps, intervals): for has_assignments_map in with_assignments_maps: for name, assignment in has_assignments_map.get_assignments(experiment).items(): remaining_intervals = intervals.get(name, []) -# pylint: disable=cell-var-from-loop + # pylint: disable=cell-var-from-loop if to_remove := find(remaining_intervals, lambda i: assignment in i): remaining_intervals.remove(to_remove) diff --git a/src/python/zigopt/handlers/clients/tokens.py b/src/python/zigopt/handlers/clients/tokens.py index 446ab66f..9dd46d78 100644 --- a/src/python/zigopt/handlers/clients/tokens.py +++ b/src/python/zigopt/handlers/clients/tokens.py @@ -187,10 +187,12 @@ def ensure_includes_role_token(self, tokens): auth = self.auth assert self.client is not None client = self.client - if (role_token := find( - tokens, - lambda t: t.user_id == auth.current_user.id and t.client_id == client.id and t.development is False, - )) is None: + if ( + role_token := find( + tokens, + lambda t: t.user_id == auth.current_user.id and t.client_id == client.id and t.development is False, + ) + ) is None: role_token = self.services.token_service.get_or_create_role_token( self.client.id, self.auth.current_user.id, @@ -204,10 +206,12 @@ def ensure_includes_development_role_token(self, tokens): assert current_user is not None assert self.client is not None client = self.client - if (development_token := find( - tokens, - lambda t: t.user_id == current_user.id and t.client_id == client.id and t.development is True, - )) is None: + if ( + development_token := find( + tokens, + lambda t: t.user_id == current_user.id and t.client_id == client.id and t.development is True, + ) + ) is None: development_token = self.services.token_service.get_or_create_development_role_token( self.client.id, self.auth.current_user.id, diff --git a/src/python/zigopt/handlers/experiments/create.py b/src/python/zigopt/handlers/experiments/create.py index f8605054..6e103b7a 100644 --- a/src/python/zigopt/handlers/experiments/create.py +++ b/src/python/zigopt/handlers/experiments/create.py @@ -280,10 +280,12 @@ def make_experiment_meta_from_json( has_constraint_metrics = any(m.strategy == ExperimentMetric.CONSTRAINT for m in experiment_meta.metrics) has_optimization_metrics = len(optimized_metrics) > 0 - if (num_solutions := cls.get_num_solutions_from_json( - json_dict, - experiment_meta.all_parameters_unsorted, - )) is not None: + if ( + num_solutions := cls.get_num_solutions_from_json( + json_dict, + experiment_meta.all_parameters_unsorted, + ) + ) is not None: experiment_meta.num_solutions = num_solutions if (parallel_bandwidth := cls.get_parallel_bandwidth_from_json(json_dict)) is not None: @@ -379,11 +381,13 @@ def get_metric_strategy(cls, metric): @classmethod def get_metric_list_from_json(cls, json_dict): - if (metrics := get_opt_with_validation( - json_dict, - "metrics", - ValidationType.arrayOf(ValidationType.oneOf([ValidationType.string, ValidationType.object])), - )) is None: + if ( + metrics := get_opt_with_validation( + json_dict, + "metrics", + ValidationType.arrayOf(ValidationType.oneOf([ValidationType.string, ValidationType.object])), + ) + ) is None: assert MAX_METRICS_ANY_STRATEGY >= 1 assert MAX_OPTIMIZED_METRICS >= 1 return [ExperimentMetric()] diff --git a/src/python/zigopt/handlers/experiments/metric_importances/update.py b/src/python/zigopt/handlers/experiments/metric_importances/update.py index 313ff27b..d4f381a1 100644 --- a/src/python/zigopt/handlers/experiments/metric_importances/update.py +++ b/src/python/zigopt/handlers/experiments/metric_importances/update.py @@ -16,10 +16,12 @@ def handle(self): assert self.experiment is not None num_observations = self.services.observation_service.count_by_experiment(self.experiment) - if (q_msg := self.services.optimize_queue_service.always_enqueue_importances( - experiment=self.experiment, - num_observations=num_observations, - )) is None: + if ( + q_msg := self.services.optimize_queue_service.always_enqueue_importances( + experiment=self.experiment, + num_observations=num_observations, + ) + ) is None: raise UnprocessableEntityError( "Parameter importances update failed. (This experiment may not support importances.)" ) diff --git a/src/python/zigopt/handlers/experiments/observations/create.py b/src/python/zigopt/handlers/experiments/observations/create.py index c97d0bc8..04f79df5 100644 --- a/src/python/zigopt/handlers/experiments/observations/create.py +++ b/src/python/zigopt/handlers/experiments/observations/create.py @@ -70,9 +70,11 @@ def observation_from_json( elif observation.timestamp: observation_data.timestamp = observation.timestamp - if (client_provided_data := BaseExperimentsCreateHandler.get_client_provided_data( - json_dict, default=observation.client_provided_data - )) is not None: + if ( + client_provided_data := BaseExperimentsCreateHandler.get_client_provided_data( + json_dict, default=observation.client_provided_data + ) + ) is not None: observation_data.client_provided_data = client_provided_data else: if observation_data.HasField("client_provided_data"): diff --git a/src/python/zigopt/handlers/experiments/queued_suggestions/base.py b/src/python/zigopt/handlers/experiments/queued_suggestions/base.py index 8059d3dd..e76e9935 100644 --- a/src/python/zigopt/handlers/experiments/queued_suggestions/base.py +++ b/src/python/zigopt/handlers/experiments/queued_suggestions/base.py @@ -30,7 +30,9 @@ def can_act_on_objects(self, requested_permission, objects): ) def _find_queued_suggestion(self, queued_suggestion_id): - if (queued_suggestion := self.services.queued_suggestion_service.find_by_id(self.experiment_id, queued_suggestion_id)) is not None: + if ( + queued_suggestion := self.services.queued_suggestion_service.find_by_id(self.experiment_id, queued_suggestion_id) + ) is not None: if queued_suggestion.experiment_id == self.experiment_id: return queued_suggestion raise NotFoundError(f"No QueuedSuggestion {queued_suggestion_id} for experiment {self.experiment_id}") diff --git a/src/python/zigopt/handlers/experiments/suggestions/update.py b/src/python/zigopt/handlers/experiments/suggestions/update.py index f7a047f9..9bf17113 100644 --- a/src/python/zigopt/handlers/experiments/suggestions/update.py +++ b/src/python/zigopt/handlers/experiments/suggestions/update.py @@ -22,9 +22,11 @@ def handle(self, json_dict): assert self.suggestion is not None suggestion_meta = ProcessedSuggestionMeta() - if (client_provided_data := BaseExperimentsCreateHandler.get_client_provided_data( - json_dict, default=self.suggestion.client_provided_data - )) is not None: + if ( + client_provided_data := BaseExperimentsCreateHandler.get_client_provided_data( + json_dict, default=self.suggestion.client_provided_data + ) + ) is not None: suggestion_meta.client_provided_data = client_provided_data processed = self.suggestion.processed diff --git a/src/python/zigopt/handlers/experiments/update.py b/src/python/zigopt/handlers/experiments/update.py index 67426717..da1f09fe 100644 --- a/src/python/zigopt/handlers/experiments/update.py +++ b/src/python/zigopt/handlers/experiments/update.py @@ -354,10 +354,12 @@ def handle(self, json_dict): update_experiment_fields["date_updated"] = current_datetime() self.experiment.date_updated = update_experiment_fields["date_updated"] - if (update_count := self.services.database_service.update_one( - self.services.database_service.query(Experiment).filter(Experiment.id == self.experiment.id), - update_experiment_fields, - )) == 0: + if ( + update_count := self.services.database_service.update_one( + self.services.database_service.query(Experiment).filter(Experiment.id == self.experiment.id), + update_experiment_fields, + ) + ) == 0: raise NotFoundError(f"No experiment {self.experiment.id}") if original_project_id is not None: @@ -526,11 +528,13 @@ def _maybe_set_parameter_grid_values(self, parameter, parameter_json): set_grid_values_from_json(parameter, parameter_json) def _maybe_set_parameter_categorical_values(self, parameter, parameter_json): - if (categorical_values_json := get_opt_with_validation( - parameter_json, - "categorical_values", - ValidationType.arrayOf(ValidationType.oneOf([ValidationType.object, ValidationType.string])), - )) is None: + if ( + categorical_values_json := get_opt_with_validation( + parameter_json, + "categorical_values", + ValidationType.arrayOf(ValidationType.oneOf([ValidationType.object, ValidationType.string])), + ) + ) is None: return categorical_values_map = dict((c.name, c) for c in parameter.all_categorical_values) diff --git a/src/python/zigopt/handlers/projects/base.py b/src/python/zigopt/handlers/projects/base.py index 5197cce1..2305a2f9 100644 --- a/src/python/zigopt/handlers/projects/base.py +++ b/src/python/zigopt/handlers/projects/base.py @@ -24,10 +24,12 @@ def find_objects(self): ) def _find_project(self): - if (project := self.services.project_service.find_by_client_and_reference_id( - client_id=self.client_id, - reference_id=self.project_reference_id, - )) is None: + if ( + project := self.services.project_service.find_by_client_and_reference_id( + client_id=self.client_id, + reference_id=self.project_reference_id, + ) + ) is None: raise NotFoundError(f"No project {self.project_reference_id} in client {self.client_id}") return project diff --git a/src/python/zigopt/handlers/training_runs/tags.py b/src/python/zigopt/handlers/training_runs/tags.py index 86b3819a..ceb117b8 100644 --- a/src/python/zigopt/handlers/training_runs/tags.py +++ b/src/python/zigopt/handlers/training_runs/tags.py @@ -53,10 +53,12 @@ def handle(self, params): tag_id = params[self.ID_PARAM] - if (tag := self.services.tag_service.find_by_client_and_id( - client_id=self.training_run.client_id, - tag_id=tag_id, - )) is None: + if ( + tag := self.services.tag_service.find_by_client_and_id( + client_id=self.training_run.client_id, + tag_id=tag_id, + ) + ) is None: raise UnprocessableEntityError( f"The tag with id {tag_id} cannot be added to this training run because it does not exist." ) @@ -88,10 +90,12 @@ def __init__(self, *args, tag_id, **kwargs): def find_objects(self): objs = super().find_objects() - if (tag := self.services.tag_service.find_by_client_and_id( - client_id=objs["training_run"].client_id, - tag_id=self.tag_id, - )) is None: + if ( + tag := self.services.tag_service.find_by_client_and_id( + client_id=objs["training_run"].client_id, + tag_id=self.tag_id, + ) + ) is None: raise NotFoundError("Tag not found") objs["tag"] = tag return objs diff --git a/src/python/zigopt/handlers/web_data/delete.py b/src/python/zigopt/handlers/web_data/delete.py index ed5e1809..8d75f750 100644 --- a/src/python/zigopt/handlers/web_data/delete.py +++ b/src/python/zigopt/handlers/web_data/delete.py @@ -26,9 +26,11 @@ def handle(self, params): web_data_id = params["id"] # Ensure the id is for the right resource - if (web_data := self.services.web_data_service.find_by_parent_resource_id_and_id( - parent_resource, web_data_type, parent_resource_id, web_data_id - )) is None: + if ( + web_data := self.services.web_data_service.find_by_parent_resource_id_and_id( + parent_resource, web_data_type, parent_resource_id, web_data_id + ) + ) is None: raise NotFoundError( f"Cannot find web data of type: {web_data_type}, with parent resource: {parent_resource} and id: {web_data_id}." ) diff --git a/src/python/zigopt/handlers/web_data/update.py b/src/python/zigopt/handlers/web_data/update.py index 565b8019..5c20e226 100644 --- a/src/python/zigopt/handlers/web_data/update.py +++ b/src/python/zigopt/handlers/web_data/update.py @@ -28,11 +28,12 @@ def handle(self, params): payload = params["payload"] web_data_id = params["id"] - # Web Data cannot change parent resoruce - if (old_web_data := self.services.web_data_service.find_by_parent_resource_id_and_id( - parent_resource_type, web_data_type, parent_resource_id, web_data_id - )) is None: + if ( + old_web_data := self.services.web_data_service.find_by_parent_resource_id_and_id( + parent_resource_type, web_data_type, parent_resource_id, web_data_id + ) + ) is None: raise ForbiddenError(f"{web_data_type} cannot be moved between {parent_resource_type}.") update_dict = { diff --git a/src/python/zigopt/organization/service.py b/src/python/zigopt/organization/service.py index ee3e3fea..de046e7d 100644 --- a/src/python/zigopt/organization/service.py +++ b/src/python/zigopt/organization/service.py @@ -205,7 +205,9 @@ def merge_organizations_into_destination( ) for i in invites_to_organization: - if existing_invite := self.services.invite_service.find_by_email_and_organization(i.email, dest_organization.id): + if existing_invite := self.services.invite_service.find_by_email_and_organization( + i.email, dest_organization.id + ): self.services.invite_service.delete_by_id(i.id) else: self.services.database_service.update( diff --git a/src/python/zigopt/parameters/from_json.py b/src/python/zigopt/parameters/from_json.py index 45949d7b..b98f140c 100644 --- a/src/python/zigopt/parameters/from_json.py +++ b/src/python/zigopt/parameters/from_json.py @@ -306,7 +306,7 @@ def set_parameter_conditions_from_json(parameter, parameter_json, conditionals_m values_json = values_json if is_sequence(values_json) else [values_json] values = [] for value_json in values_json: -# pylint: disable=cell-var-from-loop + # pylint: disable=cell-var-from-loop if (conditional_value := find(conditional.values, lambda c: c.name == value_json)) is None: raise SigoptValidationError( f"Conditional {condition.name} on parameter {parameter.name} attempted to use non-existent value {value_json}" diff --git a/src/python/zigopt/permission/service.py b/src/python/zigopt/permission/service.py index 5c54d81d..10b3c992 100644 --- a/src/python/zigopt/permission/service.py +++ b/src/python/zigopt/permission/service.py @@ -153,10 +153,12 @@ def upsert( requestor: User, role_for_logging: str, ) -> Permission: - if (membership := self.services.membership_service.find_by_user_and_organization( - user_id=user.id, - organization_id=client.organization_id, - )) is None: + if ( + membership := self.services.membership_service.find_by_user_and_organization( + user_id=user.id, + organization_id=client.organization_id, + ) + ) is None: raise ValueError("Permission cannot be created without a membership.") if user and user.id and client and client.id: client_permissions = self.find_by_client_id(client.id) diff --git a/src/python/zigopt/protobuf/json.py b/src/python/zigopt/protobuf/json.py index e262ae47..be48e464 100644 --- a/src/python/zigopt/protobuf/json.py +++ b/src/python/zigopt/protobuf/json.py @@ -65,24 +65,26 @@ def next_descriptor_for_field_descriptor(descriptor): if descriptor.type == FieldDescriptor.TYPE_ENUM: return OurEnumDescriptor(descriptor) - if (python_type := ( - { - FieldDescriptor.TYPE_BOOL: bool, - FieldDescriptor.TYPE_DOUBLE: float, - FieldDescriptor.TYPE_FIXED32: int, - FieldDescriptor.TYPE_FIXED64: int, - FieldDescriptor.TYPE_FLOAT: float, - FieldDescriptor.TYPE_INT32: int, - FieldDescriptor.TYPE_INT64: int, - FieldDescriptor.TYPE_SFIXED32: int, - FieldDescriptor.TYPE_SFIXED64: int, - FieldDescriptor.TYPE_SINT32: int, - FieldDescriptor.TYPE_SINT64: int, - FieldDescriptor.TYPE_STRING: str, - FieldDescriptor.TYPE_UINT32: int, - FieldDescriptor.TYPE_UINT64: int, - } - ).get(descriptor.type)) is None: + if ( + python_type := ( + { + FieldDescriptor.TYPE_BOOL: bool, + FieldDescriptor.TYPE_DOUBLE: float, + FieldDescriptor.TYPE_FIXED32: int, + FieldDescriptor.TYPE_FIXED64: int, + FieldDescriptor.TYPE_FLOAT: float, + FieldDescriptor.TYPE_INT32: int, + FieldDescriptor.TYPE_INT64: int, + FieldDescriptor.TYPE_SFIXED32: int, + FieldDescriptor.TYPE_SFIXED64: int, + FieldDescriptor.TYPE_SINT32: int, + FieldDescriptor.TYPE_SINT64: int, + FieldDescriptor.TYPE_STRING: str, + FieldDescriptor.TYPE_UINT32: int, + FieldDescriptor.TYPE_UINT64: int, + } + ).get(descriptor.type) + ) is None: raise NotImplementedError(f"Unknown message type: {descriptor.type}") return python_type diff --git a/src/python/zigopt/queue/workers.py b/src/python/zigopt/queue/workers.py index 73a6ad76..a42546c7 100644 --- a/src/python/zigopt/queue/workers.py +++ b/src/python/zigopt/queue/workers.py @@ -72,7 +72,9 @@ def __init__(self, message_group, global_services, request_local_services_factor self._get_pull_queue_name() def _get_pull_queue_name(self): - if (queue_name := self.global_services.message_router.get_queue_name_from_message_group(self.message_group)) is None: + if ( + queue_name := self.global_services.message_router.get_queue_name_from_message_group(self.message_group) + ) is None: raise Exception(f"Missing queue name for {self.message_group}") return queue_name @@ -92,10 +94,12 @@ def _check_stop_conditions(self, base_max_messages): self.logger.info("QueueWorkers killed by kill policy") raise WorkerKilledException() - if (max_messages_threshold := coalesce( - base_max_messages, - self.global_services.config_broker.get(f"queue.{self.message_group.value}.max_messages"), - )) is None: + if ( + max_messages_threshold := coalesce( + base_max_messages, + self.global_services.config_broker.get(f"queue.{self.message_group.value}.max_messages"), + ) + ) is None: should_process = True else: max_messages_threshold = max(max_messages_threshold, self.jitter * 2) diff --git a/src/python/zigopt/suggestion/broker/base.py b/src/python/zigopt/suggestion/broker/base.py index 1b05d468..86ba5168 100644 --- a/src/python/zigopt/suggestion/broker/base.py +++ b/src/python/zigopt/suggestion/broker/base.py @@ -19,9 +19,11 @@ class BaseBroker(Service): def serve_suggestion(self, experiment, processed_suggestion_meta, auth, automatic=False): - if (queued_suggestion := self.retrieve_queued_suggestions_if_exists( - experiment, processed_suggestion_meta, automatic=automatic - )) is not None: + if ( + queued_suggestion := self.retrieve_queued_suggestions_if_exists( + experiment, processed_suggestion_meta, automatic=automatic + ) + ) is not None: return queued_suggestion next_suggestion = self.next_suggestion(experiment, processed_suggestion_meta, automatic=automatic) diff --git a/src/python/zigopt/token/service.py b/src/python/zigopt/token/service.py index eedbac6c..1f200498 100644 --- a/src/python/zigopt/token/service.py +++ b/src/python/zigopt/token/service.py @@ -73,7 +73,9 @@ def _make_meta(self, session_expiration: int | None, token_type: str, can_renew: def _get_or_create_role_token(self, client_id: int, user_id: int, development: bool) -> Token: assert client_id is not None assert user_id is not None - if existing := [token for token in self.find_by_client_and_user(client_id, user_id) if token.development == development]: + if existing := [ + token for token in self.find_by_client_and_user(client_id, user_id) if token.development == development + ]: return existing[0] token_type = TokenType.CLIENT_DEV if development else TokenType.CLIENT_API meta = self._make_meta(session_expiration=None, token_type=token_type, can_renew=False) @@ -192,7 +194,6 @@ def renew_token(self, token: Token) -> Token | None: def rotate_token(self, token: Token) -> Token | None: token_string = random_string() - if updated := self.services.database_service.update_one_or_none( self.services.database_service.query(Token).filter(Token.token == token.token), {Token.token: token_string} ):