Skip to content
This repository has been archived by the owner on Feb 14, 2025. It is now read-only.

Commit

Permalink
keep simple walruses
Browse files Browse the repository at this point in the history
  • Loading branch information
tjs-intel committed Feb 26, 2024
1 parent 32bf309 commit fa8cb21
Show file tree
Hide file tree
Showing 38 changed files with 68 additions and 123 deletions.
6 changes: 2 additions & 4 deletions src/python/zigopt/api/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,12 +124,10 @@ def _password_reset_authentication(services, request):
to the method as the first argument.
"""
email = napply(request.optional_param("email"), validate_email)
optional_api_token = request.optional_api_token()
if optional_api_token:
if optional_api_token := request.optional_api_token():
token = _validate_api_token(request.optional_user_token())
token_authorization = _do_api_token_authentication(services, request, token, mandatory=True)
auth_email = token_authorization.current_user and token_authorization.current_user.email
if auth_email:
if auth_email := token_authorization.current_user and token_authorization.current_user.email:
if email and auth_email != email:
raise BadParamError("Invalid email parameter when authenticating with API token")
email = auth_email
Expand Down
3 changes: 1 addition & 2 deletions src/python/zigopt/api/paging.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,7 @@

def decode_base64_marker_without_padding(serialized_marker):
# NOTE: "=" padding needs to be added to the marker in order for it to be correctly decoded
tail_len = len(serialized_marker) % 4
if tail_len:
if tail_len := len(serialized_marker) % 4:
serialized_marker += "=" * (4 - tail_len)
return base64.urlsafe_b64decode(serialized_marker)

Expand Down
3 changes: 1 addition & 2 deletions src/python/zigopt/authentication/login.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,7 @@ def authenticate_login(services, email: str, password: str, code: str | None) ->
delay_seconds_on_failure = min_delay_seconds + jitter_seconds * non_crypto_random.random()
failure_time = start_time + delay_seconds_on_failure
try:
user = services.user_service.find_by_email(email)
if user:
if user := services.user_service.find_by_email(email):
# NOTE: If the user cannot auth with a password, we can't reveal the error they have made.
# This would leak information about the account they are attempting to log in to, notably that
# the account exists and someone has logged into it successfully.
Expand Down
6 changes: 2 additions & 4 deletions src/python/zigopt/authorization/user.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,11 +86,9 @@ def can_act_on_organization(self, services, requested_permission, organization):
return False

def _can_act_on_client_artifacts(self, services, requested_permission, client_id, owner_id_for_artifacts):
organization_id = self._infer_organization_id_from_client_id(services, client_id)
if organization_id:
if organization_id := self._infer_organization_id_from_client_id(services, client_id):
organization = self._infer_organization_from_organization_id(services, organization_id)
membership = self._membership(services, organization=organization)
if membership:
if membership := self._membership(services, organization=organization):
if membership.is_owner:
return True

Expand Down
3 changes: 1 addition & 2 deletions src/python/zigopt/common/lists.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,7 @@ def partition(lis: Sequence[T], predicate: Callable[[T], bool]) -> tuple[list[T]
true_list = []
false_list = []
for l in as_list:
pred_value = predicate(l)
if pred_value:
if predicate(l):
true_list.append(l)
else:
false_list.append(l)
Expand Down
4 changes: 2 additions & 2 deletions src/python/zigopt/experiment/segmenter.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,8 @@ def prune_intervals(self, experiment, with_assignments_maps, intervals):
for has_assignments_map in with_assignments_maps:
for name, assignment in has_assignments_map.get_assignments(experiment).items():
remaining_intervals = intervals.get(name, [])
to_remove = find(remaining_intervals, lambda i: assignment in i) # pylint: disable=cell-var-from-loop
if to_remove:
# pylint: disable=cell-var-from-loop
if to_remove := find(remaining_intervals, lambda i: assignment in i):
remaining_intervals.remove(to_remove)

def pick_value(self, parameter, intervals):
Expand Down
10 changes: 4 additions & 6 deletions src/python/zigopt/experiment/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,14 +228,13 @@ def delete(self, experiment: Experiment) -> None:
Not a true DB delete, just sets the deleted flag.
"""
timestamp = current_datetime()
update = self.services.database_service.update_one_or_none(
if self.services.database_service.update_one_or_none(
self.services.database_service.query(Experiment).filter(Experiment.id == experiment.id),
{
Experiment.deleted: True,
Experiment.date_updated: timestamp,
},
)
if update:
):
experiment.date_updated = timestamp
experiment.deleted = True
self.services.project_service.mark_as_updated_by_experiment(experiment)
Expand Down Expand Up @@ -271,15 +270,14 @@ def _include_deleted_clause(self, include_deleted: bool, q: Query) -> Query:
def mark_as_updated(self, experiment: Experiment, timestamp: datetime.datetime | None = None) -> None:
if timestamp is None:
timestamp = current_datetime()
did_update = self.services.database_service.update_one_or_none(
if self.services.database_service.update_one_or_none(
self.services.database_service.query(Experiment)
.filter(Experiment.id == experiment.id)
.filter(Experiment.date_updated < timestamp.replace(microsecond=0)),
{
Experiment.date_updated: timestamp,
},
)
if did_update:
):
experiment.date_updated = timestamp
self.services.project_service.mark_as_updated_by_experiment(experiment=experiment)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,7 @@ def maybe_create_observation_from_params(self, training_run_params):

assignments = self.get_assignments_relevant_to_experiment(training_run_params)
values, failed = self.get_observation_values(training_run_params)
observation = self.maybe_create_observation_from_data(assignments, values, failed)
if observation:
if observation := self.maybe_create_observation_from_data(assignments, values, failed):
observation_assignments = observation.get_assignments(self.experiment)
return observation, assignments_json(self.experiment, observation_assignments)
return None, {}
Expand Down Expand Up @@ -119,8 +118,7 @@ def maybe_create_explicit_suggestion_from_assignments(self, assignments):
def maybe_create_suggestion(self, training_run_params):
assert self.experiment is not None

provided_experiment_assignments = self.get_assignments_relevant_to_experiment(training_run_params)
if provided_experiment_assignments:
if provided_experiment_assignments := self.get_assignments_relevant_to_experiment(training_run_params):
suggestion = self.maybe_create_explicit_suggestion_from_assignments(provided_experiment_assignments)
else:
suggestion = self.serve_suggestion()
Expand Down
3 changes: 1 addition & 2 deletions src/python/zigopt/handlers/clients/invite.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,8 @@ def handle(self, request):
raise ForbiddenError("You do not have permission to invite this email address.")

client_invites = [dict(id=client.id, role=role, old_role=old_role)]
existing_invite = self.services.invite_service.find_by_email_and_organization(email, client.organization_id)

if existing_invite:
if existing_invite := self.services.invite_service.find_by_email_and_organization(email, client.organization_id):
if existing_invite.membership_type == MembershipType.owner:
raise InvalidValueError("This email was already invited as an owner")

Expand Down
3 changes: 1 addition & 2 deletions src/python/zigopt/handlers/clients/tokens.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,7 @@ def handle(self):
assert self.token is not None
if not self.token.expiration_timestamp and self.token.all_experiments:
raise ForbiddenError("Cannot delete root token")
success = self.services.token_service.delete_token(self.token)
if success:
if self.services.token_service.delete_token(self.token):
return {}
raise NotFoundError("Token not found")

Expand Down
9 changes: 4 additions & 5 deletions src/python/zigopt/handlers/experiments/list_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,9 @@ def _maybe_search(self, query_args, params, client_ids):
if params.search is None:
return
keyword = params.search.lower().strip()
matching_user_ids = set(

# TODO: Consider optimization of search ordering
if matching_user_ids := set(
u_id
for (u_id,) in flatten(
[
Expand All @@ -152,10 +154,7 @@ def _maybe_search(self, query_args, params, client_ids):
),
]
)
)

# TODO: Consider optimization of search ordering
if matching_user_ids:
):
search_filter = or_(Experiment.created_by.in_(matching_user_ids), Experiment.name.ilike(f"%{keyword}%"))
else:
search_filter = Experiment.name.ilike(f"%{keyword}%")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -171,8 +171,7 @@ def handle(self, request):
no_optimize=self.get_no_optimize_for_observation_json(params),
)

bad_keys = params.keys() - observation_json.keys()
if bad_keys:
if bad_keys := params.keys() - observation_json.keys():
raise InvalidKeyError(f"Unknown keys were provided for observation create: {bad_keys}")

observation_json = remove_nones_mapping(observation_json)
Expand Down
3 changes: 1 addition & 2 deletions src/python/zigopt/handlers/organizations/update.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,7 @@ def parse_params(self, request):
"allow_signup_from_email_domains",
ValidationType.boolean,
)
email_domains = get_opt_with_validation(data, "email_domains", ValidationType.arrayOf(ValidationType.string))
if email_domains:
if email_domains := get_opt_with_validation(data, "email_domains", ValidationType.arrayOf(ValidationType.string)):
email_domains = [validate_email_domain(domain) for domain in email_domains]

did_enable_allow_signup = (
Expand Down
3 changes: 1 addition & 2 deletions src/python/zigopt/handlers/training_runs/create.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,7 @@ def parse_params(self, request):
return self.parse_request(request)

def parse_request(self, request):
unaccepted_params = request.params().keys() - TrainingRunRequestParams.valid_fields
if unaccepted_params:
if unaccepted_params := request.params().keys() - TrainingRunRequestParams.valid_fields:
raise SigoptValidationError(f"Unknown parameters: {unaccepted_params}")

training_run_params = TrainingRunRequestParser().parse_params(request)
Expand Down
3 changes: 1 addition & 2 deletions src/python/zigopt/handlers/training_runs/files.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,7 @@ def can_act_on_objects(self, requested_permission, objects):
def parse_params(self, request):
provided_params = request.params()
acceptable_params = [key for key, _ in self.REQUIRED_INPUT_PARAMS + self.OPTIONAL_INPUT_PARAMS]
unaccepted_params = provided_params.keys() - acceptable_params
if unaccepted_params:
if unaccepted_params := provided_params.keys() - acceptable_params:
raise SigoptValidationError(
f"Unknown parameters: {unaccepted_params}. Only the following parameters are accepted: {acceptable_params}"
)
Expand Down
3 changes: 1 addition & 2 deletions src/python/zigopt/handlers/training_runs/tags.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,7 @@ class TrainingRunsAddTagHandler(BaseTrainingRunsTagHandler):
def parse_params(self, request):
provided_params = request.params()
acceptable_params = [key for key, _ in self.INPUT_PARAMS]
unaccepted_params = provided_params.keys() - acceptable_params
if unaccepted_params:
if unaccepted_params := provided_params.keys() - acceptable_params:
raise SigoptValidationError(
f"Unknown parameters: {unaccepted_params}. Only the following parameters are accepted: {acceptable_params}"
)
Expand Down
3 changes: 1 addition & 2 deletions src/python/zigopt/handlers/training_runs/update.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,7 @@ class TrainingRunsUpdateHandler(CreatesObservationsMixin, TrainingRunHandler):
training_run: TrainingRun | None

def parse_params(self, request):
unaccepted_params = request.params().keys() - TrainingRunRequestParams.valid_fields
if unaccepted_params:
if unaccepted_params := request.params().keys() - TrainingRunRequestParams.valid_fields:
raise SigoptValidationError(f"Unknown parameters: {unaccepted_params}")

method = request.method.lower()
Expand Down
6 changes: 2 additions & 4 deletions src/python/zigopt/handlers/users/create.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,8 +174,7 @@ def get_verified_invite(self, email, invite_code):
if invite.invite_code == invite_code or (not self.services.email_verification_service.enabled)
]
# TODO: This doesn't handle multiple invite codes. We probably don't need to
claimable_invite = list_get(claimable_invites, 0)
if claimable_invite:
if claimable_invite := list_get(claimable_invites, 0):
if self.services.invite_service.invite_is_valid(claimable_invite):
return claimable_invite
raise ForbiddenError("This invite is no longer valid.")
Expand Down Expand Up @@ -247,8 +246,7 @@ def parse_params(self, request):
)

def handle(self, params):
verified_invite = self.get_verified_invite(params.user_attributes.email, params.invite_code)
if verified_invite:
if verified_invite := self.get_verified_invite(params.user_attributes.email, params.invite_code):
has_verified_email = params.has_verified_email or verified_invite.invite_code == params.invite_code
user = self.create_user_by_invite(params.user_attributes, verified_invite, has_verified_email)
self.create_clients_and_permissions(user, verified_invite)
Expand Down
15 changes: 5 additions & 10 deletions src/python/zigopt/handlers/users/delete.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,7 @@ def handle(self, password):
self.ensure_no_orphaned_organizations()
self.ensure_no_orphaned_clients()

validated_request = password and password_matches(password, self.user.hashed_password)
if validated_request:
if validated_request := password and password_matches(password, self.user.hashed_password):
do_password_hash_work_factor_update(self.services, self.user, password)
if validated_request:
self.do_delete()
Expand Down Expand Up @@ -54,14 +53,12 @@ def ensure_no_orphaned_organizations(self):
)
still_owned_organization_ids = set(m.organization_id for m in other_owner_memberships)

unowned_organization_ids = organization_ids - still_owned_organization_ids
if unowned_organization_ids:
if unowned_organization_ids := organization_ids - still_owned_organization_ids:
other_non_owner_memberships = self.services.membership_service.organizations_with_other_non_owners(
list(unowned_organization_ids),
self.user.id,
)
orphaned_organization_ids = set(m.organization_id for m in other_non_owner_memberships)
if orphaned_organization_ids:
if orphaned_organization_ids := set(m.organization_id for m in other_non_owner_memberships):
raise ForbiddenError(
"This user cannot be deleted without assigning another owner or removing "
f"all other users from the following organizations: {orphaned_organization_ids}"
Expand All @@ -80,18 +77,16 @@ def ensure_no_orphaned_clients(self):
self.user.id,
)
other_owned_organizations = set(m.organization_id for m in organizations_with_other_owners)
client_ids = [p.client_id for p in permissions if p.organization_id not in other_owned_organizations]

if client_ids:
if client_ids := [p.client_id for p in permissions if p.organization_id not in other_owned_organizations]:
outstanding_permissions = self.services.database_service.all(
self.services.database_service.query(Permission)
.filter(Permission.client_id.in_(client_ids))
.filter(Permission.user_id != self.user.id)
)
outstanding_client_ids = [p.client_id for p in outstanding_permissions]
orphaned_clients = list(set(client_ids) - set(outstanding_client_ids))

if orphaned_clients:
if orphaned_clients := list(set(client_ids) - set(outstanding_client_ids)):
raise ForbiddenError(f"This user cannot be deleted without deleting the following clients: {orphaned_clients}")

def do_delete(self):
Expand Down
3 changes: 1 addition & 2 deletions src/python/zigopt/handlers/users/password.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,7 @@ def parse_params(self, request):
return validate_email(request.required_param("email"))

def handle(self, email):
user = self.services.user_service.find_by_email(email)
if user:
if user := self.services.user_service.find_by_email(email):
if user.hashed_password:
code = self.services.user_service.set_password_reset_code(user)
self.services.email_router.send(
Expand Down
3 changes: 1 addition & 2 deletions src/python/zigopt/handlers/validate/training_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,7 @@ def validate_assignments_meta(
if OPTIMIZED_ASSIGNMENT_SOURCE in sources:
raise SigoptValidationError(f" {OPTIMIZED_ASSIGNMENT_SOURCE} source is reserved and set automatically.")

metas_without_param = set(meta_keys) - set(params)
if metas_without_param:
if metas_without_param := set(meta_keys) - set(params):
raise SigoptValidationError(
f" Parameter meta exist for {metas_without_param} but there is no corresponding parameter."
)
3 changes: 1 addition & 2 deletions src/python/zigopt/handlers/validate/user.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,7 @@


def validate_user_name(name: Optional[str]) -> str:
name = validate_name(name)
if name:
if name := validate_name(name):
if len(name) >= User.NAME_MAX_LENGTH:
raise InvalidValueError(f"Name must be fewer than {User.NAME_MAX_LENGTH} characters")
return name
Expand Down
3 changes: 1 addition & 2 deletions src/python/zigopt/json/assignments.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,7 @@ def assignments_json(experiment: Experiment, assignments_dict: dict[str, float])
},
}

unknown_keys = set(assignments_dict.keys()) - set(json_dict.keys())
if unknown_keys:
if unknown_keys := set(assignments_dict.keys()) - set(json_dict.keys()):
raise Exception(f"Attempting to render assignments, unknown keys: {', '.join(unknown_keys)}")

return json_dict
3 changes: 1 addition & 2 deletions src/python/zigopt/json/render.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,7 @@ def render_param_value(parameter: ExperimentParameterProxy, assignment: float) -


def render_conditional_value(conditional: ExperimentConditional, assignment: float) -> str:
value = find(conditional.values, lambda c: c.enum_index == assignment)
if value:
if value := find(conditional.values, lambda c: c.enum_index == assignment):
return value.name
raise ConditionalParamRenderException(
f"Conditional {conditional.name} attempting to render unknown value {assignment}", conditional, assignment
Expand Down
6 changes: 2 additions & 4 deletions src/python/zigopt/log/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,7 @@ def sensitive_filter(record):

def set_default_formatter(formatter):
root_logger = logging.getLogger()
default_handler = list_get(root_logger.handlers, 0)
if default_handler:
if default_handler := list_get(root_logger.handlers, 0):
default_handler.setFormatter(formatter)


Expand Down Expand Up @@ -131,8 +130,7 @@ def configure_warnings():
def configure_loggers(config_broker):
syslog_handler = syslog_logger_setup(config_broker)

force_level = config_broker.get("logging.force")
if force_level:
if force_level := config_broker.get("logging.force"):
LOG_LEVELS = {"": force_level}
else:
LOG_LEVELS = {
Expand Down
3 changes: 1 addition & 2 deletions src/python/zigopt/membership/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,8 +171,7 @@ def insert(self, user_id: int, organization_id: int, **kwargs) -> Membership:
def create_if_not_exists(
self, user_id: int, organization_id: int, membership_type: MembershipType | None = None, **kwargs
) -> Membership:
existing = self.find_by_user_and_organization(user_id, organization_id)
if existing:
if existing := self.find_by_user_and_organization(user_id, organization_id):
return existing
return self.insert(user_id=user_id, organization_id=organization_id, membership_type=membership_type, **kwargs)

Expand Down
3 changes: 1 addition & 2 deletions src/python/zigopt/net/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,8 +112,7 @@ def __init__(self, msg=None, token_status=None):

class EndpointNotFoundError(NotFoundError):
def __init__(self, path, msg=None):
parts = [p for p in path.lstrip("/").split("/") if p]
if parts:
if parts := [p for p in path.lstrip("/").split("/") if p]:
if parts[0] == "v1":
msg = f"Endpoint not found: {path}"
else:
Expand Down
3 changes: 1 addition & 2 deletions src/python/zigopt/observation/from_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,7 @@ def set_observation_data_assignments_map_from_json(observation_data, json_dict,


def set_observation_data_task_from_json(observation_data, json_dict, experiment):
task = extract_task_from_json(experiment, json_dict)
if task:
if task := extract_task_from_json(experiment, json_dict):
observation_data.task.CopyFrom(task)


Expand Down
Loading

0 comments on commit fa8cb21

Please sign in to comment.