Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use Assignment Expression (Walrus) In Conditional #518

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 3 additions & 6 deletions src/python/zigopt/api/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,12 +124,10 @@ def _password_reset_authentication(services, request):
to the method as the first argument.
"""
email = napply(request.optional_param("email"), validate_email)
optional_api_token = request.optional_api_token()
if optional_api_token:
if optional_api_token := request.optional_api_token():
token = _validate_api_token(request.optional_user_token())
token_authorization = _do_api_token_authentication(services, request, token, mandatory=True)
auth_email = token_authorization.current_user and token_authorization.current_user.email
if auth_email:
if auth_email := token_authorization.current_user and token_authorization.current_user.email:
if email and auth_email != email:
raise BadParamError("Invalid email parameter when authenticating with API token")
email = auth_email
Expand Down Expand Up @@ -258,8 +256,7 @@ def _login_authentication(services, request):
rate_limit_identifier = login_rate_limit.increment_and_check_rate_limit(services, request)
authentication = authenticate_login(services, email, password, code)
login_rate_limit.reset_rate_limit(services, rate_limit_identifier)
user = authentication.user
if user is not None:
if (user := authentication.user) is not None:
authenticated_from_email_link = authentication.authenticated_from_email_link
user_token = services.token_service.create_temporary_user_token(user.id)
return UserLoginAuthorization(
Expand Down
3 changes: 1 addition & 2 deletions src/python/zigopt/api/paging.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,7 @@

def decode_base64_marker_without_padding(serialized_marker):
# NOTE: "=" padding needs to be added to the marker in order for it to be correctly decoded
tail_len = len(serialized_marker) % 4
if tail_len:
if tail_len := len(serialized_marker) % 4:
serialized_marker += "=" * (4 - tail_len)
return base64.urlsafe_b64decode(serialized_marker)

Expand Down
3 changes: 1 addition & 2 deletions src/python/zigopt/api/ratelimit.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,7 @@ class _ApiTokenRateLimit(_BaseRateLimit):
NON_MUTATING_METHODS = {"GET", "HEAD", "OPTIONS"}

def _get_identifier(self, services, identifying_object):
token = identifying_object.optional_client_token()
if token is not None:
if (token := identifying_object.optional_client_token()) is not None:
return token, TOKEN_IDENTIFIER
return _UNUSED, _UNUSED

Expand Down
15 changes: 5 additions & 10 deletions src/python/zigopt/api/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,8 +221,7 @@ def on_json_loading_failed(self, e):
raise SigoptValidationError("Malformed json in request body") from e

def params(self):
ret = self._params
if ret is self._MISSING:
if (ret := self._params) is self._MISSING:
if self.method in ("GET", "DELETE"):
self._params = {validate_api_input(k): validate_api_input(v) for k, v in self.args.items()}
elif self.method in ("PUT", "POST", "MERGE"):
Expand Down Expand Up @@ -259,8 +258,7 @@ def optional_param(self, name):

def required_param(self, name):
"""Retrieve ``name`` from flask HTTP request and *fail* if ``name`` is not found."""
value = self.optional_param(name)
if value is not None:
if (value := self.optional_param(name)) is not None:
return value
raise MissingParamError(name)

Expand All @@ -283,8 +281,7 @@ def optional_list_param(self, name):
"""
assert self.method in ("GET", "POST", "DELETE"), "Must use get_with_validation for non-query arguments"
param_string = self.params().get(name)
param_value = param_string.split(",") if param_string else None
if param_value is None:
if (param_value := param_string.split(",") if param_string else None) is None:
return None
if not isinstance(param_value, list):
raise SigoptValidationError("Invalid list value: " + param_value)
Expand All @@ -296,14 +293,12 @@ def get_sort(self, default_field, default_ascending=False):
return SortRequest(field, ascending)

def _parse_marker(self, name):
serialized_marker = self.optional_param(name)
if serialized_marker is None:
if (serialized_marker := self.optional_param(name)) is None:
return None
return deserialize_paging_marker(serialized_marker)

def get_paging(self, max_limit=DEFAULT_PAGING_MAX_LIMIT):
limit = self.optional_int_param("limit")
if limit is None:
if (limit := self.optional_int_param("limit")) is None:
limit = max_limit
if limit < 0:
raise InvalidValueError(f"Invalid limit: {limit}")
Expand Down
3 changes: 1 addition & 2 deletions src/python/zigopt/assignments/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,7 @@ def assignments_map(self):
raise Exception("Do not access .assignments_map directly - prefer .get_assignments(experiment)")

def _get_required_parameter(self, array, parameter, index):
ret = array[index]
if ret is None:
if (ret := array[index]) is None:
raise ValueError(f"Parameter has no replacement value: {parameter.name}")
return ret

Expand Down
6 changes: 2 additions & 4 deletions src/python/zigopt/authentication/login.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,14 +88,12 @@ def authenticate_login(services, email: str, password: str, code: str | None) ->
delay_seconds_on_failure = min_delay_seconds + jitter_seconds * non_crypto_random.random()
failure_time = start_time + delay_seconds_on_failure
try:
user = services.user_service.find_by_email(email)
if user:
if user := services.user_service.find_by_email(email):
# NOTE: If the user cannot auth with a password, we can't reveal the error they have made.
# This would leak information about the account they are attempting to log in to, notably that
# the account exists and someone has logged into it successfully.
# The user must go through the "Forgot Password" flow
can_auth_with_password = bool(user.hashed_password)
if can_auth_with_password:
if can_auth_with_password := bool(user.hashed_password):
if password is not None:
return authenticate_password(services, user, password)
if code is not None:
Expand Down
3 changes: 1 addition & 2 deletions src/python/zigopt/authentication/token.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,7 @@

@deal.raises(ForbiddenError, TypeError)
def authenticate_token(services, token: str) -> AuthenticationResult:
token_obj = services.token_service.find_by_token(token, include_expired=True)
if token_obj is None:
if (token_obj := services.token_service.find_by_token(token, include_expired=True)) is None:
return AuthenticationResult()
if token_obj.expired:
raise ForbiddenError("Your API token has expired.", token_status=TokenStatus.EXPIRED)
Expand Down
3 changes: 1 addition & 2 deletions src/python/zigopt/authorization/owner.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,7 @@ def can_act_on_organization(self, services, requested_permission, organization):
)

def _can_act_on_client_artifacts(self, services, requested_permission, client_id, owner_id_for_artifacts):
client = services.client_service.find_by_id(client_id)
if client is None:
if (client := services.client_service.find_by_id(client_id)) is None:
return False
return (
self.can_act_on_client(services, requested_permission, client)
Expand Down
6 changes: 2 additions & 4 deletions src/python/zigopt/authorization/user.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,11 +86,9 @@ def can_act_on_organization(self, services, requested_permission, organization):
return False

def _can_act_on_client_artifacts(self, services, requested_permission, client_id, owner_id_for_artifacts):
organization_id = self._infer_organization_id_from_client_id(services, client_id)
if organization_id:
if organization_id := self._infer_organization_id_from_client_id(services, client_id):
organization = self._infer_organization_from_organization_id(services, organization_id)
membership = self._membership(services, organization=organization)
if membership:
if membership := self._membership(services, organization=organization):
if membership.is_owner:
return True

Expand Down
3 changes: 1 addition & 2 deletions src/python/zigopt/common/lists.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,7 @@ def partition(lis: Sequence[T], predicate: Callable[[T], bool]) -> tuple[list[T]
true_list = []
false_list = []
for l in as_list:
pred_value = predicate(l)
if pred_value:
if pred_value := predicate(l):
true_list.append(l)
else:
false_list.append(l)
Expand Down
3 changes: 1 addition & 2 deletions src/python/zigopt/db/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,8 +408,7 @@ def one(self, q: Query) -> Any:
@sanitize_errors
@retry_on_error
def one_or_none(self, q: Query) -> Any | None:
ret = q.one_or_none()
if ret is not None:
if (ret := q.one_or_none()) is not None:
self._expunge_one(ret)
self._rollback()
return ret
Expand Down
6 changes: 2 additions & 4 deletions src/python/zigopt/experiment/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,14 +374,12 @@ def costliest_task(self):
return max_option(self.tasks, key=lambda x: x.cost)

def get_task_by_name(self, task_name):
matching_task = find(self.tasks, lambda t: t.name == task_name)
if matching_task is None:
if (matching_task := find(self.tasks, lambda t: t.name == task_name)) is None:
raise ValueError(f"No task named {task_name} for experiment {self.id}.")
return matching_task

def get_task_by_cost(self, task_cost):
matching_task = find(self.tasks, lambda t: t.cost == task_cost)
if matching_task is None:
if (matching_task := find(self.tasks, lambda t: t.cost == task_cost)) is None:
raise ValueError(f"No task with cost {task_cost} for experiment {self.id}.")
return matching_task

Expand Down
4 changes: 2 additions & 2 deletions src/python/zigopt/experiment/segmenter.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,8 @@ def prune_intervals(self, experiment, with_assignments_maps, intervals):
for has_assignments_map in with_assignments_maps:
for name, assignment in has_assignments_map.get_assignments(experiment).items():
remaining_intervals = intervals.get(name, [])
to_remove = find(remaining_intervals, lambda i: assignment in i) # pylint: disable=cell-var-from-loop
if to_remove:
# pylint: disable=cell-var-from-loop
if to_remove := find(remaining_intervals, lambda i: assignment in i):
remaining_intervals.remove(to_remove)

def pick_value(self, parameter, intervals):
Expand Down
10 changes: 4 additions & 6 deletions src/python/zigopt/experiment/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,14 +228,13 @@ def delete(self, experiment: Experiment) -> None:
Not a true DB delete, just sets the deleted flag.
"""
timestamp = current_datetime()
update = self.services.database_service.update_one_or_none(
if update := self.services.database_service.update_one_or_none(
self.services.database_service.query(Experiment).filter(Experiment.id == experiment.id),
{
Experiment.deleted: True,
Experiment.date_updated: timestamp,
},
)
if update:
):
experiment.date_updated = timestamp
experiment.deleted = True
self.services.project_service.mark_as_updated_by_experiment(experiment)
Expand Down Expand Up @@ -271,15 +270,14 @@ def _include_deleted_clause(self, include_deleted: bool, q: Query) -> Query:
def mark_as_updated(self, experiment: Experiment, timestamp: datetime.datetime | None = None) -> None:
if timestamp is None:
timestamp = current_datetime()
did_update = self.services.database_service.update_one_or_none(
if did_update := self.services.database_service.update_one_or_none(
self.services.database_service.query(Experiment)
.filter(Experiment.id == experiment.id)
.filter(Experiment.date_updated < timestamp.replace(microsecond=0)),
{
Experiment.date_updated: timestamp,
},
)
if did_update:
):
experiment.date_updated = timestamp
self.services.project_service.mark_as_updated_by_experiment(experiment=experiment)

Expand Down
3 changes: 1 addition & 2 deletions src/python/zigopt/file/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,7 @@ def get_download_filename(self):
"text/plain": ".txt",
}
mime_type = self.data.content_type
extension = overrides.get(mime_type)
if extension is None:
if (extension := overrides.get(mime_type)) is None:
extension = mimetypes.guess_extension(self.data.content_type)
if extension is None:
extension = ""
Expand Down
3 changes: 1 addition & 2 deletions src/python/zigopt/handlers/aiexperiments/create.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,8 +340,7 @@ def get_metric_list_from_json(cls, json_dict):

@classmethod
def get_metric_name(cls, metric, seen_names, num_metrics):
name = super().get_metric_name(metric, seen_names, num_metrics)
if name is None:
if (name := super().get_metric_name(metric, seen_names, num_metrics)) is None:
raise MissingJsonKeyError("name", "All metrics require a name")
return name

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,7 @@ def maybe_create_observation_from_params(self, training_run_params):

assignments = self.get_assignments_relevant_to_experiment(training_run_params)
values, failed = self.get_observation_values(training_run_params)
observation = self.maybe_create_observation_from_data(assignments, values, failed)
if observation:
if observation := self.maybe_create_observation_from_data(assignments, values, failed):
observation_assignments = observation.get_assignments(self.experiment)
return observation, assignments_json(self.experiment, observation_assignments)
return None, {}
Expand Down Expand Up @@ -119,8 +118,7 @@ def maybe_create_explicit_suggestion_from_assignments(self, assignments):
def maybe_create_suggestion(self, training_run_params):
assert self.experiment is not None

provided_experiment_assignments = self.get_assignments_relevant_to_experiment(training_run_params)
if provided_experiment_assignments:
if provided_experiment_assignments := self.get_assignments_relevant_to_experiment(training_run_params):
suggestion = self.maybe_create_explicit_suggestion_from_assignments(provided_experiment_assignments)
else:
suggestion = self.serve_suggestion()
Expand Down Expand Up @@ -156,8 +154,7 @@ def handle(self, params):
params.training_run_params.training_run_data.assignments_struct.update(assignments)

training_run_data = params.training_run_params.training_run_data
assignments_meta = training_run_data.assignments_meta
if assignments_meta is not None:
if (assignments_meta := training_run_data.assignments_meta) is not None:
validate_assignments_meta(training_run_data.assignments_struct, assignments_meta, None)
if suggestion:
for assignment in assignments.keys():
Expand Down
5 changes: 2 additions & 3 deletions src/python/zigopt/handlers/clients/delete.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,10 @@ def handle(self):
assert self.auth is not None
assert self.client is not None

user_can_delete = self.services.membership_service.user_is_owner_for_organization(
if user_can_delete := self.services.membership_service.user_is_owner_for_organization(
user_id=self.auth.current_user.id,
organization_id=self.client.organization_id,
)
if user_can_delete:
):
self.do_delete()
return {}
raise ForbiddenError("You cannot delete this client.")
Expand Down
9 changes: 3 additions & 6 deletions src/python/zigopt/handlers/clients/invite.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,8 @@ def handle(self, request):
raise ForbiddenError("You do not have permission to invite this email address.")

client_invites = [dict(id=client.id, role=role, old_role=old_role)]
existing_invite = self.services.invite_service.find_by_email_and_organization(email, client.organization_id)

if existing_invite:
if existing_invite := self.services.invite_service.find_by_email_and_organization(email, client.organization_id):
if existing_invite.membership_type == MembershipType.owner:
raise InvalidValueError("This email was already invited as an owner")

Expand Down Expand Up @@ -153,9 +152,7 @@ def handle(self, email):
)

self.services.pending_permission_service.delete_by_email_and_client(email, self.client)
invite = self.services.invite_service.find_by_email_and_organization(email, self.client.organization_id)
if invite:
num_pending_permissions = self.services.pending_permission_service.count_by_invite_id(invite.id)
if num_pending_permissions == 0:
if invite := self.services.invite_service.find_by_email_and_organization(email, self.client.organization_id):
if (num_pending_permissions := self.services.pending_permission_service.count_by_invite_id(invite.id)) == 0:
self.services.invite_service.delete_by_email_and_organization(email, self.client.organization_id)
return {}
31 changes: 15 additions & 16 deletions src/python/zigopt/handlers/clients/tokens.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,7 @@ def handle(self):
assert self.token is not None
if not self.token.expiration_timestamp and self.token.all_experiments:
raise ForbiddenError("Cannot delete root token")
success = self.services.token_service.delete_token(self.token)
if success:
if success := self.services.token_service.delete_token(self.token):
return {}
raise NotFoundError("Token not found")

Expand All @@ -81,17 +80,15 @@ def handle(self, params):
assert self.auth is not None
assert self.token is not None

new_token_value = params.token
if new_token_value is not None:
if (new_token_value := params.token) is not None:
if new_token_value != "rotate":
raise SigoptValidationError('Token must equal "rotate"')
self.services.token_service.rotate_token(self.token)
if params.lasts_forever is not None:
meta = copy_protobuf(self.token.meta)
meta.lasts_forever = params.lasts_forever
self.services.token_service.update_meta(self.token, meta)
new_expires_value = params.expires
if new_expires_value is not None:
if (new_expires_value := params.expires) is not None:
if new_expires_value != "renew":
raise SigoptValidationError('Expires must equal "renew"')
if self.token.meta.can_renew:
Expand Down Expand Up @@ -190,11 +187,12 @@ def ensure_includes_role_token(self, tokens):
auth = self.auth
assert self.client is not None
client = self.client
role_token = find(
tokens,
lambda t: t.user_id == auth.current_user.id and t.client_id == client.id and t.development is False,
)
if role_token is None:
if (
role_token := find(
tokens,
lambda t: t.user_id == auth.current_user.id and t.client_id == client.id and t.development is False,
)
) is None:
role_token = self.services.token_service.get_or_create_role_token(
self.client.id,
self.auth.current_user.id,
Expand All @@ -208,11 +206,12 @@ def ensure_includes_development_role_token(self, tokens):
assert current_user is not None
assert self.client is not None
client = self.client
development_token = find(
tokens,
lambda t: t.user_id == current_user.id and t.client_id == client.id and t.development is True,
)
if development_token is None:
if (
development_token := find(
tokens,
lambda t: t.user_id == current_user.id and t.client_id == client.id and t.development is True,
)
) is None:
development_token = self.services.token_service.get_or_create_development_role_token(
self.client.id,
self.auth.current_user.id,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,7 @@ def _parse_values(self, data):
observation_value.name = get_with_validation(value_dict, "name", ValidationType.string)
observation_value.value = get_with_validation(value_dict, "value", ValidationType.number)
value_stddev: float | None = get_opt_with_validation(value_dict, "value_stddev", ValidationType.number)
value_var = napply(value_stddev, lambda stddev: stddev * stddev)
if value_var is not None:
if (value_var := napply(value_stddev, lambda stddev: stddev * stddev)) is not None:
observation_value.value_var = value_var
yield observation_value

Expand Down
Loading
Loading