Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support gateways without public IPs on AWS #1224

Merged
merged 1 commit into from
May 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/dstack/_internal/core/backends/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,4 @@
BackendType.LAMBDA,
BackendType.TENSORDOCK,
]
BACKENDS_WITH_PRIVATE_GATEWAY_SUPPORT = [BackendType.AWS]
15 changes: 14 additions & 1 deletion src/dstack/_internal/core/backends/aws/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,12 @@ def create_gateway(
]
if settings.DSTACK_VERSION is not None:
tags.append({"Key": "dstack_version", "Value": settings.DSTACK_VERSION})
vpc_id, subnet_id = get_vpc_id_subnet_id_or_error(
ec2_client=ec2_client,
config=self.config,
region=configuration.region,
allocate_public_ip=configuration.public_ip,
)
response = ec2.create_instances(
**aws_resources.create_instances_struct(
disk_size=10,
Expand All @@ -215,17 +221,24 @@ def create_gateway(
security_group_id=aws_resources.create_gateway_security_group(
ec2_client=ec2_client,
project_id=configuration.project_name,
vpc_id=vpc_id,
),
spot=False,
subnet_id=subnet_id,
allocate_public_ip=configuration.public_ip,
)
)
instance = response[0]
instance.wait_until_running()
instance.reload() # populate instance.public_ip_address
if configuration.public_ip:
ip_address = instance.public_ip_address
else:
ip_address = instance.private_ip_address
return LaunchedGatewayInfo(
instance_id=instance.instance_id,
region=configuration.region,
ip_address=instance.public_ip_address,
ip_address=ip_address,
)


Expand Down
32 changes: 22 additions & 10 deletions src/dstack/_internal/core/backends/aws/resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,20 +171,31 @@ def get_gateway_image_id(ec2_client: botocore.client.BaseClient) -> str:
return image["ImageId"]


def create_gateway_security_group(ec2_client: botocore.client.BaseClient, project_id: str) -> str:
def create_gateway_security_group(
ec2_client: botocore.client.BaseClient,
project_id: str,
vpc_id: Optional[str],
) -> str:
security_group_name = "dstack_gw_sg_" + project_id.replace("-", "_").lower()

response = ec2_client.describe_security_groups(
Filters=[
describe_security_groups_filters = [
{
"Name": "group-name",
"Values": [security_group_name],
},
]
if vpc_id is not None:
describe_security_groups_filters.append(
{
"Name": "group-name",
"Values": [security_group_name],
},
],
)
"Name": "vpc-id",
"Values": [vpc_id],
}
)
response = ec2_client.describe_security_groups(Filters=describe_security_groups_filters)
if response.get("SecurityGroups"):
return response["SecurityGroups"][0]["GroupId"]

create_security_group_kwargs = {}
if vpc_id is not None:
create_security_group_kwargs["VpcId"] = vpc_id
security_group = ec2_client.create_security_group(
Description="Generated by dstack",
GroupName=security_group_name,
Expand All @@ -198,6 +209,7 @@ def create_gateway_security_group(ec2_client: botocore.client.BaseClient, projec
],
},
],
**create_security_group_kwargs,
)
group_id = security_group["GroupId"]

Expand Down
2 changes: 1 addition & 1 deletion src/dstack/_internal/core/models/gateways.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class GatewayConfiguration(CoreModel):
domain: Annotated[
Optional[str], Field(description="The gateway domain, e.g. `*.example.com`")
] = None
# public_ip: Annotated[bool, Field(description="Allocate public IP for the gateway")] = True
public_ip: Annotated[bool, Field(description="Allocate public IP for the gateway")] = True


class GatewayComputeConfiguration(CoreModel):
Expand Down
58 changes: 39 additions & 19 deletions src/dstack/_internal/server/services/gateways/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

import dstack._internal.server.services.jobs as jobs_services
import dstack._internal.utils.random_names as random_names
from dstack._internal.core.backends import BACKENDS_WITH_PRIVATE_GATEWAY_SUPPORT
from dstack._internal.core.backends.base.compute import (
Compute,
get_dstack_gateway_wheel,
Expand Down Expand Up @@ -87,17 +88,27 @@ async def get_project_default_gateway(


async def create_gateway_compute(
project_name: str,
backend_compute: Compute,
configuration: GatewayComputeConfiguration,
configuration: GatewayConfiguration,
backend_id: Optional[uuid.UUID] = None,
) -> GatewayComputeModel:
private_bytes, public_bytes = generate_rsa_key_pair_bytes()
gateway_ssh_private_key = private_bytes.decode()
gateway_ssh_public_key = public_bytes.decode()

compute_configuration = GatewayComputeConfiguration(
project_name=project_name,
instance_name=configuration.name,
backend=configuration.backend,
region=configuration.region,
public_ip=configuration.public_ip,
ssh_key_pub=gateway_ssh_public_key,
)

info = await run_async(
backend_compute.create_gateway,
configuration,
compute_configuration,
)

return GatewayComputeModel(
Expand All @@ -122,6 +133,15 @@ async def create_gateway(
else:
raise ResourceNotExistsError()

if (
not configuration.public_ip
and configuration.backend not in BACKENDS_WITH_PRIVATE_GATEWAY_SUPPORT
):
raise GatewayError(
f"Private gateways are not supported for {configuration.backend.value} backend. "
f"Supported backends: {[b.value for b in BACKENDS_WITH_PRIVATE_GATEWAY_SUPPORT]}."
)

if configuration.name is None:
configuration.name = await generate_gateway_name(session=session, project=project)

Expand All @@ -139,19 +159,11 @@ async def create_gateway(
if project.default_gateway is None or configuration.default:
await set_default_gateway(session=session, project=project, name=configuration.name)

compute_configuration = GatewayComputeConfiguration(
project_name=project.name,
instance_name=gateway.name,
backend=configuration.backend,
region=configuration.region,
public_ip=True,
ssh_key_pub=project.name,
)

try:
gateway.gateway_compute = await create_gateway_compute(
backend_compute=backend.compute(),
configuration=compute_configuration,
project_name=project.name,
configuration=configuration,
backend_id=backend_model.id,
)
session.add(gateway)
Expand Down Expand Up @@ -321,13 +333,6 @@ async def generate_gateway_name(session: AsyncSession, project: ProjectModel) ->
async def register_service(session: AsyncSession, run_model: RunModel):
run_spec = RunSpec.__response__.parse_raw(run_model.run_spec)

service_https = run_spec.configuration.https
service_protocol = "https" if service_https else "http"

# Currently, gateway endpoint is always https
gateway_https = True
gateway_protocol = "https" if gateway_https else "http"

# TODO(egor-s): allow to configure gateway name
gateway_name: Optional[str] = None
if gateway_name is None:
Expand All @@ -343,6 +348,21 @@ async def register_service(session: AsyncSession, run_model: RunModel):
if gateway.gateway_compute is None:
raise ServerClientError("Gateway has no instance associated with it")

service_https = run_spec.configuration.https
service_protocol = "https" if service_https else "http"

gateway_configuration = None
if gateway.configuration is not None:
gateway_configuration = GatewayConfiguration.__response__.parse_raw(gateway.configuration)
if service_https and not gateway_configuration.public_ip:
raise ServerClientError("Cannot run HTTPS service on gateway without public IP")

gateway_https = True
if gateway_configuration is not None:
# Currently, https is always False for private gateways
gateway_https = gateway_configuration.public_ip
gateway_protocol = "https" if gateway_https else "http"

wildcard_domain = gateway.wildcard_domain.lstrip("*.") if gateway.wildcard_domain else None
if wildcard_domain is None:
raise ServerClientError("Domain is required for gateway")
Expand Down
6 changes: 4 additions & 2 deletions src/dstack/_internal/server/services/gateways/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import aiorwlock

from dstack._internal.core.services.ssh.ports import PortsLock
from dstack._internal.server.services.gateways.client import (
GATEWAY_MANAGEMENT_PORT,
GatewayClient,
Expand All @@ -29,9 +30,10 @@ def __init__(self, ip_address: str, id_rsa: str, server_port: int):
self._lock = aiorwlock.RWLock()
self.stats: Dict[str, Dict[int, Stat]] = {}
self.ip_address = ip_address

self.ports_lock = PortsLock(restrictions={server_port: 0}).acquire()
local_port = self.ports_lock.dict()[server_port]
args = ["-L", "{temp_dir}/gateway:localhost:%d" % GATEWAY_MANAGEMENT_PORT]
args += ["-R", f"localhost:8001:localhost:{server_port}"]
args += ["-R", f"localhost:{local_port}:localhost:{server_port}"]
self.tunnel = AsyncSSHTunnel(
f"ubuntu@{ip_address}",
id_rsa,
Expand Down
7 changes: 7 additions & 0 deletions src/tests/_internal/server/routers/test_gateways.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ async def test_list(self, test_db, session: AsyncSession):
"region": gateway.region,
"domain": gateway.wildcard_domain,
"default": False,
"public_ip": True,
},
}
]
Expand Down Expand Up @@ -124,6 +125,7 @@ async def test_get(self, test_db, session: AsyncSession):
"region": gateway.region,
"domain": gateway.wildcard_domain,
"default": False,
"public_ip": True,
},
}

Expand Down Expand Up @@ -203,6 +205,7 @@ async def test_create_gateway(self, test_db, session: AsyncSession):
"region": "us",
"domain": None,
"default": True,
"public_ip": True,
},
}

Expand Down Expand Up @@ -257,6 +260,7 @@ async def test_create_gateway_without_name(self, test_db, session: AsyncSession)
"region": "us",
"domain": None,
"default": True,
"public_ip": True,
},
}

Expand Down Expand Up @@ -391,6 +395,7 @@ async def test_set_default_gateway(self, test_db, session: AsyncSession):
"region": gateway.region,
"domain": gateway.wildcard_domain,
"default": True,
"public_ip": True,
},
}

Expand Down Expand Up @@ -498,6 +503,7 @@ def get_backend(_, backend_type):
"region": gateway_gcp.region,
"domain": gateway_gcp.wildcard_domain,
"default": False,
"public_ip": True,
},
}
]
Expand Down Expand Up @@ -557,6 +563,7 @@ async def test_set_wildcard_domain(self, test_db, session: AsyncSession):
"region": gateway.region,
"domain": "test.com",
"default": False,
"public_ip": True,
},
}

Expand Down
Loading