diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 70b520311..0cad6b225 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -159,7 +159,7 @@ jobs: run: pip install . - name: Generate json schema run: | - python -c "from dstack._internal.core.models.configurations import RunConfiguration; print(RunConfiguration.schema_json(indent=2))" > configuration.json + python -c "from dstack._internal.core.models.configurations import DstackConfiguration; print(DstackConfiguration.schema_json(indent=2))" > configuration.json python -c "from dstack._internal.core.models.profiles import ProfilesConfig; print(ProfilesConfig.schema_json(indent=2))" > profiles.json - name: Upload json schema to S3 run: | diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index be3c666b9..f70ecc176 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -223,7 +223,7 @@ jobs: run: pip install . - name: Generate json schema run: | - python -c "from dstack._internal.core.models.configurations import RunConfiguration; print(RunConfiguration.schema_json(indent=2))" > configuration.json + python -c "from dstack._internal.core.models.configurations import DstackConfiguration; print(DstackConfiguration.schema_json(indent=2))" > configuration.json python -c "from dstack._internal.core.models.profiles import ProfilesConfig; print(ProfilesConfig.schema_json(indent=2))" > profiles.json - name: Upload json schema to S3 run: | diff --git a/docs/docs/reference/cli/index.md b/docs/docs/reference/cli/index.md index 97b97f3bb..8c462bdfd 100644 --- a/docs/docs/reference/cli/index.md +++ b/docs/docs/reference/cli/index.md @@ -59,6 +59,24 @@ $ dstack run . --help If there are large files, consider creating a `.gitignore` file to exclude them for better performance. +### dstack apply + +This command applies a given configuration. If a resources does not exist, `dstack apply` creates the resource. +If a resource exists, `dstack apply` updates the resource in-place or re-creates the resource if the update is not possible. + +
+ +```shell +$ dstack apply --help +#GENERATE# +``` + +
+ +!!! info "NOTE:" + The `dstack apply` command currently supports only `gateway` configurations. + Support for other configuration types is coming soon. + ### dstack ps This command shows the status of runs. diff --git a/docs/docs/reference/dstack.yml/dev-environment.md b/docs/docs/reference/dstack.yml/dev-environment.md index 1aba1ed37..6768a65b6 100644 --- a/docs/docs/reference/dstack.yml/dev-environment.md +++ b/docs/docs/reference/dstack.yml/dev-environment.md @@ -2,9 +2,10 @@ The `dev-environment` configuration type allows running [dev environments](../../concepts/dev-environments.md). -> Configuration files must have a name ending with `.dstack.yml` (e.g., `.dstack.yml` or `dev.dstack.yml` are both acceptable) -> and can be located in the project's root directory or any nested folder. -> Any configuration can be run via [`dstack run . -f PATH`](../cli/index.md#dstack-run). +!!! info "Filename" + Configuration files must have a name ending with `.dstack.yml` (e.g., `.dstack.yml` or `serve.dstack.yml` are both acceptable) + and can be located in the project's root directory or any nested folder. + Any configuration can be run via [`dstack run`](../cli/index.md#dstack-run). ## Examples diff --git a/docs/docs/reference/dstack.yml/gateway.md b/docs/docs/reference/dstack.yml/gateway.md new file mode 100644 index 000000000..a72679286 --- /dev/null +++ b/docs/docs/reference/dstack.yml/gateway.md @@ -0,0 +1,31 @@ +# gateway + +The `gateway` configuration type allows creating and updating [gateways](../../concepts/services.md). + +!!! info "Filename" + Configuration files must have a name ending with `.dstack.yml` (e.g., `.dstack.yml` or `serve.dstack.yml` are both acceptable) + and can be located in the project's root directory or any nested folder. + Any configuration can be applied via [`dstack apply`](../cli/index.md#dstack-apply). + +## Examples + +
+ +```yaml +type: gateway +name: example-gateway +backend: aws +region: eu-west-1 +domain: '*.example.com' +``` + +
+ + +## Root reference + +#SCHEMA# dstack._internal.core.models.gateways.GatewayConfiguration + overrides: + show_root_heading: false + type: + required: true diff --git a/docs/docs/reference/dstack.yml/task.md b/docs/docs/reference/dstack.yml/task.md index 42282a161..cb080d162 100644 --- a/docs/docs/reference/dstack.yml/task.md +++ b/docs/docs/reference/dstack.yml/task.md @@ -2,9 +2,10 @@ The `task` configuration type allows running [tasks](../../concepts/tasks.md). -> Configuration files must have a name ending with `.dstack.yml` (e.g., `.dstack.yml` or `train.dstack.yml` are both acceptable) -> and can be located in the project's root directory or any nested folder. -> Any configuration can be run via [`dstack run . -f PATH`](../cli/index.md#dstack-run). +!!! info "Filename" + Configuration files must have a name ending with `.dstack.yml` (e.g., `.dstack.yml` or `serve.dstack.yml` are both acceptable) + and can be located in the project's root directory or any nested folder. + Any configuration can be run via [`dstack run`](../cli/index.md#dstack-run). ## Examples diff --git a/mkdocs.yml b/mkdocs.yml index 11a32bbf6..14aec5c19 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -199,6 +199,7 @@ nav: - dev-environment: docs/reference/dstack.yml/dev-environment.md - task: docs/reference/dstack.yml/task.md - service: docs/reference/dstack.yml/service.md + - gateway: docs/reference/dstack.yml/gateway.md - profiles.yml: docs/reference/profiles.yml.md - CLI: docs/reference/cli/index.md - server/config.yml: docs/reference/server/config.yml.md diff --git a/src/dstack/_internal/cli/commands/apply.py b/src/dstack/_internal/cli/commands/apply.py new file mode 100644 index 000000000..645baca76 --- /dev/null +++ b/src/dstack/_internal/cli/commands/apply.py @@ -0,0 +1,60 @@ +import argparse +from pathlib import Path + +import yaml + +from dstack._internal.cli.commands import APIBaseCommand +from dstack._internal.cli.services.configurators import ( + get_apply_configurator_class, +) +from dstack._internal.cli.utils.common import cli_error +from dstack._internal.core.errors import ConfigurationError +from dstack._internal.core.models.configurations import ( + AnyApplyConfiguration, + parse_apply_configuration, +) + + +class ApplyCommand(APIBaseCommand): + NAME = "apply" + DESCRIPTION = "Apply dstack configuration" + + def _register(self): + super()._register() + self._parser.add_argument( + "configuration_file", + help="The path to the configuration file", + ) + self._parser.add_argument( + "--force", + help="Force apply when no changes detected", + action="store_true", + ) + self._parser.add_argument( + "-y", + "--yes", + help="Do not ask for confirmation", + action="store_true", + ) + + def _command(self, args: argparse.Namespace): + super()._command(args) + try: + configuration = _load_configuration(args.configuration_file) + except ConfigurationError as e: + raise cli_error(e) + configurator_class = get_apply_configurator_class(configuration.type) + configurator = configurator_class(api_client=self.api) + configurator.apply_configuration(conf=configuration, args=args) + + +def _load_configuration(configuration_file: str) -> AnyApplyConfiguration: + configuration_path = Path(configuration_file) + if not configuration_path.exists(): + raise ConfigurationError(f"Configuration file {configuration_file} does not exist") + try: + with open(configuration_path, "r") as f: + conf = parse_apply_configuration(yaml.safe_load(f)) + except OSError: + raise ConfigurationError(f"Failed to load configuration from {configuration_path}") + return conf diff --git a/src/dstack/_internal/cli/commands/pool.py b/src/dstack/_internal/cli/commands/pool.py index 4c570fd15..7e95f820d 100644 --- a/src/dstack/_internal/cli/commands/pool.py +++ b/src/dstack/_internal/cli/commands/pool.py @@ -11,7 +11,7 @@ from dstack._internal.cli.commands import APIBaseCommand from dstack._internal.cli.services.args import cpu_spec, disk_spec, gpu_spec, memory_spec -from dstack._internal.cli.services.configurators.profile import ( +from dstack._internal.cli.services.profile import ( apply_profile_args, register_profile_args, ) diff --git a/src/dstack/_internal/cli/commands/run.py b/src/dstack/_internal/cli/commands/run.py index 510c972b9..a706fa284 100644 --- a/src/dstack/_internal/cli/commands/run.py +++ b/src/dstack/_internal/cli/commands/run.py @@ -5,18 +5,18 @@ from typing import Optional from dstack._internal.cli.commands import APIBaseCommand -from dstack._internal.cli.services.configurators.profile import ( - apply_profile_args, - register_profile_args, -) from dstack._internal.cli.services.configurators.run import ( BaseRunConfigurator, run_configurators_mapping, ) +from dstack._internal.cli.services.profile import ( + apply_profile_args, + register_profile_args, +) from dstack._internal.cli.utils.common import confirm_ask, console from dstack._internal.cli.utils.run import print_run_plan from dstack._internal.core.errors import CLIError, ConfigurationError, ServerClientError -from dstack._internal.core.models.configurations import ConfigurationType +from dstack._internal.core.models.configurations import RunConfigurationType from dstack._internal.core.models.runs import JobTerminationReason from dstack._internal.core.services.configs import ConfigManager from dstack._internal.utils.logging import get_logger @@ -39,7 +39,7 @@ def _register(self): "-h", "--help", nargs="?", - type=ConfigurationType, + type=RunConfigurationType, default=NOTSET, help="Show this help message and exit. TYPE is one of [code]task[/], [code]dev-environment[/], [code]service[/]", dest="help", @@ -83,7 +83,7 @@ def _register(self): def _command(self, args: argparse.Namespace): if args.help is not NOTSET: if args.help is not None: - run_configurators_mapping[ConfigurationType(args.help)].register(self._parser) + run_configurators_mapping[RunConfigurationType(args.help)].register(self._parser) else: BaseRunConfigurator.register(self._parser) self._parser.print_help() @@ -102,7 +102,7 @@ def _command(self, args: argparse.Namespace): apply_profile_args(args, conf) logger.debug("Configuration loaded: %s", configuration_path) parser = argparse.ArgumentParser() - configurator = run_configurators_mapping[ConfigurationType(conf.type)] + configurator = run_configurators_mapping[RunConfigurationType(conf.type)] configurator.register(parser) known, unknown = parser.parse_known_args(args.unknown) configurator.apply(known, unknown, conf) @@ -176,7 +176,7 @@ def _command(self, args: argparse.Namespace): ) if run.status in (RunStatus.RUNNING, RunStatus.DONE): - if run._run.run_spec.configuration.type == ConfigurationType.SERVICE.value: + if run._run.run_spec.configuration.type == RunConfigurationType.SERVICE.value: console.print( f"Service is published at [link={run.service_url}]{run.service_url}[/]\n" ) diff --git a/src/dstack/_internal/cli/main.py b/src/dstack/_internal/cli/main.py index c1dc1da1b..6d40b7d98 100644 --- a/src/dstack/_internal/cli/main.py +++ b/src/dstack/_internal/cli/main.py @@ -3,6 +3,7 @@ from rich.markup import escape from rich_argparse import RichHelpFormatter +from dstack._internal.cli.commands.apply import ApplyCommand from dstack._internal.cli.commands.config import ConfigCommand from dstack._internal.cli.commands.gateway import GatewayCommand from dstack._internal.cli.commands.init import InitCommand @@ -50,6 +51,7 @@ def main(): parser.set_defaults(func=lambda _: parser.print_help()) subparsers = parser.add_subparsers(metavar="COMMAND") + ApplyCommand.register(subparsers) ConfigCommand.register(subparsers) GatewayCommand.register(subparsers) PoolCommand.register(subparsers) diff --git a/src/dstack/_internal/cli/services/configurators/__init__.py b/src/dstack/_internal/cli/services/configurators/__init__.py index e69de29bb..152e0a3a7 100644 --- a/src/dstack/_internal/cli/services/configurators/__init__.py +++ b/src/dstack/_internal/cli/services/configurators/__init__.py @@ -0,0 +1,13 @@ +from typing import Dict, Type + +from dstack._internal.cli.services.configurators.base import BaseApplyConfigurator +from dstack._internal.cli.services.configurators.gateway import GatewayConfigurator +from dstack._internal.core.models.configurations import ApplyConfigurationType + +apply_configurators_mapping: Dict[ApplyConfigurationType, Type[BaseApplyConfigurator]] = { + cls.TYPE: cls for cls in [GatewayConfigurator] +} + + +def get_apply_configurator_class(configurator_type: str) -> Type[BaseApplyConfigurator]: + return apply_configurators_mapping[ApplyConfigurationType(configurator_type)] diff --git a/src/dstack/_internal/cli/services/configurators/base.py b/src/dstack/_internal/cli/services/configurators/base.py new file mode 100644 index 000000000..1067fa87a --- /dev/null +++ b/src/dstack/_internal/cli/services/configurators/base.py @@ -0,0 +1,28 @@ +import argparse +from abc import ABC, abstractmethod +from typing import List + +from dstack._internal.core.models.configurations import ( + AnyApplyConfiguration, + ApplyConfigurationType, +) +from dstack.api._public import Client + + +class BaseApplyConfigurator(ABC): + TYPE: ApplyConfigurationType + + def __init__(self, api_client: Client): + self.api_client = api_client + + @abstractmethod + def apply_configuration(self, conf: AnyApplyConfiguration, args: argparse.Namespace): + pass + + def register_args(self, parser: argparse.ArgumentParser): + pass + + def apply_args( + self, args: argparse.Namespace, unknown: List[str], conf: AnyApplyConfiguration + ): + pass diff --git a/src/dstack/_internal/cli/services/configurators/gateway.py b/src/dstack/_internal/cli/services/configurators/gateway.py new file mode 100644 index 000000000..e1920b5c7 --- /dev/null +++ b/src/dstack/_internal/cli/services/configurators/gateway.py @@ -0,0 +1,58 @@ +import argparse + +from dstack._internal.cli.services.configurators.base import BaseApplyConfigurator +from dstack._internal.cli.utils.common import confirm_ask, console +from dstack._internal.cli.utils.gateway import print_gateways_table +from dstack._internal.core.errors import ResourceNotExistsError +from dstack._internal.core.models.configurations import ApplyConfigurationType +from dstack._internal.core.models.gateways import GatewayConfiguration + + +class GatewayConfigurator(BaseApplyConfigurator): + TYPE: ApplyConfigurationType = ApplyConfigurationType.GATEWAY + + def apply_configuration(self, conf: GatewayConfiguration, args: argparse.Namespace): + # TODO: Show apply plan + # TODO: Update gateway in-place when domain/default change + confirmed = False + if conf.name is not None: + try: + gateway = self.api_client.client.gateways.get( + project_name=self.api_client.project, gateway_name=conf.name + ) + except ResourceNotExistsError: + pass + else: + if gateway.configuration == conf: + if not args.force: + console.print( + "Gateway configuration has not changed. Use --force to recreate the gateway." + ) + return + if not args.yes and not confirm_ask( + "Gateway configuration has not changed. Re-create the gateway?" + ): + console.print("\nExiting...") + return + elif not args.yes and not confirm_ask( + f"Gateway [code]{conf.name}[/] already exist. Re-create the gateway?" + ): + console.print("\nExiting...") + return + confirmed = True + with console.status("Deleting gateway..."): + self.api_client.client.gateways.delete( + project_name=self.api_client.project, gateways_names=[conf.name] + ) + if not confirmed and not args.yes: + if not confirm_ask( + f"Gateway [code]{conf.name}[/] does not exist yet. Create the gateway?" + ): + console.print("\nExiting...") + return + with console.status("Creating gateway..."): + gateway = self.api_client.client.gateways.create( + project_name=self.api_client.project, + configuration=conf, + ) + print_gateways_table([gateway]) diff --git a/src/dstack/_internal/cli/services/configurators/run.py b/src/dstack/_internal/cli/services/configurators/run.py index ba6687aaa..1dc1cd692 100644 --- a/src/dstack/_internal/cli/services/configurators/run.py +++ b/src/dstack/_internal/cli/services/configurators/run.py @@ -11,10 +11,10 @@ from dstack._internal.core.models.configurations import ( BaseConfiguration, BaseConfigurationWithPorts, - ConfigurationType, DevEnvironmentConfiguration, EnvSentinel, PortMapping, + RunConfigurationType, ServiceConfiguration, TaskConfiguration, ) @@ -22,7 +22,7 @@ class BaseRunConfigurator: - TYPE: ConfigurationType = None + TYPE: RunConfigurationType = None @classmethod def register(cls, parser: argparse.ArgumentParser): @@ -102,7 +102,7 @@ def apply(cls, args: argparse.Namespace, unknown: List[str], conf: BaseConfigura class TaskRunConfigurator(RunWithPortsConfigurator): - TYPE = ConfigurationType.TASK + TYPE = RunConfigurationType.TASK @classmethod def apply(cls, args: argparse.Namespace, unknown: List[str], conf: TaskConfiguration): @@ -112,7 +112,7 @@ def apply(cls, args: argparse.Namespace, unknown: List[str], conf: TaskConfigura class DevEnvironmentRunConfigurator(RunWithPortsConfigurator): - TYPE = ConfigurationType.DEV_ENVIRONMENT + TYPE = RunConfigurationType.DEV_ENVIRONMENT @classmethod def apply( @@ -130,7 +130,7 @@ def apply( class ServiceRunConfigurator(BaseRunConfigurator): - TYPE = ConfigurationType.SERVICE + TYPE = RunConfigurationType.SERVICE @classmethod def apply(cls, args: argparse.Namespace, unknown: List[str], conf: ServiceConfiguration): @@ -169,7 +169,7 @@ def _detect_vscode_version(exe: str = "code") -> Optional[str]: return None -run_configurators_mapping: Dict[ConfigurationType, Type[BaseRunConfigurator]] = { +run_configurators_mapping: Dict[RunConfigurationType, Type[BaseRunConfigurator]] = { cls.TYPE: cls for cls in [ TaskRunConfigurator, diff --git a/src/dstack/_internal/cli/services/configurators/profile.py b/src/dstack/_internal/cli/services/profile.py similarity index 100% rename from src/dstack/_internal/cli/services/configurators/profile.py rename to src/dstack/_internal/cli/services/profile.py diff --git a/src/dstack/_internal/core/backends/aws/compute.py b/src/dstack/_internal/core/backends/aws/compute.py index f2d0da344..3bc3c834b 100644 --- a/src/dstack/_internal/core/backends/aws/compute.py +++ b/src/dstack/_internal/core/backends/aws/compute.py @@ -19,6 +19,7 @@ from dstack._internal.core.models.backends.aws import AWSAccessKeyCreds from dstack._internal.core.models.backends.base import BackendType from dstack._internal.core.models.common import is_core_model_instance +from dstack._internal.core.models.gateways import GatewayComputeConfiguration from dstack._internal.core.models.instances import ( InstanceAvailability, InstanceConfiguration, @@ -192,17 +193,14 @@ def run_job( def create_gateway( self, - instance_name: str, - ssh_key_pub: str, - region: str, - project_id: str, - ) -> JobProvisioningData: - ec2 = self.session.resource("ec2", region_name=region) - ec2_client = self.session.client("ec2", region_name=region) + configuration: GatewayComputeConfiguration, + ) -> LaunchedGatewayInfo: + ec2 = self.session.resource("ec2", region_name=configuration.region) + ec2_client = self.session.client("ec2", region_name=configuration.region) tags = [ - {"Key": "Name", "Value": instance_name}, + {"Key": "Name", "Value": configuration.instance_name}, {"Key": "owner", "Value": "dstack"}, - {"Key": "dstack_project", "Value": project_id}, + {"Key": "dstack_project", "Value": configuration.project_name}, ] if settings.DSTACK_VERSION is not None: tags.append({"Key": "dstack_version", "Value": settings.DSTACK_VERSION}) @@ -212,11 +210,11 @@ def create_gateway( image_id=aws_resources.get_gateway_image_id(ec2_client), instance_type="t2.micro", iam_instance_profile_arn=None, - user_data=get_gateway_user_data(ssh_key_pub), + user_data=get_gateway_user_data(configuration.ssh_key_pub), tags=tags, security_group_id=aws_resources.create_gateway_security_group( ec2_client=ec2_client, - project_id=project_id, + project_id=configuration.project_name, ), spot=False, ) @@ -226,7 +224,7 @@ def create_gateway( instance.reload() # populate instance.public_ip_address return LaunchedGatewayInfo( instance_id=instance.instance_id, - region=region, + region=configuration.region, ip_address=instance.public_ip_address, ) diff --git a/src/dstack/_internal/core/backends/azure/compute.py b/src/dstack/_internal/core/backends/azure/compute.py index 5b7f4b1b3..d53e17320 100644 --- a/src/dstack/_internal/core/backends/azure/compute.py +++ b/src/dstack/_internal/core/backends/azure/compute.py @@ -43,6 +43,7 @@ from dstack._internal.core.backends.base.offers import get_catalog_offers from dstack._internal.core.errors import NoCapacityError from dstack._internal.core.models.backends.base import BackendType +from dstack._internal.core.models.gateways import GatewayComputeConfiguration from dstack._internal.core.models.instances import ( InstanceAvailability, InstanceConfiguration, @@ -183,35 +184,36 @@ def terminate_instance( def create_gateway( self, - instance_name: str, - ssh_key_pub: str, - region: str, - project_id: str, + configuration: GatewayComputeConfiguration, ) -> LaunchedGatewayInfo: - logger.info("Launching %s gateway instance in %s...", instance_name, region) + logger.info( + "Launching %s gateway instance in %s...", + configuration.instance_name, + configuration.region, + ) vm = _launch_instance( compute_client=self._compute_client, subscription_id=self.config.subscription_id, - location=region, + location=configuration.region, resource_group=self.config.resource_group, network_security_group=azure_utils.get_gateway_network_security_group_name( resource_group=self.config.resource_group, - location=region, + location=configuration.region, ), network=azure_utils.get_default_network_name( resource_group=self.config.resource_group, - location=region, + location=configuration.region, ), subnet=azure_utils.get_default_subnet_name( resource_group=self.config.resource_group, - location=region, + location=configuration.region, ), managed_identity=None, image_reference=_get_gateway_image_ref(), vm_size="Standard_B1s", - instance_name=instance_name, - user_data=get_gateway_user_data(ssh_key_pub), - ssh_pub_keys=[ssh_key_pub], + instance_name=configuration.instance_name, + user_data=get_gateway_user_data(configuration.ssh_key_pub), + ssh_pub_keys=[configuration.ssh_key_pub], spot=False, disk_size=30, computer_name="gatewayvm", @@ -225,7 +227,7 @@ def create_gateway( return LaunchedGatewayInfo( instance_id=vm.name, ip_address=public_ip, - region=region, + region=configuration.region, ) diff --git a/src/dstack/_internal/core/backends/base/compute.py b/src/dstack/_internal/core/backends/base/compute.py index 02c2b6186..fc61754da 100644 --- a/src/dstack/_internal/core/backends/base/compute.py +++ b/src/dstack/_internal/core/backends/base/compute.py @@ -9,6 +9,7 @@ import yaml from dstack._internal import settings +from dstack._internal.core.models.gateways import GatewayComputeConfiguration from dstack._internal.core.models.instances import ( InstanceConfiguration, InstanceOfferWithAvailability, @@ -84,10 +85,7 @@ def update_provisioning_data( def create_gateway( self, - instance_name: str, - ssh_key_pub: str, - region: str, - project_id: str, + configuration: GatewayComputeConfiguration, ) -> LaunchedGatewayInfo: raise NotImplementedError() diff --git a/src/dstack/_internal/core/backends/gcp/compute.py b/src/dstack/_internal/core/backends/gcp/compute.py index dfcb798e8..a6cba7149 100644 --- a/src/dstack/_internal/core/backends/gcp/compute.py +++ b/src/dstack/_internal/core/backends/gcp/compute.py @@ -17,6 +17,7 @@ from dstack._internal.core.backends.gcp.config import GCPConfig from dstack._internal.core.errors import ComputeResourceNotFoundError, NoCapacityError from dstack._internal.core.models.backends.base import BackendType +from dstack._internal.core.models.gateways import GatewayComputeConfiguration from dstack._internal.core.models.instances import ( InstanceAvailability, InstanceConfiguration, @@ -176,10 +177,7 @@ def run_job( def create_gateway( self, - instance_name: str, - ssh_key_pub: str, - region: str, - project_id: str, + configuration: GatewayComputeConfiguration, ) -> LaunchedGatewayInfo: gcp_resources.create_gateway_firewall_rules( firewalls_client=self.firewalls_client, @@ -187,7 +185,7 @@ def create_gateway( ) # e2-micro is available in every zone for i in self.regions_client.list(project=self.config.project_id): - if i.name == region: + if i.name == configuration.region: zone = i.zones[0].split("/")[-1] break else: @@ -202,24 +200,24 @@ def create_gateway( machine_type="e2-micro", accelerators=[], spot=False, - user_data=get_gateway_user_data(ssh_key_pub), - authorized_keys=[ssh_key_pub], + user_data=get_gateway_user_data(configuration.ssh_key_pub), + authorized_keys=[configuration.ssh_key_pub], labels={ "owner": "dstack", - "dstack_project": project_id, + "dstack_project": configuration.project_name, }, tags=[gcp_resources.DSTACK_GATEWAY_TAG], - instance_name=instance_name, + instance_name=configuration.instance_name, zone=zone, service_account=None, ) operation = self.instances_client.insert(request=request) gcp_resources.wait_for_extended_operation(operation, "instance creation") instance = self.instances_client.get( - project=self.config.project_id, zone=zone, instance=instance_name + project=self.config.project_id, zone=zone, instance=configuration.instance_name ) return LaunchedGatewayInfo( - instance_id=instance_name, + instance_id=configuration.instance_name, region=zone, # used for instance termination ip_address=instance.network_interfaces[0].access_configs[0].nat_i_p, ) diff --git a/src/dstack/_internal/core/backends/kubernetes/compute.py b/src/dstack/_internal/core/backends/kubernetes/compute.py index 9810c1b5a..61d25ff6c 100644 --- a/src/dstack/_internal/core/backends/kubernetes/compute.py +++ b/src/dstack/_internal/core/backends/kubernetes/compute.py @@ -4,8 +4,7 @@ import time from typing import Dict, List, Optional -# TODO: update import as KNOWN_GPUS becomes public -from gpuhunt._internal.constraints import KNOWN_GPUS +from gpuhunt import KNOWN_GPUS from kubernetes import client from dstack._internal.core.backends.base.compute import ( @@ -22,6 +21,9 @@ ) from dstack._internal.core.errors import ComputeError, GatewayError from dstack._internal.core.models.backends.base import BackendType + +# TODO: update import as KNOWN_GPUS becomes public +from dstack._internal.core.models.gateways import GatewayComputeConfiguration from dstack._internal.core.models.instances import ( Disk, Gpu, @@ -205,7 +207,10 @@ def terminate_instance( if e.status != 404: raise - def create_gateway(self, instance_name: str, ssh_key_pub: str, region: str, project_id: str): + def create_gateway( + self, + configuration: GatewayComputeConfiguration, + ) -> LaunchedGatewayInfo: # Gateway creation is currently limited to Kubernetes with Load Balancer support. # If the cluster does not support Load Balancer, the service will be provisioned but # the external IP/hostname will never be allocated. @@ -215,7 +220,8 @@ def create_gateway(self, instance_name: str, ssh_key_pub: str, region: str, proj # TODO: By default EKS creates a Classic Load Balancer for Load Balancer services. # Consider deploying an NLB. It seems it requires some extra configuration on the cluster: # https://docs.aws.amazon.com/eks/latest/userguide/network-load-balancing.html - commands = _get_gateway_commands(authorized_keys=[ssh_key_pub]) + instance_name = configuration.instance_name + commands = _get_gateway_commands(authorized_keys=[configuration.ssh_key_pub]) self.api.create_namespaced_pod( namespace=DEFAULT_NAMESPACE, body=client.V1Pod( diff --git a/src/dstack/_internal/core/models/configurations.py b/src/dstack/_internal/core/models/configurations.py index 781610017..591f1724e 100644 --- a/src/dstack/_internal/core/models/configurations.py +++ b/src/dstack/_internal/core/models/configurations.py @@ -7,7 +7,7 @@ from dstack._internal.core.errors import ConfigurationError from dstack._internal.core.models.common import CoreModel, Duration -from dstack._internal.core.models.gateways import AnyModel +from dstack._internal.core.models.gateways import AnyModel, GatewayConfiguration from dstack._internal.core.models.profiles import ProfileParams from dstack._internal.core.models.repos.base import Repo from dstack._internal.core.models.repos.virtual import VirtualRepo @@ -17,7 +17,7 @@ ValidPort = conint(gt=0, le=65536) -class ConfigurationType(str, Enum): +class RunConfigurationType(str, Enum): DEV_ENVIRONMENT = "dev-environment" TASK = "task" SERVICE = "service" @@ -319,13 +319,38 @@ class RunConfiguration(CoreModel): Field(discriminator="type"), ] - class Config: - schema_extra = {"$schema": "http://json-schema.org/draft-07/schema#"} - -def parse(data: dict) -> AnyRunConfiguration: +def parse_run_configuration(data: dict) -> AnyRunConfiguration: try: conf = RunConfiguration.parse_obj(data).__root__ except ValidationError as e: raise ConfigurationError(e) return conf + + +class ApplyConfigurationType(str, Enum): + GATEWAY = "gateway" + + +AnyApplyConfiguration = GatewayConfiguration + + +def parse_apply_configuration(data: dict) -> AnyApplyConfiguration: + try: + conf = GatewayConfiguration.parse_obj(data) + except ValidationError as e: + raise ConfigurationError(e) + return conf + + +AnyDstackConfiguration = Union[AnyRunConfiguration, GatewayConfiguration] + + +class DstackConfiguration(CoreModel): + __root__: Annotated[ + AnyDstackConfiguration, + Field(discriminator="type"), + ] + + class Config: + schema_extra = {"$schema": "http://json-schema.org/draft-07/schema#"} diff --git a/src/dstack/_internal/core/models/gateways.py b/src/dstack/_internal/core/models/gateways.py index 3831fa5e1..6df853b08 100644 --- a/src/dstack/_internal/core/models/gateways.py +++ b/src/dstack/_internal/core/models/gateways.py @@ -8,7 +8,30 @@ from dstack._internal.core.models.common import CoreModel +class GatewayConfiguration(CoreModel): + type: Literal["gateway"] = "gateway" + name: Annotated[Optional[str], Field(description="The gateway name")] = None + default: Annotated[bool, Field(description="Make the gateway default")] = False + backend: Annotated[BackendType, Field(description="The gateway backend")] + region: Annotated[str, Field(description="The gateway region")] + domain: Annotated[ + Optional[str], Field(description="The gateway domain, e.g. `*.example.com`") + ] = None + # public_ip: Annotated[bool, Field(description="Allocate public IP for the gateway")] = True + + +class GatewayComputeConfiguration(CoreModel): + project_name: str + instance_name: str + backend: BackendType + region: str + public_ip: bool + ssh_key_pub: str + + class Gateway(CoreModel): + # TODO: configuration fields are duplicated on top-level for backward compatibility with 0.18.x + # Remove in 0.19 name: str ip_address: Optional[str] instance_id: Optional[str] @@ -17,6 +40,7 @@ class Gateway(CoreModel): default: bool created_at: datetime.datetime backend: BackendType + configuration: GatewayConfiguration class BaseChatModel(CoreModel): diff --git a/src/dstack/_internal/server/migrations/versions/58aa5162dcc3_add_gatewaymodel_configuration.py b/src/dstack/_internal/server/migrations/versions/58aa5162dcc3_add_gatewaymodel_configuration.py new file mode 100644 index 000000000..5a25adaee --- /dev/null +++ b/src/dstack/_internal/server/migrations/versions/58aa5162dcc3_add_gatewaymodel_configuration.py @@ -0,0 +1,32 @@ +"""Add GatewayModel.configuration + +Revision ID: 58aa5162dcc3 +Revises: 1e3fb39ef74b +Create Date: 2024-05-15 11:04:58.848554 + +""" + +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision = "58aa5162dcc3" +down_revision = "1e3fb39ef74b" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table("gateways", schema=None) as batch_op: + batch_op.add_column(sa.Column("configuration", sa.Text(), nullable=True)) + + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table("gateways", schema=None) as batch_op: + batch_op.drop_column("configuration") + + # ### end Alembic commands ### diff --git a/src/dstack/_internal/server/models.py b/src/dstack/_internal/server/models.py index 1615f4a99..57ee37ff3 100644 --- a/src/dstack/_internal/server/models.py +++ b/src/dstack/_internal/server/models.py @@ -230,6 +230,7 @@ class GatewayModel(BaseModel): name: Mapped[str] = mapped_column(String(100)) region: Mapped[str] = mapped_column(String(100)) wildcard_domain: Mapped[str] = mapped_column(String(100), nullable=True) + configuration: Mapped[Optional[str]] = mapped_column(Text) created_at: Mapped[datetime] = mapped_column(DateTime, default=get_current_datetime) project_id: Mapped[uuid.UUID] = mapped_column(ForeignKey("projects.id", ondelete="CASCADE")) diff --git a/src/dstack/_internal/server/routers/gateways.py b/src/dstack/_internal/server/routers/gateways.py index 6de22aa6a..4215fa90a 100644 --- a/src/dstack/_internal/server/routers/gateways.py +++ b/src/dstack/_internal/server/routers/gateways.py @@ -46,9 +46,7 @@ async def create_gateway( return await gateways.create_gateway( session=session, project=project, - name=body.name, - backend_type=body.backend_type, - region=body.region, + configuration=body.configuration, ) diff --git a/src/dstack/_internal/server/schemas/gateways.py b/src/dstack/_internal/server/schemas/gateways.py index 4183b862f..99930b95e 100644 --- a/src/dstack/_internal/server/schemas/gateways.py +++ b/src/dstack/_internal/server/schemas/gateways.py @@ -1,13 +1,34 @@ -from typing import List, Optional +from typing import Dict, List, Optional + +from pydantic import root_validator from dstack._internal.core.models.backends.base import BackendType from dstack._internal.core.models.common import CoreModel +from dstack._internal.core.models.gateways import GatewayConfiguration class CreateGatewayRequest(CoreModel): name: Optional[str] - backend_type: BackendType - region: str + backend_type: Optional[BackendType] + region: Optional[str] + configuration: Optional[GatewayConfiguration] + + @root_validator + def fill_configuration(cls, values: Dict) -> Dict: + if values.get("configuration", None) is not None: + return values + backend_type = values.get("backend_type", None) + region = values.get("region", None) + if backend_type is None: + raise ValueError("backend_type must be specified") + if region is None: + raise ValueError("region must be specified") + values["configuration"] = GatewayConfiguration( + name=values.get("name", None), + backend=backend_type, + region=region, + ) + return values class GetGatewayRequest(CoreModel): diff --git a/src/dstack/_internal/server/services/gateways/__init__.py b/src/dstack/_internal/server/services/gateways/__init__.py index 78cb3b76d..ff960c7f3 100644 --- a/src/dstack/_internal/server/services/gateways/__init__.py +++ b/src/dstack/_internal/server/services/gateways/__init__.py @@ -23,7 +23,11 @@ SSHError, ) from dstack._internal.core.models.backends.base import BackendType -from dstack._internal.core.models.gateways import Gateway +from dstack._internal.core.models.gateways import ( + Gateway, + GatewayComputeConfiguration, + GatewayConfiguration, +) from dstack._internal.core.models.runs import ( Run, RunSpec, @@ -84,9 +88,7 @@ async def get_project_default_gateway( async def create_gateway_compute( backend_compute: Compute, - instance_name: str, - region: str, - project_id: str, + configuration: GatewayComputeConfiguration, backend_id: Optional[uuid.UUID] = None, ) -> GatewayComputeModel: private_bytes, public_bytes = generate_rsa_key_pair_bytes() @@ -95,10 +97,7 @@ async def create_gateway_compute( info = await run_async( backend_compute.create_gateway, - instance_name, - gateway_ssh_public_key, - region, - project_id, + configuration, ) return GatewayComputeModel( @@ -114,38 +113,45 @@ async def create_gateway_compute( async def create_gateway( session: AsyncSession, project: ProjectModel, - name: Optional[str], - backend_type: BackendType, - region: str, + configuration: GatewayConfiguration, ) -> Gateway: # TODO: Gateay creation may take significant time. Make it asynchronous. for backend_model, backend in await get_project_backends_with_models(project): - if backend_model.type == backend_type: + if backend_model.type == configuration.backend: break else: raise ResourceNotExistsError() - if name is None: - name = await generate_gateway_name(session=session, project=project) + if configuration.name is None: + configuration.name = await generate_gateway_name(session=session, project=project) gateway = GatewayModel( # reserve name - name=name, - region=region, + name=configuration.name, + region=configuration.region, project_id=project.id, backend_id=backend_model.id, + wildcard_domain=configuration.domain, + configuration=configuration.json(), ) session.add(gateway) await session.commit() - if project.default_gateway is None: - await set_default_gateway(session=session, project=project, name=name) + if project.default_gateway is None or configuration.default: + await set_default_gateway(session=session, project=project, name=configuration.name) + + compute_configuration = GatewayComputeConfiguration( + project_name=project.name, + instance_name=gateway.name, + backend=configuration.backend, + region=configuration.region, + public_ip=True, + ssh_key_pub=project.name, + ) try: gateway.gateway_compute = await create_gateway_compute( backend_compute=backend.compute(), - instance_name=gateway.name, - region=region, - project_id=project.name, + configuration=compute_configuration, backend_id=backend_model.id, ) session.add(gateway) @@ -155,7 +161,7 @@ async def create_gateway( await session.execute( delete(GatewayModel).where( GatewayModel.project_id == project.id, - GatewayModel.name == name, + GatewayModel.name == configuration.name, ) ) await session.commit() @@ -545,6 +551,18 @@ def gateway_model_to_gateway(gateway_model: GatewayModel) -> Gateway: backend_type = gateway_model.backend.type if gateway_model.backend.type == BackendType.DSTACK: backend_type = BackendType.AWS + if gateway_model.configuration is not None: + configuration = GatewayConfiguration.__response__.parse_raw(gateway_model.configuration) + else: + # Handle gateways created before GatewayConfiguration was introduced + configuration = GatewayConfiguration( + name=gateway_model.name, + default=False, + backend=gateway_model.backend.type, + region=gateway_model.region, + domain=gateway_model.wildcard_domain, + ) + configuration.default = gateway_model.project.default_gateway_id == gateway_model.id return Gateway( name=gateway_model.name, ip_address=ip_address, @@ -554,4 +572,5 @@ def gateway_model_to_gateway(gateway_model: GatewayModel) -> Gateway: default=gateway_model.project.default_gateway_id == gateway_model.id, created_at=gateway_model.created_at.replace(tzinfo=timezone.utc), backend=backend_type, + configuration=configuration, ) diff --git a/src/dstack/_internal/server/services/jobs/__init__.py b/src/dstack/_internal/server/services/jobs/__init__.py index 6c23695b0..9833903b3 100644 --- a/src/dstack/_internal/server/services/jobs/__init__.py +++ b/src/dstack/_internal/server/services/jobs/__init__.py @@ -12,7 +12,7 @@ import dstack._internal.server.services.gateways as gateways from dstack._internal.core.errors import ComputeResourceNotFoundError, SSHError from dstack._internal.core.models.backends.base import BackendType -from dstack._internal.core.models.configurations import ConfigurationType +from dstack._internal.core.models.configurations import RunConfigurationType from dstack._internal.core.models.instances import RemoteConnectionInfo from dstack._internal.core.models.runs import ( InstanceStatus, @@ -133,7 +133,7 @@ def delay_job_instance_termination(job_model: JobModel): def _get_job_configurator(run_spec: RunSpec) -> JobConfigurator: - configuration_type = ConfigurationType(run_spec.configuration.type) + configuration_type = RunConfigurationType(run_spec.configuration.type) configurator_class = _configuration_type_to_configurator_class_map[configuration_type] return configurator_class(run_spec) diff --git a/src/dstack/_internal/server/services/jobs/configurators/base.py b/src/dstack/_internal/server/services/jobs/configurators/base.py index 7a9625bd0..ddfdb3011 100644 --- a/src/dstack/_internal/server/services/jobs/configurators/base.py +++ b/src/dstack/_internal/server/services/jobs/configurators/base.py @@ -8,10 +8,10 @@ import dstack.version as version from dstack._internal.core.errors import DockerRegistryError, ServerClientError from dstack._internal.core.models.configurations import ( - ConfigurationType, PortMapping, PythonVersion, RegistryAuth, + RunConfigurationType, ) from dstack._internal.core.models.profiles import SpotPolicy from dstack._internal.core.models.runs import ( @@ -44,7 +44,7 @@ def get_default_image(python_version: str) -> str: class JobConfigurator(ABC): - TYPE: ConfigurationType + TYPE: RunConfigurationType def __init__(self, run_spec: RunSpec): self.run_spec = run_spec diff --git a/src/dstack/_internal/server/services/jobs/configurators/dev.py b/src/dstack/_internal/server/services/jobs/configurators/dev.py index 760390bff..2dbb4e21b 100644 --- a/src/dstack/_internal/server/services/jobs/configurators/dev.py +++ b/src/dstack/_internal/server/services/jobs/configurators/dev.py @@ -1,6 +1,6 @@ from typing import List, Optional -from dstack._internal.core.models.configurations import ConfigurationType, PortMapping +from dstack._internal.core.models.configurations import PortMapping, RunConfigurationType from dstack._internal.core.models.profiles import ProfileRetryPolicy, SpotPolicy from dstack._internal.core.models.runs import RetryPolicy, RunSpec from dstack._internal.server.services.jobs.configurators.base import JobConfigurator @@ -15,7 +15,7 @@ class DevEnvironmentJobConfigurator(JobConfigurator): - TYPE: ConfigurationType = ConfigurationType.DEV_ENVIRONMENT + TYPE: RunConfigurationType = RunConfigurationType.DEV_ENVIRONMENT def __init__(self, run_spec: RunSpec): self.ide = VSCodeDesktop( diff --git a/src/dstack/_internal/server/services/jobs/configurators/service.py b/src/dstack/_internal/server/services/jobs/configurators/service.py index ed783fa58..718caabe0 100644 --- a/src/dstack/_internal/server/services/jobs/configurators/service.py +++ b/src/dstack/_internal/server/services/jobs/configurators/service.py @@ -1,13 +1,13 @@ from typing import List, Optional -from dstack._internal.core.models.configurations import ConfigurationType, PortMapping +from dstack._internal.core.models.configurations import PortMapping, RunConfigurationType from dstack._internal.core.models.profiles import ProfileRetryPolicy, SpotPolicy from dstack._internal.core.models.runs import RetryPolicy from dstack._internal.server.services.jobs.configurators.base import JobConfigurator class ServiceJobConfigurator(JobConfigurator): - TYPE: ConfigurationType = ConfigurationType.SERVICE + TYPE: RunConfigurationType = RunConfigurationType.SERVICE def _shell_commands(self) -> List[str]: return self.run_spec.configuration.commands diff --git a/src/dstack/_internal/server/services/jobs/configurators/task.py b/src/dstack/_internal/server/services/jobs/configurators/task.py index d1eb2d97e..de8fad9be 100644 --- a/src/dstack/_internal/server/services/jobs/configurators/task.py +++ b/src/dstack/_internal/server/services/jobs/configurators/task.py @@ -1,6 +1,6 @@ from typing import List, Optional -from dstack._internal.core.models.configurations import ConfigurationType, PortMapping +from dstack._internal.core.models.configurations import PortMapping, RunConfigurationType from dstack._internal.core.models.profiles import ProfileRetryPolicy, SpotPolicy from dstack._internal.core.models.runs import JobSpec, RetryPolicy from dstack._internal.server.services.jobs.configurators.base import JobConfigurator @@ -9,7 +9,7 @@ class TaskJobConfigurator(JobConfigurator): - TYPE: ConfigurationType = ConfigurationType.TASK + TYPE: RunConfigurationType = RunConfigurationType.TASK async def get_job_specs(self, replica_num: int) -> List[JobSpec]: job_specs = [] diff --git a/src/dstack/api/server/_gateways.py b/src/dstack/api/server/_gateways.py index 3d77f208d..f2800c2fe 100644 --- a/src/dstack/api/server/_gateways.py +++ b/src/dstack/api/server/_gateways.py @@ -3,7 +3,7 @@ from pydantic import parse_obj_as from dstack._internal.core.models.backends.base import BackendType -from dstack._internal.core.models.gateways import Gateway +from dstack._internal.core.models.gateways import Gateway, GatewayConfiguration from dstack._internal.server.schemas.gateways import ( CreateGatewayRequest, DeleteGatewaysRequest, @@ -24,14 +24,22 @@ def get(self, project_name: str, gateway_name: str) -> Gateway: resp = self._request(f"/api/project/{project_name}/gateways/get", body=body.json()) return parse_obj_as(Gateway.__response__, resp.json()) + # gateway_name, backend_type, region are left for backward-compatibility with 0.18.x + # TODO: Remove in 0.19 def create( self, project_name: str, - gateway_name: Optional[str], - backend_type: BackendType, - region: str, + gateway_name: Optional[str] = None, + backend_type: Optional[BackendType] = None, + region: Optional[str] = None, + configuration: Optional[GatewayConfiguration] = None, ) -> Gateway: - body = CreateGatewayRequest(name=gateway_name, backend_type=backend_type, region=region) + body = CreateGatewayRequest( + name=gateway_name, + backend_type=backend_type, + region=region, + configuration=configuration, + ) resp = self._request(f"/api/project/{project_name}/gateways/create", body=body.json()) return parse_obj_as(Gateway.__response__, resp.json()) diff --git a/src/dstack/api/utils.py b/src/dstack/api/utils.py index 19e074131..9471ce74b 100644 --- a/src/dstack/api/utils.py +++ b/src/dstack/api/utils.py @@ -5,7 +5,9 @@ from dstack._internal.core.errors import ConfigurationError from dstack._internal.core.models.configurations import AnyRunConfiguration -from dstack._internal.core.models.configurations import parse as parse_configuration +from dstack._internal.core.models.configurations import ( + parse_run_configuration as parse_configuration, +) from dstack._internal.core.models.profiles import Profile, ProfilesConfig from dstack._internal.utils.common import get_dstack_dir from dstack._internal.utils.path import PathLike, path_in_dir diff --git a/src/tests/_internal/cli/services/configurators/test_profile.py b/src/tests/_internal/cli/services/configurators/test_profile.py index ead9568d7..d9a047bdf 100644 --- a/src/tests/_internal/cli/services/configurators/test_profile.py +++ b/src/tests/_internal/cli/services/configurators/test_profile.py @@ -1,7 +1,7 @@ import argparse from typing import List, Tuple -from dstack._internal.cli.services.configurators.profile import ( +from dstack._internal.cli.services.profile import ( apply_profile_args, register_profile_args, ) diff --git a/src/tests/_internal/core/models/test_configurations.py b/src/tests/_internal/core/models/test_configurations.py index 92601e356..90dfdc11e 100644 --- a/src/tests/_internal/core/models/test_configurations.py +++ b/src/tests/_internal/core/models/test_configurations.py @@ -3,7 +3,7 @@ import pytest from dstack._internal.core.errors import ConfigurationError -from dstack._internal.core.models.configurations import RegistryAuth, parse +from dstack._internal.core.models.configurations import RegistryAuth, parse_run_configuration from dstack._internal.core.models.resources import Range @@ -20,15 +20,15 @@ def test_conf(replicas: Any, scaling: Optional[Any] = None): conf["scaling"] = scaling return conf - assert parse(test_conf(1)).replicas == Range(min=1, max=1) - assert parse(test_conf("2")).replicas == Range(min=2, max=2) - assert parse(test_conf("3..3")).replicas == Range(min=3, max=3) + assert parse_run_configuration(test_conf(1)).replicas == Range(min=1, max=1) + assert parse_run_configuration(test_conf("2")).replicas == Range(min=2, max=2) + assert parse_run_configuration(test_conf("3..3")).replicas == Range(min=3, max=3) with pytest.raises( ConfigurationError, match="When you set `replicas` to a range, ensure to specify `scaling`", ): - parse(test_conf("0..10")) - assert parse( + parse_run_configuration(test_conf("0..10")) + assert parse_run_configuration( test_conf( "0..10", { @@ -41,7 +41,7 @@ def test_conf(replicas: Any, scaling: Optional[Any] = None): ConfigurationError, match="When you set `replicas` to a range, ensure to specify `scaling`", ): - parse( + parse_run_configuration( test_conf( "0..10", { diff --git a/src/tests/_internal/server/routers/test_gateways.py b/src/tests/_internal/server/routers/test_gateways.py index 7fe394da3..6a40947ab 100644 --- a/src/tests/_internal/server/routers/test_gateways.py +++ b/src/tests/_internal/server/routers/test_gateways.py @@ -73,6 +73,14 @@ async def test_list(self, test_db, session: AsyncSession): "name": gateway.name, "region": gateway.region, "wildcard_domain": gateway.wildcard_domain, + "configuration": { + "type": "gateway", + "name": gateway.name, + "backend": backend.type.value, + "region": gateway.region, + "domain": gateway.wildcard_domain, + "default": False, + }, } ] @@ -109,6 +117,14 @@ async def test_get(self, test_db, session: AsyncSession): "name": gateway.name, "region": gateway.region, "wildcard_domain": gateway.wildcard_domain, + "configuration": { + "type": "gateway", + "name": gateway.name, + "backend": backend.type.value, + "region": gateway.region, + "domain": gateway.wildcard_domain, + "default": False, + }, } @pytest.mark.asyncio @@ -180,6 +196,14 @@ async def test_create_gateway(self, test_db, session: AsyncSession): "wildcard_domain": None, "default": True, "created_at": response.json()["created_at"], + "configuration": { + "type": "gateway", + "name": "test", + "backend": backend.type.value, + "region": "us", + "domain": None, + "default": True, + }, } @pytest.mark.asyncio @@ -226,6 +250,14 @@ async def test_create_gateway_without_name(self, test_db, session: AsyncSession) "wildcard_domain": None, "default": True, "created_at": response.json()["created_at"], + "configuration": { + "type": "gateway", + "name": "random-name", + "backend": backend.type.value, + "region": "us", + "domain": None, + "default": True, + }, } @pytest.mark.asyncio @@ -352,6 +384,14 @@ async def test_set_default_gateway(self, test_db, session: AsyncSession): "name": gateway.name, "region": gateway.region, "wildcard_domain": gateway.wildcard_domain, + "configuration": { + "type": "gateway", + "name": gateway.name, + "backend": backend.type.value, + "region": gateway.region, + "domain": gateway.wildcard_domain, + "default": True, + }, } @pytest.mark.asyncio @@ -451,6 +491,14 @@ def get_backend(_, backend_type): "name": gateway_gcp.name, "region": gateway_gcp.region, "wildcard_domain": gateway_gcp.wildcard_domain, + "configuration": { + "type": "gateway", + "name": gateway_gcp.name, + "backend": backend_gcp.type.value, + "region": gateway_gcp.region, + "domain": gateway_gcp.wildcard_domain, + "default": False, + }, } ] @@ -502,6 +550,14 @@ async def test_set_wildcard_domain(self, test_db, session: AsyncSession): "name": gateway.name, "region": gateway.region, "wildcard_domain": "test.com", + "configuration": { + "type": "gateway", + "name": gateway.name, + "backend": backend.type.value, + "region": gateway.region, + "domain": "test.com", + "default": False, + }, } @pytest.mark.asyncio