Skip to content

Commit

Permalink
✨Clusters-keeper: terminate broken EC2s🚨 (#5851)
Browse files Browse the repository at this point in the history
  • Loading branch information
sanderegg authored May 26, 2024
1 parent b9022f6 commit 09866f1
Show file tree
Hide file tree
Showing 12 changed files with 360 additions and 100 deletions.
5 changes: 4 additions & 1 deletion packages/aws-library/src/aws_library/ec2/client.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import contextlib
import logging
from collections.abc import Iterable
from dataclasses import dataclass
from typing import cast

Expand Down Expand Up @@ -281,7 +282,9 @@ async def get_instances(
)
return all_instances

async def terminate_instances(self, instance_datas: list[EC2InstanceData]) -> None:
async def terminate_instances(
self, instance_datas: Iterable[EC2InstanceData]
) -> None:
try:
with log_context(
_logger,
Expand Down
18 changes: 16 additions & 2 deletions packages/aws-library/src/aws_library/ec2/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,14 +84,28 @@ class AWSTagValue(ConstrainedStr):
@dataclass(frozen=True)
class EC2InstanceData:
launch_time: datetime.datetime
id: str # noqa: A003
id: str
aws_private_dns: InstancePrivateDNSName
aws_public_ip: str | None
type: InstanceTypeType # noqa: A003
type: InstanceTypeType
state: InstanceStateNameType
resources: Resources
tags: EC2Tags

def __hash__(self) -> int:
return hash(
(
self.launch_time,
self.id,
self.aws_private_dns,
self.aws_public_ip,
self.type,
self.state,
self.resources,
tuple(sorted(self.tags.items())),
)
)


@dataclass(frozen=True)
class EC2InstanceConfig:
Expand Down
40 changes: 39 additions & 1 deletion packages/aws-library/tests/test_ec2_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@


import pytest
from aws_library.ec2.models import AWSTagKey, AWSTagValue, Resources
from aws_library.ec2.models import AWSTagKey, AWSTagValue, EC2InstanceData, Resources
from faker import Faker
from pydantic import ByteSize, ValidationError, parse_obj_as


Expand Down Expand Up @@ -132,3 +133,40 @@ def test_aws_tag_key_invalid(ec2_tag_key: str):

# for a value it does not
parse_obj_as(AWSTagValue, ec2_tag_key)


def test_ec2_instance_data_hashable(faker: Faker):
first_set_of_ec2s = {
EC2InstanceData(
faker.date_time(),
faker.pystr(),
faker.pystr(),
f"{faker.ipv4()}",
"g4dn.xlarge",
"running",
Resources(
cpus=faker.pyfloat(min_value=0.1),
ram=ByteSize(faker.pyint(min_value=123)),
),
{AWSTagKey("mytagkey"): AWSTagValue("mytagvalue")},
)
}
second_set_of_ec2s = {
EC2InstanceData(
faker.date_time(),
faker.pystr(),
faker.pystr(),
f"{faker.ipv4()}",
"g4dn.xlarge",
"running",
Resources(
cpus=faker.pyfloat(min_value=0.1),
ram=ByteSize(faker.pyint(min_value=123)),
),
{AWSTagKey("mytagkey"): AWSTagValue("mytagvalue")},
)
}

union_of_sets = first_set_of_ec2s.union(second_set_of_ec2s)
assert next(iter(first_set_of_ec2s)) in union_of_sets
assert next(iter(second_set_of_ec2s)) in union_of_sets
Original file line number Diff line number Diff line change
Expand Up @@ -115,11 +115,8 @@ def check_valid_instance_names(
) -> dict[str, EC2InstanceBootSpecific]:
# NOTE: needed because of a flaw in BaseCustomSettings
# issubclass raises TypeError if used on Aliases
if all(parse_obj_as(InstanceTypeType, key) for key in value):
return value

msg = "Invalid instance type name"
raise ValueError(msg)
parse_obj_as(list[InstanceTypeType], list(value))
return value


class PrimaryEC2InstancesSettings(BaseCustomSettings):
Expand Down Expand Up @@ -177,18 +174,23 @@ class PrimaryEC2InstancesSettings(BaseCustomSettings):
..., description="Password for accessing prometheus data"
)

PRIMARY_EC2_INSTANCES_MAX_START_TIME: datetime.timedelta = Field(
default=datetime.timedelta(minutes=2),
description="Usual time taken an EC2 instance with the given AMI takes to startup and be ready to receive jobs "
"(default to seconds, or see https://pydantic-docs.helpmanual.io/usage/types/#datetime-types for string formating)."
"NOTE: be careful that this time should always be a factor larger than the real time, as EC2 instances"
"that take longer than this time will be terminated as sometimes it happens that EC2 machine fail on start.",
)

@validator("PRIMARY_EC2_INSTANCES_ALLOWED_TYPES")
@classmethod
def check_valid_instance_names(
cls, value: dict[str, EC2InstanceBootSpecific]
) -> dict[str, EC2InstanceBootSpecific]:
# NOTE: needed because of a flaw in BaseCustomSettings
# issubclass raises TypeError if used on Aliases
if all(parse_obj_as(InstanceTypeType, key) for key in value):
return value

msg = "Invalid instance type name"
raise ValueError(msg)
parse_obj_as(list[InstanceTypeType], list(value))
return value

@validator("PRIMARY_EC2_INSTANCES_ALLOWED_TYPES")
@classmethod
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import datetime
import logging
from collections.abc import Iterable

import arrow
from aws_library.ec2.client import SimcoreEC2API
from aws_library.ec2.models import (
AWSTagKey,
Expand Down Expand Up @@ -96,15 +97,17 @@ async def create_cluster(
return new_ec2_instance_data


async def get_all_clusters(app: FastAPI) -> list[EC2InstanceData]:
async def get_all_clusters(app: FastAPI) -> set[EC2InstanceData]:
app_settings = get_application_settings(app)
assert app_settings.CLUSTERS_KEEPER_PRIMARY_EC2_INSTANCES # nosec
ec2_instance_data: list[EC2InstanceData] = await get_ec2_client(app).get_instances(
key_names=[
app_settings.CLUSTERS_KEEPER_PRIMARY_EC2_INSTANCES.PRIMARY_EC2_INSTANCES_KEY_NAME
],
tags=all_created_ec2_instances_filter(app_settings),
state_names=["running"],
ec2_instance_data: set[EC2InstanceData] = set(
await get_ec2_client(app).get_instances(
key_names=[
app_settings.CLUSTERS_KEEPER_PRIMARY_EC2_INSTANCES.PRIMARY_EC2_INSTANCES_KEY_NAME
],
tags=all_created_ec2_instances_filter(app_settings),
state_names=["running"],
)
)
return ec2_instance_data

Expand Down Expand Up @@ -159,9 +162,11 @@ async def set_instance_heartbeat(app: FastAPI, *, instance: EC2InstanceData) ->
ec2_client = get_ec2_client(app)
await ec2_client.set_instances_tags(
[instance],
tags={HEARTBEAT_TAG_KEY: f"{datetime.datetime.now(datetime.timezone.utc)}"},
tags={HEARTBEAT_TAG_KEY: AWSTagValue(arrow.utcnow().datetime.isoformat())},
)


async def delete_clusters(app: FastAPI, *, instances: list[EC2InstanceData]) -> None:
async def delete_clusters(
app: FastAPI, *, instances: Iterable[EC2InstanceData]
) -> None:
await get_ec2_client(app).terminate_instances(instances)
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
import datetime
import logging
from collections.abc import Iterable
from typing import Final

import arrow
from aws_library.ec2.models import EC2InstanceData
from aws_library.ec2.models import AWSTagKey, EC2InstanceData
from fastapi import FastAPI
from models_library.users import UserID
from models_library.wallets import WalletID
from pydantic import parse_obj_as
from servicelib.logging_utils import log_catch

from ..core.settings import get_application_settings
from ..modules.clusters import (
Expand All @@ -21,16 +25,44 @@
_logger = logging.getLogger(__name__)


def _get_instance_last_heartbeat(instance: EC2InstanceData) -> datetime.datetime:
if last_heartbeat := instance.tags.get(HEARTBEAT_TAG_KEY, None):
def _get_instance_last_heartbeat(instance: EC2InstanceData) -> datetime.datetime | None:
if last_heartbeat := instance.tags.get(
HEARTBEAT_TAG_KEY,
):
last_heartbeat_time: datetime.datetime = arrow.get(last_heartbeat).datetime
return last_heartbeat_time
launch_time: datetime.datetime = instance.launch_time
return launch_time

return None


_USER_ID_TAG_KEY: Final[AWSTagKey] = parse_obj_as(AWSTagKey, "user_id")
_WALLET_ID_TAG_KEY: Final[AWSTagKey] = parse_obj_as(AWSTagKey, "wallet_id")


async def _get_all_associated_worker_instances(
app: FastAPI,
primary_instances: Iterable[EC2InstanceData],
) -> list[EC2InstanceData]:
worker_instances = []
for instance in primary_instances:
assert "user_id" in instance.tags # nosec
user_id = UserID(instance.tags[_USER_ID_TAG_KEY])
assert "wallet_id" in instance.tags # nosec
# NOTE: wallet_id can be None
wallet_id = (
WalletID(instance.tags[_WALLET_ID_TAG_KEY])
if instance.tags[_WALLET_ID_TAG_KEY] != "None"
else None
)

worker_instances.extend(
await get_cluster_workers(app, user_id=user_id, wallet_id=wallet_id)
)
return worker_instances


async def _find_terminateable_instances(
app: FastAPI, instances: list[EC2InstanceData]
app: FastAPI, instances: Iterable[EC2InstanceData]
) -> list[EC2InstanceData]:
app_settings = get_application_settings(app)
assert app_settings.CLUSTERS_KEEPER_PRIMARY_EC2_INSTANCES # nosec
Expand All @@ -42,61 +74,89 @@ async def _find_terminateable_instances(
app_settings.CLUSTERS_KEEPER_MAX_MISSED_HEARTBEATS_BEFORE_CLUSTER_TERMINATION
* app_settings.SERVICE_TRACKING_HEARTBEAT
)
startup_delay = (
app_settings.CLUSTERS_KEEPER_PRIMARY_EC2_INSTANCES.PRIMARY_EC2_INSTANCES_MAX_START_TIME
)
for instance in instances:
last_heartbeat = _get_instance_last_heartbeat(instance)

elapsed_time_since_heartbeat = (
datetime.datetime.now(datetime.timezone.utc) - last_heartbeat
)
_logger.info(
"%s has still %ss before being terminateable",
f"{instance.id=}",
f"{(time_to_wait_before_termination - elapsed_time_since_heartbeat).total_seconds()}",
)
if elapsed_time_since_heartbeat >= time_to_wait_before_termination:
# let's terminate that one
terminateable_instances.append(instance)
if last_heartbeat := _get_instance_last_heartbeat(instance):
elapsed_time_since_heartbeat = arrow.utcnow().datetime - last_heartbeat
allowed_time_to_wait = time_to_wait_before_termination
if elapsed_time_since_heartbeat >= allowed_time_to_wait:
terminateable_instances.append(instance)
else:
_logger.info(
"%s has still %ss before being terminateable",
f"{instance.id=}",
f"{(allowed_time_to_wait - elapsed_time_since_heartbeat).total_seconds()}",
)
else:
elapsed_time_since_startup = arrow.utcnow().datetime - instance.launch_time
allowed_time_to_wait = startup_delay
if elapsed_time_since_startup >= allowed_time_to_wait:
terminateable_instances.append(instance)

# get all terminateable instances associated worker instances
worker_instances = []
for instance in terminateable_instances:
assert "user_id" in instance.tags # nosec
user_id = UserID(instance.tags["user_id"])
assert "wallet_id" in instance.tags # nosec
# NOTE: wallet_id can be None
wallet_id = (
WalletID(instance.tags["wallet_id"])
if instance.tags["wallet_id"] != "None"
else None
)

worker_instances.extend(
await get_cluster_workers(app, user_id=user_id, wallet_id=wallet_id)
)
worker_instances = await _get_all_associated_worker_instances(
app, terminateable_instances
)

return terminateable_instances + worker_instances


async def check_clusters(app: FastAPI) -> None:
primary_instances = await get_all_clusters(app)

instances = await get_all_clusters(app)
connected_intances = [
connected_intances = {
instance
for instance in instances
for instance in primary_instances
if await ping_scheduler(get_scheduler_url(instance), get_scheduler_auth(app))
]
}

for instance in connected_intances:
is_busy = await is_scheduler_busy(
get_scheduler_url(instance), get_scheduler_auth(app)
)
_logger.info(
"%s currently %s",
f"{instance.id=} for {instance.tags=}",
f"{'is running tasks' if is_busy else 'not doing anything!'}",
)
if is_busy:
await set_instance_heartbeat(app, instance=instance)
with log_catch(_logger, reraise=False):
# NOTE: some connected instance could in theory break between these 2 calls, therefore this is silenced and will
# be handled in the next call to check_clusters
if await is_scheduler_busy(
get_scheduler_url(instance), get_scheduler_auth(app)
):
_logger.info(
"%s is running tasks",
f"{instance.id=} for {instance.tags=}",
)
await set_instance_heartbeat(app, instance=instance)
if terminateable_instances := await _find_terminateable_instances(
app, connected_intances
):
await delete_clusters(app, instances=terminateable_instances)

# analyse disconnected instances (currently starting or broken)
disconnected_instances = primary_instances - connected_intances

# starting instances do not have a heartbeat set but sometimes might fail and should be terminated
starting_instances = {
instance
for instance in disconnected_instances
if _get_instance_last_heartbeat(instance) is None
}

if terminateable_instances := await _find_terminateable_instances(
app, starting_instances
):
_logger.warning(
"The following clusters'primary EC2 were starting for too long and will be terminated now "
"(either because a cluster was started and is not needed anymore, or there is an issue): '%s",
f"{[i.id for i in terminateable_instances]}",
)
await delete_clusters(app, instances=terminateable_instances)

# the other instances are broken (they were at some point connected but now not anymore)
broken_instances = disconnected_instances - starting_instances
if terminateable_instances := await _find_terminateable_instances(
app, broken_instances
):
_logger.error(
"The following clusters'primary EC2 were found as unresponsive "
"(TIP: there is something wrong here, please inform support) and will be terminated now: '%s",
f"{[i.id for i in terminateable_instances]}",
)
await delete_clusters(app, instances=terminateable_instances)
Loading

0 comments on commit 09866f1

Please sign in to comment.