diff --git a/src/dstack/_internal/core/backends/__init__.py b/src/dstack/_internal/core/backends/__init__.py index de5d4af91..55aafa0ca 100644 --- a/src/dstack/_internal/core/backends/__init__.py +++ b/src/dstack/_internal/core/backends/__init__.py @@ -15,3 +15,4 @@ BackendType.LAMBDA, BackendType.TENSORDOCK, ] +BACKENDS_WITH_PRIVATE_GATEWAY_SUPPORT = [BackendType.AWS] diff --git a/src/dstack/_internal/core/backends/aws/compute.py b/src/dstack/_internal/core/backends/aws/compute.py index 3bc3c834b..af6564fbf 100644 --- a/src/dstack/_internal/core/backends/aws/compute.py +++ b/src/dstack/_internal/core/backends/aws/compute.py @@ -204,6 +204,12 @@ def create_gateway( ] if settings.DSTACK_VERSION is not None: tags.append({"Key": "dstack_version", "Value": settings.DSTACK_VERSION}) + vpc_id, subnet_id = get_vpc_id_subnet_id_or_error( + ec2_client=ec2_client, + config=self.config, + region=configuration.region, + allocate_public_ip=configuration.public_ip, + ) response = ec2.create_instances( **aws_resources.create_instances_struct( disk_size=10, @@ -215,17 +221,24 @@ def create_gateway( security_group_id=aws_resources.create_gateway_security_group( ec2_client=ec2_client, project_id=configuration.project_name, + vpc_id=vpc_id, ), spot=False, + subnet_id=subnet_id, + allocate_public_ip=configuration.public_ip, ) ) instance = response[0] instance.wait_until_running() instance.reload() # populate instance.public_ip_address + if configuration.public_ip: + ip_address = instance.public_ip_address + else: + ip_address = instance.private_ip_address return LaunchedGatewayInfo( instance_id=instance.instance_id, region=configuration.region, - ip_address=instance.public_ip_address, + ip_address=ip_address, ) diff --git a/src/dstack/_internal/core/backends/aws/resources.py b/src/dstack/_internal/core/backends/aws/resources.py index bc8b1de4e..a22ff5141 100644 --- a/src/dstack/_internal/core/backends/aws/resources.py +++ b/src/dstack/_internal/core/backends/aws/resources.py @@ -171,20 +171,31 @@ def get_gateway_image_id(ec2_client: botocore.client.BaseClient) -> str: return image["ImageId"] -def create_gateway_security_group(ec2_client: botocore.client.BaseClient, project_id: str) -> str: +def create_gateway_security_group( + ec2_client: botocore.client.BaseClient, + project_id: str, + vpc_id: Optional[str], +) -> str: security_group_name = "dstack_gw_sg_" + project_id.replace("-", "_").lower() - - response = ec2_client.describe_security_groups( - Filters=[ + describe_security_groups_filters = [ + { + "Name": "group-name", + "Values": [security_group_name], + }, + ] + if vpc_id is not None: + describe_security_groups_filters.append( { - "Name": "group-name", - "Values": [security_group_name], - }, - ], - ) + "Name": "vpc-id", + "Values": [vpc_id], + } + ) + response = ec2_client.describe_security_groups(Filters=describe_security_groups_filters) if response.get("SecurityGroups"): return response["SecurityGroups"][0]["GroupId"] - + create_security_group_kwargs = {} + if vpc_id is not None: + create_security_group_kwargs["VpcId"] = vpc_id security_group = ec2_client.create_security_group( Description="Generated by dstack", GroupName=security_group_name, @@ -198,6 +209,7 @@ def create_gateway_security_group(ec2_client: botocore.client.BaseClient, projec ], }, ], + **create_security_group_kwargs, ) group_id = security_group["GroupId"] diff --git a/src/dstack/_internal/core/models/gateways.py b/src/dstack/_internal/core/models/gateways.py index 6df853b08..f58815473 100644 --- a/src/dstack/_internal/core/models/gateways.py +++ b/src/dstack/_internal/core/models/gateways.py @@ -17,7 +17,7 @@ class GatewayConfiguration(CoreModel): domain: Annotated[ Optional[str], Field(description="The gateway domain, e.g. `*.example.com`") ] = None - # public_ip: Annotated[bool, Field(description="Allocate public IP for the gateway")] = True + public_ip: Annotated[bool, Field(description="Allocate public IP for the gateway")] = True class GatewayComputeConfiguration(CoreModel): diff --git a/src/dstack/_internal/server/services/gateways/__init__.py b/src/dstack/_internal/server/services/gateways/__init__.py index ff960c7f3..597c05730 100644 --- a/src/dstack/_internal/server/services/gateways/__init__.py +++ b/src/dstack/_internal/server/services/gateways/__init__.py @@ -11,6 +11,7 @@ import dstack._internal.server.services.jobs as jobs_services import dstack._internal.utils.random_names as random_names +from dstack._internal.core.backends import BACKENDS_WITH_PRIVATE_GATEWAY_SUPPORT from dstack._internal.core.backends.base.compute import ( Compute, get_dstack_gateway_wheel, @@ -87,17 +88,27 @@ async def get_project_default_gateway( async def create_gateway_compute( + project_name: str, backend_compute: Compute, - configuration: GatewayComputeConfiguration, + configuration: GatewayConfiguration, backend_id: Optional[uuid.UUID] = None, ) -> GatewayComputeModel: private_bytes, public_bytes = generate_rsa_key_pair_bytes() gateway_ssh_private_key = private_bytes.decode() gateway_ssh_public_key = public_bytes.decode() + compute_configuration = GatewayComputeConfiguration( + project_name=project_name, + instance_name=configuration.name, + backend=configuration.backend, + region=configuration.region, + public_ip=configuration.public_ip, + ssh_key_pub=gateway_ssh_public_key, + ) + info = await run_async( backend_compute.create_gateway, - configuration, + compute_configuration, ) return GatewayComputeModel( @@ -122,6 +133,15 @@ async def create_gateway( else: raise ResourceNotExistsError() + if ( + not configuration.public_ip + and configuration.backend not in BACKENDS_WITH_PRIVATE_GATEWAY_SUPPORT + ): + raise GatewayError( + f"Private gateways are not supported for {configuration.backend.value} backend. " + f"Supported backends: {[b.value for b in BACKENDS_WITH_PRIVATE_GATEWAY_SUPPORT]}." + ) + if configuration.name is None: configuration.name = await generate_gateway_name(session=session, project=project) @@ -139,19 +159,11 @@ async def create_gateway( if project.default_gateway is None or configuration.default: await set_default_gateway(session=session, project=project, name=configuration.name) - compute_configuration = GatewayComputeConfiguration( - project_name=project.name, - instance_name=gateway.name, - backend=configuration.backend, - region=configuration.region, - public_ip=True, - ssh_key_pub=project.name, - ) - try: gateway.gateway_compute = await create_gateway_compute( backend_compute=backend.compute(), - configuration=compute_configuration, + project_name=project.name, + configuration=configuration, backend_id=backend_model.id, ) session.add(gateway) @@ -321,13 +333,6 @@ async def generate_gateway_name(session: AsyncSession, project: ProjectModel) -> async def register_service(session: AsyncSession, run_model: RunModel): run_spec = RunSpec.__response__.parse_raw(run_model.run_spec) - service_https = run_spec.configuration.https - service_protocol = "https" if service_https else "http" - - # Currently, gateway endpoint is always https - gateway_https = True - gateway_protocol = "https" if gateway_https else "http" - # TODO(egor-s): allow to configure gateway name gateway_name: Optional[str] = None if gateway_name is None: @@ -343,6 +348,21 @@ async def register_service(session: AsyncSession, run_model: RunModel): if gateway.gateway_compute is None: raise ServerClientError("Gateway has no instance associated with it") + service_https = run_spec.configuration.https + service_protocol = "https" if service_https else "http" + + gateway_configuration = None + if gateway.configuration is not None: + gateway_configuration = GatewayConfiguration.__response__.parse_raw(gateway.configuration) + if service_https and not gateway_configuration.public_ip: + raise ServerClientError("Cannot run HTTPS service on gateway without public IP") + + gateway_https = True + if gateway_configuration is not None: + # Currently, https is always False for private gateways + gateway_https = gateway_configuration.public_ip + gateway_protocol = "https" if gateway_https else "http" + wildcard_domain = gateway.wildcard_domain.lstrip("*.") if gateway.wildcard_domain else None if wildcard_domain is None: raise ServerClientError("Domain is required for gateway") diff --git a/src/dstack/_internal/server/services/gateways/connection.py b/src/dstack/_internal/server/services/gateways/connection.py index 741407075..422a31677 100644 --- a/src/dstack/_internal/server/services/gateways/connection.py +++ b/src/dstack/_internal/server/services/gateways/connection.py @@ -5,6 +5,7 @@ import aiorwlock +from dstack._internal.core.services.ssh.ports import PortsLock from dstack._internal.server.services.gateways.client import ( GATEWAY_MANAGEMENT_PORT, GatewayClient, @@ -29,9 +30,10 @@ def __init__(self, ip_address: str, id_rsa: str, server_port: int): self._lock = aiorwlock.RWLock() self.stats: Dict[str, Dict[int, Stat]] = {} self.ip_address = ip_address - + self.ports_lock = PortsLock(restrictions={server_port: 0}).acquire() + local_port = self.ports_lock.dict()[server_port] args = ["-L", "{temp_dir}/gateway:localhost:%d" % GATEWAY_MANAGEMENT_PORT] - args += ["-R", f"localhost:8001:localhost:{server_port}"] + args += ["-R", f"localhost:{local_port}:localhost:{server_port}"] self.tunnel = AsyncSSHTunnel( f"ubuntu@{ip_address}", id_rsa, diff --git a/src/tests/_internal/server/routers/test_gateways.py b/src/tests/_internal/server/routers/test_gateways.py index 6a40947ab..76a5ffd22 100644 --- a/src/tests/_internal/server/routers/test_gateways.py +++ b/src/tests/_internal/server/routers/test_gateways.py @@ -80,6 +80,7 @@ async def test_list(self, test_db, session: AsyncSession): "region": gateway.region, "domain": gateway.wildcard_domain, "default": False, + "public_ip": True, }, } ] @@ -124,6 +125,7 @@ async def test_get(self, test_db, session: AsyncSession): "region": gateway.region, "domain": gateway.wildcard_domain, "default": False, + "public_ip": True, }, } @@ -203,6 +205,7 @@ async def test_create_gateway(self, test_db, session: AsyncSession): "region": "us", "domain": None, "default": True, + "public_ip": True, }, } @@ -257,6 +260,7 @@ async def test_create_gateway_without_name(self, test_db, session: AsyncSession) "region": "us", "domain": None, "default": True, + "public_ip": True, }, } @@ -391,6 +395,7 @@ async def test_set_default_gateway(self, test_db, session: AsyncSession): "region": gateway.region, "domain": gateway.wildcard_domain, "default": True, + "public_ip": True, }, } @@ -498,6 +503,7 @@ def get_backend(_, backend_type): "region": gateway_gcp.region, "domain": gateway_gcp.wildcard_domain, "default": False, + "public_ip": True, }, } ] @@ -557,6 +563,7 @@ async def test_set_wildcard_domain(self, test_db, session: AsyncSession): "region": gateway.region, "domain": "test.com", "default": False, + "public_ip": True, }, }