diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml
index 70b520311..0cad6b225 100644
--- a/.github/workflows/build.yml
+++ b/.github/workflows/build.yml
@@ -159,7 +159,7 @@ jobs:
run: pip install .
- name: Generate json schema
run: |
- python -c "from dstack._internal.core.models.configurations import RunConfiguration; print(RunConfiguration.schema_json(indent=2))" > configuration.json
+ python -c "from dstack._internal.core.models.configurations import DstackConfiguration; print(DstackConfiguration.schema_json(indent=2))" > configuration.json
python -c "from dstack._internal.core.models.profiles import ProfilesConfig; print(ProfilesConfig.schema_json(indent=2))" > profiles.json
- name: Upload json schema to S3
run: |
diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml
index be3c666b9..f70ecc176 100644
--- a/.github/workflows/release.yml
+++ b/.github/workflows/release.yml
@@ -223,7 +223,7 @@ jobs:
run: pip install .
- name: Generate json schema
run: |
- python -c "from dstack._internal.core.models.configurations import RunConfiguration; print(RunConfiguration.schema_json(indent=2))" > configuration.json
+ python -c "from dstack._internal.core.models.configurations import DstackConfiguration; print(DstackConfiguration.schema_json(indent=2))" > configuration.json
python -c "from dstack._internal.core.models.profiles import ProfilesConfig; print(ProfilesConfig.schema_json(indent=2))" > profiles.json
- name: Upload json schema to S3
run: |
diff --git a/docs/docs/reference/cli/index.md b/docs/docs/reference/cli/index.md
index 97b97f3bb..8c462bdfd 100644
--- a/docs/docs/reference/cli/index.md
+++ b/docs/docs/reference/cli/index.md
@@ -59,6 +59,24 @@ $ dstack run . --help
If there are large files, consider creating a `.gitignore` file to exclude them for better performance.
+### dstack apply
+
+This command applies a given configuration. If a resources does not exist, `dstack apply` creates the resource.
+If a resource exists, `dstack apply` updates the resource in-place or re-creates the resource if the update is not possible.
+
+
+
+```shell
+$ dstack apply --help
+#GENERATE#
+```
+
+
+
+!!! info "NOTE:"
+ The `dstack apply` command currently supports only `gateway` configurations.
+ Support for other configuration types is coming soon.
+
### dstack ps
This command shows the status of runs.
diff --git a/docs/docs/reference/dstack.yml/dev-environment.md b/docs/docs/reference/dstack.yml/dev-environment.md
index 1aba1ed37..6768a65b6 100644
--- a/docs/docs/reference/dstack.yml/dev-environment.md
+++ b/docs/docs/reference/dstack.yml/dev-environment.md
@@ -2,9 +2,10 @@
The `dev-environment` configuration type allows running [dev environments](../../concepts/dev-environments.md).
-> Configuration files must have a name ending with `.dstack.yml` (e.g., `.dstack.yml` or `dev.dstack.yml` are both acceptable)
-> and can be located in the project's root directory or any nested folder.
-> Any configuration can be run via [`dstack run . -f PATH`](../cli/index.md#dstack-run).
+!!! info "Filename"
+ Configuration files must have a name ending with `.dstack.yml` (e.g., `.dstack.yml` or `serve.dstack.yml` are both acceptable)
+ and can be located in the project's root directory or any nested folder.
+ Any configuration can be run via [`dstack run`](../cli/index.md#dstack-run).
## Examples
diff --git a/docs/docs/reference/dstack.yml/gateway.md b/docs/docs/reference/dstack.yml/gateway.md
new file mode 100644
index 000000000..a72679286
--- /dev/null
+++ b/docs/docs/reference/dstack.yml/gateway.md
@@ -0,0 +1,31 @@
+# gateway
+
+The `gateway` configuration type allows creating and updating [gateways](../../concepts/services.md).
+
+!!! info "Filename"
+ Configuration files must have a name ending with `.dstack.yml` (e.g., `.dstack.yml` or `serve.dstack.yml` are both acceptable)
+ and can be located in the project's root directory or any nested folder.
+ Any configuration can be applied via [`dstack apply`](../cli/index.md#dstack-apply).
+
+## Examples
+
+
+
+```yaml
+type: gateway
+name: example-gateway
+backend: aws
+region: eu-west-1
+domain: '*.example.com'
+```
+
+
+
+
+## Root reference
+
+#SCHEMA# dstack._internal.core.models.gateways.GatewayConfiguration
+ overrides:
+ show_root_heading: false
+ type:
+ required: true
diff --git a/docs/docs/reference/dstack.yml/task.md b/docs/docs/reference/dstack.yml/task.md
index 42282a161..cb080d162 100644
--- a/docs/docs/reference/dstack.yml/task.md
+++ b/docs/docs/reference/dstack.yml/task.md
@@ -2,9 +2,10 @@
The `task` configuration type allows running [tasks](../../concepts/tasks.md).
-> Configuration files must have a name ending with `.dstack.yml` (e.g., `.dstack.yml` or `train.dstack.yml` are both acceptable)
-> and can be located in the project's root directory or any nested folder.
-> Any configuration can be run via [`dstack run . -f PATH`](../cli/index.md#dstack-run).
+!!! info "Filename"
+ Configuration files must have a name ending with `.dstack.yml` (e.g., `.dstack.yml` or `serve.dstack.yml` are both acceptable)
+ and can be located in the project's root directory or any nested folder.
+ Any configuration can be run via [`dstack run`](../cli/index.md#dstack-run).
## Examples
diff --git a/mkdocs.yml b/mkdocs.yml
index 11a32bbf6..14aec5c19 100644
--- a/mkdocs.yml
+++ b/mkdocs.yml
@@ -199,6 +199,7 @@ nav:
- dev-environment: docs/reference/dstack.yml/dev-environment.md
- task: docs/reference/dstack.yml/task.md
- service: docs/reference/dstack.yml/service.md
+ - gateway: docs/reference/dstack.yml/gateway.md
- profiles.yml: docs/reference/profiles.yml.md
- CLI: docs/reference/cli/index.md
- server/config.yml: docs/reference/server/config.yml.md
diff --git a/src/dstack/_internal/cli/commands/apply.py b/src/dstack/_internal/cli/commands/apply.py
new file mode 100644
index 000000000..645baca76
--- /dev/null
+++ b/src/dstack/_internal/cli/commands/apply.py
@@ -0,0 +1,60 @@
+import argparse
+from pathlib import Path
+
+import yaml
+
+from dstack._internal.cli.commands import APIBaseCommand
+from dstack._internal.cli.services.configurators import (
+ get_apply_configurator_class,
+)
+from dstack._internal.cli.utils.common import cli_error
+from dstack._internal.core.errors import ConfigurationError
+from dstack._internal.core.models.configurations import (
+ AnyApplyConfiguration,
+ parse_apply_configuration,
+)
+
+
+class ApplyCommand(APIBaseCommand):
+ NAME = "apply"
+ DESCRIPTION = "Apply dstack configuration"
+
+ def _register(self):
+ super()._register()
+ self._parser.add_argument(
+ "configuration_file",
+ help="The path to the configuration file",
+ )
+ self._parser.add_argument(
+ "--force",
+ help="Force apply when no changes detected",
+ action="store_true",
+ )
+ self._parser.add_argument(
+ "-y",
+ "--yes",
+ help="Do not ask for confirmation",
+ action="store_true",
+ )
+
+ def _command(self, args: argparse.Namespace):
+ super()._command(args)
+ try:
+ configuration = _load_configuration(args.configuration_file)
+ except ConfigurationError as e:
+ raise cli_error(e)
+ configurator_class = get_apply_configurator_class(configuration.type)
+ configurator = configurator_class(api_client=self.api)
+ configurator.apply_configuration(conf=configuration, args=args)
+
+
+def _load_configuration(configuration_file: str) -> AnyApplyConfiguration:
+ configuration_path = Path(configuration_file)
+ if not configuration_path.exists():
+ raise ConfigurationError(f"Configuration file {configuration_file} does not exist")
+ try:
+ with open(configuration_path, "r") as f:
+ conf = parse_apply_configuration(yaml.safe_load(f))
+ except OSError:
+ raise ConfigurationError(f"Failed to load configuration from {configuration_path}")
+ return conf
diff --git a/src/dstack/_internal/cli/commands/pool.py b/src/dstack/_internal/cli/commands/pool.py
index 4c570fd15..7e95f820d 100644
--- a/src/dstack/_internal/cli/commands/pool.py
+++ b/src/dstack/_internal/cli/commands/pool.py
@@ -11,7 +11,7 @@
from dstack._internal.cli.commands import APIBaseCommand
from dstack._internal.cli.services.args import cpu_spec, disk_spec, gpu_spec, memory_spec
-from dstack._internal.cli.services.configurators.profile import (
+from dstack._internal.cli.services.profile import (
apply_profile_args,
register_profile_args,
)
diff --git a/src/dstack/_internal/cli/commands/run.py b/src/dstack/_internal/cli/commands/run.py
index 510c972b9..a706fa284 100644
--- a/src/dstack/_internal/cli/commands/run.py
+++ b/src/dstack/_internal/cli/commands/run.py
@@ -5,18 +5,18 @@
from typing import Optional
from dstack._internal.cli.commands import APIBaseCommand
-from dstack._internal.cli.services.configurators.profile import (
- apply_profile_args,
- register_profile_args,
-)
from dstack._internal.cli.services.configurators.run import (
BaseRunConfigurator,
run_configurators_mapping,
)
+from dstack._internal.cli.services.profile import (
+ apply_profile_args,
+ register_profile_args,
+)
from dstack._internal.cli.utils.common import confirm_ask, console
from dstack._internal.cli.utils.run import print_run_plan
from dstack._internal.core.errors import CLIError, ConfigurationError, ServerClientError
-from dstack._internal.core.models.configurations import ConfigurationType
+from dstack._internal.core.models.configurations import RunConfigurationType
from dstack._internal.core.models.runs import JobTerminationReason
from dstack._internal.core.services.configs import ConfigManager
from dstack._internal.utils.logging import get_logger
@@ -39,7 +39,7 @@ def _register(self):
"-h",
"--help",
nargs="?",
- type=ConfigurationType,
+ type=RunConfigurationType,
default=NOTSET,
help="Show this help message and exit. TYPE is one of [code]task[/], [code]dev-environment[/], [code]service[/]",
dest="help",
@@ -83,7 +83,7 @@ def _register(self):
def _command(self, args: argparse.Namespace):
if args.help is not NOTSET:
if args.help is not None:
- run_configurators_mapping[ConfigurationType(args.help)].register(self._parser)
+ run_configurators_mapping[RunConfigurationType(args.help)].register(self._parser)
else:
BaseRunConfigurator.register(self._parser)
self._parser.print_help()
@@ -102,7 +102,7 @@ def _command(self, args: argparse.Namespace):
apply_profile_args(args, conf)
logger.debug("Configuration loaded: %s", configuration_path)
parser = argparse.ArgumentParser()
- configurator = run_configurators_mapping[ConfigurationType(conf.type)]
+ configurator = run_configurators_mapping[RunConfigurationType(conf.type)]
configurator.register(parser)
known, unknown = parser.parse_known_args(args.unknown)
configurator.apply(known, unknown, conf)
@@ -176,7 +176,7 @@ def _command(self, args: argparse.Namespace):
)
if run.status in (RunStatus.RUNNING, RunStatus.DONE):
- if run._run.run_spec.configuration.type == ConfigurationType.SERVICE.value:
+ if run._run.run_spec.configuration.type == RunConfigurationType.SERVICE.value:
console.print(
f"Service is published at [link={run.service_url}]{run.service_url}[/]\n"
)
diff --git a/src/dstack/_internal/cli/main.py b/src/dstack/_internal/cli/main.py
index c1dc1da1b..6d40b7d98 100644
--- a/src/dstack/_internal/cli/main.py
+++ b/src/dstack/_internal/cli/main.py
@@ -3,6 +3,7 @@
from rich.markup import escape
from rich_argparse import RichHelpFormatter
+from dstack._internal.cli.commands.apply import ApplyCommand
from dstack._internal.cli.commands.config import ConfigCommand
from dstack._internal.cli.commands.gateway import GatewayCommand
from dstack._internal.cli.commands.init import InitCommand
@@ -50,6 +51,7 @@ def main():
parser.set_defaults(func=lambda _: parser.print_help())
subparsers = parser.add_subparsers(metavar="COMMAND")
+ ApplyCommand.register(subparsers)
ConfigCommand.register(subparsers)
GatewayCommand.register(subparsers)
PoolCommand.register(subparsers)
diff --git a/src/dstack/_internal/cli/services/configurators/__init__.py b/src/dstack/_internal/cli/services/configurators/__init__.py
index e69de29bb..152e0a3a7 100644
--- a/src/dstack/_internal/cli/services/configurators/__init__.py
+++ b/src/dstack/_internal/cli/services/configurators/__init__.py
@@ -0,0 +1,13 @@
+from typing import Dict, Type
+
+from dstack._internal.cli.services.configurators.base import BaseApplyConfigurator
+from dstack._internal.cli.services.configurators.gateway import GatewayConfigurator
+from dstack._internal.core.models.configurations import ApplyConfigurationType
+
+apply_configurators_mapping: Dict[ApplyConfigurationType, Type[BaseApplyConfigurator]] = {
+ cls.TYPE: cls for cls in [GatewayConfigurator]
+}
+
+
+def get_apply_configurator_class(configurator_type: str) -> Type[BaseApplyConfigurator]:
+ return apply_configurators_mapping[ApplyConfigurationType(configurator_type)]
diff --git a/src/dstack/_internal/cli/services/configurators/base.py b/src/dstack/_internal/cli/services/configurators/base.py
new file mode 100644
index 000000000..1067fa87a
--- /dev/null
+++ b/src/dstack/_internal/cli/services/configurators/base.py
@@ -0,0 +1,28 @@
+import argparse
+from abc import ABC, abstractmethod
+from typing import List
+
+from dstack._internal.core.models.configurations import (
+ AnyApplyConfiguration,
+ ApplyConfigurationType,
+)
+from dstack.api._public import Client
+
+
+class BaseApplyConfigurator(ABC):
+ TYPE: ApplyConfigurationType
+
+ def __init__(self, api_client: Client):
+ self.api_client = api_client
+
+ @abstractmethod
+ def apply_configuration(self, conf: AnyApplyConfiguration, args: argparse.Namespace):
+ pass
+
+ def register_args(self, parser: argparse.ArgumentParser):
+ pass
+
+ def apply_args(
+ self, args: argparse.Namespace, unknown: List[str], conf: AnyApplyConfiguration
+ ):
+ pass
diff --git a/src/dstack/_internal/cli/services/configurators/gateway.py b/src/dstack/_internal/cli/services/configurators/gateway.py
new file mode 100644
index 000000000..e1920b5c7
--- /dev/null
+++ b/src/dstack/_internal/cli/services/configurators/gateway.py
@@ -0,0 +1,58 @@
+import argparse
+
+from dstack._internal.cli.services.configurators.base import BaseApplyConfigurator
+from dstack._internal.cli.utils.common import confirm_ask, console
+from dstack._internal.cli.utils.gateway import print_gateways_table
+from dstack._internal.core.errors import ResourceNotExistsError
+from dstack._internal.core.models.configurations import ApplyConfigurationType
+from dstack._internal.core.models.gateways import GatewayConfiguration
+
+
+class GatewayConfigurator(BaseApplyConfigurator):
+ TYPE: ApplyConfigurationType = ApplyConfigurationType.GATEWAY
+
+ def apply_configuration(self, conf: GatewayConfiguration, args: argparse.Namespace):
+ # TODO: Show apply plan
+ # TODO: Update gateway in-place when domain/default change
+ confirmed = False
+ if conf.name is not None:
+ try:
+ gateway = self.api_client.client.gateways.get(
+ project_name=self.api_client.project, gateway_name=conf.name
+ )
+ except ResourceNotExistsError:
+ pass
+ else:
+ if gateway.configuration == conf:
+ if not args.force:
+ console.print(
+ "Gateway configuration has not changed. Use --force to recreate the gateway."
+ )
+ return
+ if not args.yes and not confirm_ask(
+ "Gateway configuration has not changed. Re-create the gateway?"
+ ):
+ console.print("\nExiting...")
+ return
+ elif not args.yes and not confirm_ask(
+ f"Gateway [code]{conf.name}[/] already exist. Re-create the gateway?"
+ ):
+ console.print("\nExiting...")
+ return
+ confirmed = True
+ with console.status("Deleting gateway..."):
+ self.api_client.client.gateways.delete(
+ project_name=self.api_client.project, gateways_names=[conf.name]
+ )
+ if not confirmed and not args.yes:
+ if not confirm_ask(
+ f"Gateway [code]{conf.name}[/] does not exist yet. Create the gateway?"
+ ):
+ console.print("\nExiting...")
+ return
+ with console.status("Creating gateway..."):
+ gateway = self.api_client.client.gateways.create(
+ project_name=self.api_client.project,
+ configuration=conf,
+ )
+ print_gateways_table([gateway])
diff --git a/src/dstack/_internal/cli/services/configurators/run.py b/src/dstack/_internal/cli/services/configurators/run.py
index ba6687aaa..1dc1cd692 100644
--- a/src/dstack/_internal/cli/services/configurators/run.py
+++ b/src/dstack/_internal/cli/services/configurators/run.py
@@ -11,10 +11,10 @@
from dstack._internal.core.models.configurations import (
BaseConfiguration,
BaseConfigurationWithPorts,
- ConfigurationType,
DevEnvironmentConfiguration,
EnvSentinel,
PortMapping,
+ RunConfigurationType,
ServiceConfiguration,
TaskConfiguration,
)
@@ -22,7 +22,7 @@
class BaseRunConfigurator:
- TYPE: ConfigurationType = None
+ TYPE: RunConfigurationType = None
@classmethod
def register(cls, parser: argparse.ArgumentParser):
@@ -102,7 +102,7 @@ def apply(cls, args: argparse.Namespace, unknown: List[str], conf: BaseConfigura
class TaskRunConfigurator(RunWithPortsConfigurator):
- TYPE = ConfigurationType.TASK
+ TYPE = RunConfigurationType.TASK
@classmethod
def apply(cls, args: argparse.Namespace, unknown: List[str], conf: TaskConfiguration):
@@ -112,7 +112,7 @@ def apply(cls, args: argparse.Namespace, unknown: List[str], conf: TaskConfigura
class DevEnvironmentRunConfigurator(RunWithPortsConfigurator):
- TYPE = ConfigurationType.DEV_ENVIRONMENT
+ TYPE = RunConfigurationType.DEV_ENVIRONMENT
@classmethod
def apply(
@@ -130,7 +130,7 @@ def apply(
class ServiceRunConfigurator(BaseRunConfigurator):
- TYPE = ConfigurationType.SERVICE
+ TYPE = RunConfigurationType.SERVICE
@classmethod
def apply(cls, args: argparse.Namespace, unknown: List[str], conf: ServiceConfiguration):
@@ -169,7 +169,7 @@ def _detect_vscode_version(exe: str = "code") -> Optional[str]:
return None
-run_configurators_mapping: Dict[ConfigurationType, Type[BaseRunConfigurator]] = {
+run_configurators_mapping: Dict[RunConfigurationType, Type[BaseRunConfigurator]] = {
cls.TYPE: cls
for cls in [
TaskRunConfigurator,
diff --git a/src/dstack/_internal/cli/services/configurators/profile.py b/src/dstack/_internal/cli/services/profile.py
similarity index 100%
rename from src/dstack/_internal/cli/services/configurators/profile.py
rename to src/dstack/_internal/cli/services/profile.py
diff --git a/src/dstack/_internal/core/backends/aws/compute.py b/src/dstack/_internal/core/backends/aws/compute.py
index f2d0da344..3bc3c834b 100644
--- a/src/dstack/_internal/core/backends/aws/compute.py
+++ b/src/dstack/_internal/core/backends/aws/compute.py
@@ -19,6 +19,7 @@
from dstack._internal.core.models.backends.aws import AWSAccessKeyCreds
from dstack._internal.core.models.backends.base import BackendType
from dstack._internal.core.models.common import is_core_model_instance
+from dstack._internal.core.models.gateways import GatewayComputeConfiguration
from dstack._internal.core.models.instances import (
InstanceAvailability,
InstanceConfiguration,
@@ -192,17 +193,14 @@ def run_job(
def create_gateway(
self,
- instance_name: str,
- ssh_key_pub: str,
- region: str,
- project_id: str,
- ) -> JobProvisioningData:
- ec2 = self.session.resource("ec2", region_name=region)
- ec2_client = self.session.client("ec2", region_name=region)
+ configuration: GatewayComputeConfiguration,
+ ) -> LaunchedGatewayInfo:
+ ec2 = self.session.resource("ec2", region_name=configuration.region)
+ ec2_client = self.session.client("ec2", region_name=configuration.region)
tags = [
- {"Key": "Name", "Value": instance_name},
+ {"Key": "Name", "Value": configuration.instance_name},
{"Key": "owner", "Value": "dstack"},
- {"Key": "dstack_project", "Value": project_id},
+ {"Key": "dstack_project", "Value": configuration.project_name},
]
if settings.DSTACK_VERSION is not None:
tags.append({"Key": "dstack_version", "Value": settings.DSTACK_VERSION})
@@ -212,11 +210,11 @@ def create_gateway(
image_id=aws_resources.get_gateway_image_id(ec2_client),
instance_type="t2.micro",
iam_instance_profile_arn=None,
- user_data=get_gateway_user_data(ssh_key_pub),
+ user_data=get_gateway_user_data(configuration.ssh_key_pub),
tags=tags,
security_group_id=aws_resources.create_gateway_security_group(
ec2_client=ec2_client,
- project_id=project_id,
+ project_id=configuration.project_name,
),
spot=False,
)
@@ -226,7 +224,7 @@ def create_gateway(
instance.reload() # populate instance.public_ip_address
return LaunchedGatewayInfo(
instance_id=instance.instance_id,
- region=region,
+ region=configuration.region,
ip_address=instance.public_ip_address,
)
diff --git a/src/dstack/_internal/core/backends/azure/compute.py b/src/dstack/_internal/core/backends/azure/compute.py
index 5b7f4b1b3..d53e17320 100644
--- a/src/dstack/_internal/core/backends/azure/compute.py
+++ b/src/dstack/_internal/core/backends/azure/compute.py
@@ -43,6 +43,7 @@
from dstack._internal.core.backends.base.offers import get_catalog_offers
from dstack._internal.core.errors import NoCapacityError
from dstack._internal.core.models.backends.base import BackendType
+from dstack._internal.core.models.gateways import GatewayComputeConfiguration
from dstack._internal.core.models.instances import (
InstanceAvailability,
InstanceConfiguration,
@@ -183,35 +184,36 @@ def terminate_instance(
def create_gateway(
self,
- instance_name: str,
- ssh_key_pub: str,
- region: str,
- project_id: str,
+ configuration: GatewayComputeConfiguration,
) -> LaunchedGatewayInfo:
- logger.info("Launching %s gateway instance in %s...", instance_name, region)
+ logger.info(
+ "Launching %s gateway instance in %s...",
+ configuration.instance_name,
+ configuration.region,
+ )
vm = _launch_instance(
compute_client=self._compute_client,
subscription_id=self.config.subscription_id,
- location=region,
+ location=configuration.region,
resource_group=self.config.resource_group,
network_security_group=azure_utils.get_gateway_network_security_group_name(
resource_group=self.config.resource_group,
- location=region,
+ location=configuration.region,
),
network=azure_utils.get_default_network_name(
resource_group=self.config.resource_group,
- location=region,
+ location=configuration.region,
),
subnet=azure_utils.get_default_subnet_name(
resource_group=self.config.resource_group,
- location=region,
+ location=configuration.region,
),
managed_identity=None,
image_reference=_get_gateway_image_ref(),
vm_size="Standard_B1s",
- instance_name=instance_name,
- user_data=get_gateway_user_data(ssh_key_pub),
- ssh_pub_keys=[ssh_key_pub],
+ instance_name=configuration.instance_name,
+ user_data=get_gateway_user_data(configuration.ssh_key_pub),
+ ssh_pub_keys=[configuration.ssh_key_pub],
spot=False,
disk_size=30,
computer_name="gatewayvm",
@@ -225,7 +227,7 @@ def create_gateway(
return LaunchedGatewayInfo(
instance_id=vm.name,
ip_address=public_ip,
- region=region,
+ region=configuration.region,
)
diff --git a/src/dstack/_internal/core/backends/base/compute.py b/src/dstack/_internal/core/backends/base/compute.py
index 02c2b6186..fc61754da 100644
--- a/src/dstack/_internal/core/backends/base/compute.py
+++ b/src/dstack/_internal/core/backends/base/compute.py
@@ -9,6 +9,7 @@
import yaml
from dstack._internal import settings
+from dstack._internal.core.models.gateways import GatewayComputeConfiguration
from dstack._internal.core.models.instances import (
InstanceConfiguration,
InstanceOfferWithAvailability,
@@ -84,10 +85,7 @@ def update_provisioning_data(
def create_gateway(
self,
- instance_name: str,
- ssh_key_pub: str,
- region: str,
- project_id: str,
+ configuration: GatewayComputeConfiguration,
) -> LaunchedGatewayInfo:
raise NotImplementedError()
diff --git a/src/dstack/_internal/core/backends/gcp/compute.py b/src/dstack/_internal/core/backends/gcp/compute.py
index dfcb798e8..a6cba7149 100644
--- a/src/dstack/_internal/core/backends/gcp/compute.py
+++ b/src/dstack/_internal/core/backends/gcp/compute.py
@@ -17,6 +17,7 @@
from dstack._internal.core.backends.gcp.config import GCPConfig
from dstack._internal.core.errors import ComputeResourceNotFoundError, NoCapacityError
from dstack._internal.core.models.backends.base import BackendType
+from dstack._internal.core.models.gateways import GatewayComputeConfiguration
from dstack._internal.core.models.instances import (
InstanceAvailability,
InstanceConfiguration,
@@ -176,10 +177,7 @@ def run_job(
def create_gateway(
self,
- instance_name: str,
- ssh_key_pub: str,
- region: str,
- project_id: str,
+ configuration: GatewayComputeConfiguration,
) -> LaunchedGatewayInfo:
gcp_resources.create_gateway_firewall_rules(
firewalls_client=self.firewalls_client,
@@ -187,7 +185,7 @@ def create_gateway(
)
# e2-micro is available in every zone
for i in self.regions_client.list(project=self.config.project_id):
- if i.name == region:
+ if i.name == configuration.region:
zone = i.zones[0].split("/")[-1]
break
else:
@@ -202,24 +200,24 @@ def create_gateway(
machine_type="e2-micro",
accelerators=[],
spot=False,
- user_data=get_gateway_user_data(ssh_key_pub),
- authorized_keys=[ssh_key_pub],
+ user_data=get_gateway_user_data(configuration.ssh_key_pub),
+ authorized_keys=[configuration.ssh_key_pub],
labels={
"owner": "dstack",
- "dstack_project": project_id,
+ "dstack_project": configuration.project_name,
},
tags=[gcp_resources.DSTACK_GATEWAY_TAG],
- instance_name=instance_name,
+ instance_name=configuration.instance_name,
zone=zone,
service_account=None,
)
operation = self.instances_client.insert(request=request)
gcp_resources.wait_for_extended_operation(operation, "instance creation")
instance = self.instances_client.get(
- project=self.config.project_id, zone=zone, instance=instance_name
+ project=self.config.project_id, zone=zone, instance=configuration.instance_name
)
return LaunchedGatewayInfo(
- instance_id=instance_name,
+ instance_id=configuration.instance_name,
region=zone, # used for instance termination
ip_address=instance.network_interfaces[0].access_configs[0].nat_i_p,
)
diff --git a/src/dstack/_internal/core/backends/kubernetes/compute.py b/src/dstack/_internal/core/backends/kubernetes/compute.py
index 9810c1b5a..61d25ff6c 100644
--- a/src/dstack/_internal/core/backends/kubernetes/compute.py
+++ b/src/dstack/_internal/core/backends/kubernetes/compute.py
@@ -4,8 +4,7 @@
import time
from typing import Dict, List, Optional
-# TODO: update import as KNOWN_GPUS becomes public
-from gpuhunt._internal.constraints import KNOWN_GPUS
+from gpuhunt import KNOWN_GPUS
from kubernetes import client
from dstack._internal.core.backends.base.compute import (
@@ -22,6 +21,9 @@
)
from dstack._internal.core.errors import ComputeError, GatewayError
from dstack._internal.core.models.backends.base import BackendType
+
+# TODO: update import as KNOWN_GPUS becomes public
+from dstack._internal.core.models.gateways import GatewayComputeConfiguration
from dstack._internal.core.models.instances import (
Disk,
Gpu,
@@ -205,7 +207,10 @@ def terminate_instance(
if e.status != 404:
raise
- def create_gateway(self, instance_name: str, ssh_key_pub: str, region: str, project_id: str):
+ def create_gateway(
+ self,
+ configuration: GatewayComputeConfiguration,
+ ) -> LaunchedGatewayInfo:
# Gateway creation is currently limited to Kubernetes with Load Balancer support.
# If the cluster does not support Load Balancer, the service will be provisioned but
# the external IP/hostname will never be allocated.
@@ -215,7 +220,8 @@ def create_gateway(self, instance_name: str, ssh_key_pub: str, region: str, proj
# TODO: By default EKS creates a Classic Load Balancer for Load Balancer services.
# Consider deploying an NLB. It seems it requires some extra configuration on the cluster:
# https://docs.aws.amazon.com/eks/latest/userguide/network-load-balancing.html
- commands = _get_gateway_commands(authorized_keys=[ssh_key_pub])
+ instance_name = configuration.instance_name
+ commands = _get_gateway_commands(authorized_keys=[configuration.ssh_key_pub])
self.api.create_namespaced_pod(
namespace=DEFAULT_NAMESPACE,
body=client.V1Pod(
diff --git a/src/dstack/_internal/core/models/configurations.py b/src/dstack/_internal/core/models/configurations.py
index 781610017..591f1724e 100644
--- a/src/dstack/_internal/core/models/configurations.py
+++ b/src/dstack/_internal/core/models/configurations.py
@@ -7,7 +7,7 @@
from dstack._internal.core.errors import ConfigurationError
from dstack._internal.core.models.common import CoreModel, Duration
-from dstack._internal.core.models.gateways import AnyModel
+from dstack._internal.core.models.gateways import AnyModel, GatewayConfiguration
from dstack._internal.core.models.profiles import ProfileParams
from dstack._internal.core.models.repos.base import Repo
from dstack._internal.core.models.repos.virtual import VirtualRepo
@@ -17,7 +17,7 @@
ValidPort = conint(gt=0, le=65536)
-class ConfigurationType(str, Enum):
+class RunConfigurationType(str, Enum):
DEV_ENVIRONMENT = "dev-environment"
TASK = "task"
SERVICE = "service"
@@ -319,13 +319,38 @@ class RunConfiguration(CoreModel):
Field(discriminator="type"),
]
- class Config:
- schema_extra = {"$schema": "http://json-schema.org/draft-07/schema#"}
-
-def parse(data: dict) -> AnyRunConfiguration:
+def parse_run_configuration(data: dict) -> AnyRunConfiguration:
try:
conf = RunConfiguration.parse_obj(data).__root__
except ValidationError as e:
raise ConfigurationError(e)
return conf
+
+
+class ApplyConfigurationType(str, Enum):
+ GATEWAY = "gateway"
+
+
+AnyApplyConfiguration = GatewayConfiguration
+
+
+def parse_apply_configuration(data: dict) -> AnyApplyConfiguration:
+ try:
+ conf = GatewayConfiguration.parse_obj(data)
+ except ValidationError as e:
+ raise ConfigurationError(e)
+ return conf
+
+
+AnyDstackConfiguration = Union[AnyRunConfiguration, GatewayConfiguration]
+
+
+class DstackConfiguration(CoreModel):
+ __root__: Annotated[
+ AnyDstackConfiguration,
+ Field(discriminator="type"),
+ ]
+
+ class Config:
+ schema_extra = {"$schema": "http://json-schema.org/draft-07/schema#"}
diff --git a/src/dstack/_internal/core/models/gateways.py b/src/dstack/_internal/core/models/gateways.py
index 3831fa5e1..6df853b08 100644
--- a/src/dstack/_internal/core/models/gateways.py
+++ b/src/dstack/_internal/core/models/gateways.py
@@ -8,7 +8,30 @@
from dstack._internal.core.models.common import CoreModel
+class GatewayConfiguration(CoreModel):
+ type: Literal["gateway"] = "gateway"
+ name: Annotated[Optional[str], Field(description="The gateway name")] = None
+ default: Annotated[bool, Field(description="Make the gateway default")] = False
+ backend: Annotated[BackendType, Field(description="The gateway backend")]
+ region: Annotated[str, Field(description="The gateway region")]
+ domain: Annotated[
+ Optional[str], Field(description="The gateway domain, e.g. `*.example.com`")
+ ] = None
+ # public_ip: Annotated[bool, Field(description="Allocate public IP for the gateway")] = True
+
+
+class GatewayComputeConfiguration(CoreModel):
+ project_name: str
+ instance_name: str
+ backend: BackendType
+ region: str
+ public_ip: bool
+ ssh_key_pub: str
+
+
class Gateway(CoreModel):
+ # TODO: configuration fields are duplicated on top-level for backward compatibility with 0.18.x
+ # Remove in 0.19
name: str
ip_address: Optional[str]
instance_id: Optional[str]
@@ -17,6 +40,7 @@ class Gateway(CoreModel):
default: bool
created_at: datetime.datetime
backend: BackendType
+ configuration: GatewayConfiguration
class BaseChatModel(CoreModel):
diff --git a/src/dstack/_internal/server/migrations/versions/58aa5162dcc3_add_gatewaymodel_configuration.py b/src/dstack/_internal/server/migrations/versions/58aa5162dcc3_add_gatewaymodel_configuration.py
new file mode 100644
index 000000000..5a25adaee
--- /dev/null
+++ b/src/dstack/_internal/server/migrations/versions/58aa5162dcc3_add_gatewaymodel_configuration.py
@@ -0,0 +1,32 @@
+"""Add GatewayModel.configuration
+
+Revision ID: 58aa5162dcc3
+Revises: 1e3fb39ef74b
+Create Date: 2024-05-15 11:04:58.848554
+
+"""
+
+import sqlalchemy as sa
+from alembic import op
+
+# revision identifiers, used by Alembic.
+revision = "58aa5162dcc3"
+down_revision = "1e3fb39ef74b"
+branch_labels = None
+depends_on = None
+
+
+def upgrade() -> None:
+ # ### commands auto generated by Alembic - please adjust! ###
+ with op.batch_alter_table("gateways", schema=None) as batch_op:
+ batch_op.add_column(sa.Column("configuration", sa.Text(), nullable=True))
+
+ # ### end Alembic commands ###
+
+
+def downgrade() -> None:
+ # ### commands auto generated by Alembic - please adjust! ###
+ with op.batch_alter_table("gateways", schema=None) as batch_op:
+ batch_op.drop_column("configuration")
+
+ # ### end Alembic commands ###
diff --git a/src/dstack/_internal/server/models.py b/src/dstack/_internal/server/models.py
index 1615f4a99..57ee37ff3 100644
--- a/src/dstack/_internal/server/models.py
+++ b/src/dstack/_internal/server/models.py
@@ -230,6 +230,7 @@ class GatewayModel(BaseModel):
name: Mapped[str] = mapped_column(String(100))
region: Mapped[str] = mapped_column(String(100))
wildcard_domain: Mapped[str] = mapped_column(String(100), nullable=True)
+ configuration: Mapped[Optional[str]] = mapped_column(Text)
created_at: Mapped[datetime] = mapped_column(DateTime, default=get_current_datetime)
project_id: Mapped[uuid.UUID] = mapped_column(ForeignKey("projects.id", ondelete="CASCADE"))
diff --git a/src/dstack/_internal/server/routers/gateways.py b/src/dstack/_internal/server/routers/gateways.py
index 6de22aa6a..4215fa90a 100644
--- a/src/dstack/_internal/server/routers/gateways.py
+++ b/src/dstack/_internal/server/routers/gateways.py
@@ -46,9 +46,7 @@ async def create_gateway(
return await gateways.create_gateway(
session=session,
project=project,
- name=body.name,
- backend_type=body.backend_type,
- region=body.region,
+ configuration=body.configuration,
)
diff --git a/src/dstack/_internal/server/schemas/gateways.py b/src/dstack/_internal/server/schemas/gateways.py
index 4183b862f..99930b95e 100644
--- a/src/dstack/_internal/server/schemas/gateways.py
+++ b/src/dstack/_internal/server/schemas/gateways.py
@@ -1,13 +1,34 @@
-from typing import List, Optional
+from typing import Dict, List, Optional
+
+from pydantic import root_validator
from dstack._internal.core.models.backends.base import BackendType
from dstack._internal.core.models.common import CoreModel
+from dstack._internal.core.models.gateways import GatewayConfiguration
class CreateGatewayRequest(CoreModel):
name: Optional[str]
- backend_type: BackendType
- region: str
+ backend_type: Optional[BackendType]
+ region: Optional[str]
+ configuration: Optional[GatewayConfiguration]
+
+ @root_validator
+ def fill_configuration(cls, values: Dict) -> Dict:
+ if values.get("configuration", None) is not None:
+ return values
+ backend_type = values.get("backend_type", None)
+ region = values.get("region", None)
+ if backend_type is None:
+ raise ValueError("backend_type must be specified")
+ if region is None:
+ raise ValueError("region must be specified")
+ values["configuration"] = GatewayConfiguration(
+ name=values.get("name", None),
+ backend=backend_type,
+ region=region,
+ )
+ return values
class GetGatewayRequest(CoreModel):
diff --git a/src/dstack/_internal/server/services/gateways/__init__.py b/src/dstack/_internal/server/services/gateways/__init__.py
index 78cb3b76d..ff960c7f3 100644
--- a/src/dstack/_internal/server/services/gateways/__init__.py
+++ b/src/dstack/_internal/server/services/gateways/__init__.py
@@ -23,7 +23,11 @@
SSHError,
)
from dstack._internal.core.models.backends.base import BackendType
-from dstack._internal.core.models.gateways import Gateway
+from dstack._internal.core.models.gateways import (
+ Gateway,
+ GatewayComputeConfiguration,
+ GatewayConfiguration,
+)
from dstack._internal.core.models.runs import (
Run,
RunSpec,
@@ -84,9 +88,7 @@ async def get_project_default_gateway(
async def create_gateway_compute(
backend_compute: Compute,
- instance_name: str,
- region: str,
- project_id: str,
+ configuration: GatewayComputeConfiguration,
backend_id: Optional[uuid.UUID] = None,
) -> GatewayComputeModel:
private_bytes, public_bytes = generate_rsa_key_pair_bytes()
@@ -95,10 +97,7 @@ async def create_gateway_compute(
info = await run_async(
backend_compute.create_gateway,
- instance_name,
- gateway_ssh_public_key,
- region,
- project_id,
+ configuration,
)
return GatewayComputeModel(
@@ -114,38 +113,45 @@ async def create_gateway_compute(
async def create_gateway(
session: AsyncSession,
project: ProjectModel,
- name: Optional[str],
- backend_type: BackendType,
- region: str,
+ configuration: GatewayConfiguration,
) -> Gateway:
# TODO: Gateay creation may take significant time. Make it asynchronous.
for backend_model, backend in await get_project_backends_with_models(project):
- if backend_model.type == backend_type:
+ if backend_model.type == configuration.backend:
break
else:
raise ResourceNotExistsError()
- if name is None:
- name = await generate_gateway_name(session=session, project=project)
+ if configuration.name is None:
+ configuration.name = await generate_gateway_name(session=session, project=project)
gateway = GatewayModel( # reserve name
- name=name,
- region=region,
+ name=configuration.name,
+ region=configuration.region,
project_id=project.id,
backend_id=backend_model.id,
+ wildcard_domain=configuration.domain,
+ configuration=configuration.json(),
)
session.add(gateway)
await session.commit()
- if project.default_gateway is None:
- await set_default_gateway(session=session, project=project, name=name)
+ if project.default_gateway is None or configuration.default:
+ await set_default_gateway(session=session, project=project, name=configuration.name)
+
+ compute_configuration = GatewayComputeConfiguration(
+ project_name=project.name,
+ instance_name=gateway.name,
+ backend=configuration.backend,
+ region=configuration.region,
+ public_ip=True,
+ ssh_key_pub=project.name,
+ )
try:
gateway.gateway_compute = await create_gateway_compute(
backend_compute=backend.compute(),
- instance_name=gateway.name,
- region=region,
- project_id=project.name,
+ configuration=compute_configuration,
backend_id=backend_model.id,
)
session.add(gateway)
@@ -155,7 +161,7 @@ async def create_gateway(
await session.execute(
delete(GatewayModel).where(
GatewayModel.project_id == project.id,
- GatewayModel.name == name,
+ GatewayModel.name == configuration.name,
)
)
await session.commit()
@@ -545,6 +551,18 @@ def gateway_model_to_gateway(gateway_model: GatewayModel) -> Gateway:
backend_type = gateway_model.backend.type
if gateway_model.backend.type == BackendType.DSTACK:
backend_type = BackendType.AWS
+ if gateway_model.configuration is not None:
+ configuration = GatewayConfiguration.__response__.parse_raw(gateway_model.configuration)
+ else:
+ # Handle gateways created before GatewayConfiguration was introduced
+ configuration = GatewayConfiguration(
+ name=gateway_model.name,
+ default=False,
+ backend=gateway_model.backend.type,
+ region=gateway_model.region,
+ domain=gateway_model.wildcard_domain,
+ )
+ configuration.default = gateway_model.project.default_gateway_id == gateway_model.id
return Gateway(
name=gateway_model.name,
ip_address=ip_address,
@@ -554,4 +572,5 @@ def gateway_model_to_gateway(gateway_model: GatewayModel) -> Gateway:
default=gateway_model.project.default_gateway_id == gateway_model.id,
created_at=gateway_model.created_at.replace(tzinfo=timezone.utc),
backend=backend_type,
+ configuration=configuration,
)
diff --git a/src/dstack/_internal/server/services/jobs/__init__.py b/src/dstack/_internal/server/services/jobs/__init__.py
index 6c23695b0..9833903b3 100644
--- a/src/dstack/_internal/server/services/jobs/__init__.py
+++ b/src/dstack/_internal/server/services/jobs/__init__.py
@@ -12,7 +12,7 @@
import dstack._internal.server.services.gateways as gateways
from dstack._internal.core.errors import ComputeResourceNotFoundError, SSHError
from dstack._internal.core.models.backends.base import BackendType
-from dstack._internal.core.models.configurations import ConfigurationType
+from dstack._internal.core.models.configurations import RunConfigurationType
from dstack._internal.core.models.instances import RemoteConnectionInfo
from dstack._internal.core.models.runs import (
InstanceStatus,
@@ -133,7 +133,7 @@ def delay_job_instance_termination(job_model: JobModel):
def _get_job_configurator(run_spec: RunSpec) -> JobConfigurator:
- configuration_type = ConfigurationType(run_spec.configuration.type)
+ configuration_type = RunConfigurationType(run_spec.configuration.type)
configurator_class = _configuration_type_to_configurator_class_map[configuration_type]
return configurator_class(run_spec)
diff --git a/src/dstack/_internal/server/services/jobs/configurators/base.py b/src/dstack/_internal/server/services/jobs/configurators/base.py
index 7a9625bd0..ddfdb3011 100644
--- a/src/dstack/_internal/server/services/jobs/configurators/base.py
+++ b/src/dstack/_internal/server/services/jobs/configurators/base.py
@@ -8,10 +8,10 @@
import dstack.version as version
from dstack._internal.core.errors import DockerRegistryError, ServerClientError
from dstack._internal.core.models.configurations import (
- ConfigurationType,
PortMapping,
PythonVersion,
RegistryAuth,
+ RunConfigurationType,
)
from dstack._internal.core.models.profiles import SpotPolicy
from dstack._internal.core.models.runs import (
@@ -44,7 +44,7 @@ def get_default_image(python_version: str) -> str:
class JobConfigurator(ABC):
- TYPE: ConfigurationType
+ TYPE: RunConfigurationType
def __init__(self, run_spec: RunSpec):
self.run_spec = run_spec
diff --git a/src/dstack/_internal/server/services/jobs/configurators/dev.py b/src/dstack/_internal/server/services/jobs/configurators/dev.py
index 760390bff..2dbb4e21b 100644
--- a/src/dstack/_internal/server/services/jobs/configurators/dev.py
+++ b/src/dstack/_internal/server/services/jobs/configurators/dev.py
@@ -1,6 +1,6 @@
from typing import List, Optional
-from dstack._internal.core.models.configurations import ConfigurationType, PortMapping
+from dstack._internal.core.models.configurations import PortMapping, RunConfigurationType
from dstack._internal.core.models.profiles import ProfileRetryPolicy, SpotPolicy
from dstack._internal.core.models.runs import RetryPolicy, RunSpec
from dstack._internal.server.services.jobs.configurators.base import JobConfigurator
@@ -15,7 +15,7 @@
class DevEnvironmentJobConfigurator(JobConfigurator):
- TYPE: ConfigurationType = ConfigurationType.DEV_ENVIRONMENT
+ TYPE: RunConfigurationType = RunConfigurationType.DEV_ENVIRONMENT
def __init__(self, run_spec: RunSpec):
self.ide = VSCodeDesktop(
diff --git a/src/dstack/_internal/server/services/jobs/configurators/service.py b/src/dstack/_internal/server/services/jobs/configurators/service.py
index ed783fa58..718caabe0 100644
--- a/src/dstack/_internal/server/services/jobs/configurators/service.py
+++ b/src/dstack/_internal/server/services/jobs/configurators/service.py
@@ -1,13 +1,13 @@
from typing import List, Optional
-from dstack._internal.core.models.configurations import ConfigurationType, PortMapping
+from dstack._internal.core.models.configurations import PortMapping, RunConfigurationType
from dstack._internal.core.models.profiles import ProfileRetryPolicy, SpotPolicy
from dstack._internal.core.models.runs import RetryPolicy
from dstack._internal.server.services.jobs.configurators.base import JobConfigurator
class ServiceJobConfigurator(JobConfigurator):
- TYPE: ConfigurationType = ConfigurationType.SERVICE
+ TYPE: RunConfigurationType = RunConfigurationType.SERVICE
def _shell_commands(self) -> List[str]:
return self.run_spec.configuration.commands
diff --git a/src/dstack/_internal/server/services/jobs/configurators/task.py b/src/dstack/_internal/server/services/jobs/configurators/task.py
index d1eb2d97e..de8fad9be 100644
--- a/src/dstack/_internal/server/services/jobs/configurators/task.py
+++ b/src/dstack/_internal/server/services/jobs/configurators/task.py
@@ -1,6 +1,6 @@
from typing import List, Optional
-from dstack._internal.core.models.configurations import ConfigurationType, PortMapping
+from dstack._internal.core.models.configurations import PortMapping, RunConfigurationType
from dstack._internal.core.models.profiles import ProfileRetryPolicy, SpotPolicy
from dstack._internal.core.models.runs import JobSpec, RetryPolicy
from dstack._internal.server.services.jobs.configurators.base import JobConfigurator
@@ -9,7 +9,7 @@
class TaskJobConfigurator(JobConfigurator):
- TYPE: ConfigurationType = ConfigurationType.TASK
+ TYPE: RunConfigurationType = RunConfigurationType.TASK
async def get_job_specs(self, replica_num: int) -> List[JobSpec]:
job_specs = []
diff --git a/src/dstack/api/server/_gateways.py b/src/dstack/api/server/_gateways.py
index 3d77f208d..f2800c2fe 100644
--- a/src/dstack/api/server/_gateways.py
+++ b/src/dstack/api/server/_gateways.py
@@ -3,7 +3,7 @@
from pydantic import parse_obj_as
from dstack._internal.core.models.backends.base import BackendType
-from dstack._internal.core.models.gateways import Gateway
+from dstack._internal.core.models.gateways import Gateway, GatewayConfiguration
from dstack._internal.server.schemas.gateways import (
CreateGatewayRequest,
DeleteGatewaysRequest,
@@ -24,14 +24,22 @@ def get(self, project_name: str, gateway_name: str) -> Gateway:
resp = self._request(f"/api/project/{project_name}/gateways/get", body=body.json())
return parse_obj_as(Gateway.__response__, resp.json())
+ # gateway_name, backend_type, region are left for backward-compatibility with 0.18.x
+ # TODO: Remove in 0.19
def create(
self,
project_name: str,
- gateway_name: Optional[str],
- backend_type: BackendType,
- region: str,
+ gateway_name: Optional[str] = None,
+ backend_type: Optional[BackendType] = None,
+ region: Optional[str] = None,
+ configuration: Optional[GatewayConfiguration] = None,
) -> Gateway:
- body = CreateGatewayRequest(name=gateway_name, backend_type=backend_type, region=region)
+ body = CreateGatewayRequest(
+ name=gateway_name,
+ backend_type=backend_type,
+ region=region,
+ configuration=configuration,
+ )
resp = self._request(f"/api/project/{project_name}/gateways/create", body=body.json())
return parse_obj_as(Gateway.__response__, resp.json())
diff --git a/src/dstack/api/utils.py b/src/dstack/api/utils.py
index 19e074131..9471ce74b 100644
--- a/src/dstack/api/utils.py
+++ b/src/dstack/api/utils.py
@@ -5,7 +5,9 @@
from dstack._internal.core.errors import ConfigurationError
from dstack._internal.core.models.configurations import AnyRunConfiguration
-from dstack._internal.core.models.configurations import parse as parse_configuration
+from dstack._internal.core.models.configurations import (
+ parse_run_configuration as parse_configuration,
+)
from dstack._internal.core.models.profiles import Profile, ProfilesConfig
from dstack._internal.utils.common import get_dstack_dir
from dstack._internal.utils.path import PathLike, path_in_dir
diff --git a/src/tests/_internal/cli/services/configurators/test_profile.py b/src/tests/_internal/cli/services/configurators/test_profile.py
index ead9568d7..d9a047bdf 100644
--- a/src/tests/_internal/cli/services/configurators/test_profile.py
+++ b/src/tests/_internal/cli/services/configurators/test_profile.py
@@ -1,7 +1,7 @@
import argparse
from typing import List, Tuple
-from dstack._internal.cli.services.configurators.profile import (
+from dstack._internal.cli.services.profile import (
apply_profile_args,
register_profile_args,
)
diff --git a/src/tests/_internal/core/models/test_configurations.py b/src/tests/_internal/core/models/test_configurations.py
index 92601e356..90dfdc11e 100644
--- a/src/tests/_internal/core/models/test_configurations.py
+++ b/src/tests/_internal/core/models/test_configurations.py
@@ -3,7 +3,7 @@
import pytest
from dstack._internal.core.errors import ConfigurationError
-from dstack._internal.core.models.configurations import RegistryAuth, parse
+from dstack._internal.core.models.configurations import RegistryAuth, parse_run_configuration
from dstack._internal.core.models.resources import Range
@@ -20,15 +20,15 @@ def test_conf(replicas: Any, scaling: Optional[Any] = None):
conf["scaling"] = scaling
return conf
- assert parse(test_conf(1)).replicas == Range(min=1, max=1)
- assert parse(test_conf("2")).replicas == Range(min=2, max=2)
- assert parse(test_conf("3..3")).replicas == Range(min=3, max=3)
+ assert parse_run_configuration(test_conf(1)).replicas == Range(min=1, max=1)
+ assert parse_run_configuration(test_conf("2")).replicas == Range(min=2, max=2)
+ assert parse_run_configuration(test_conf("3..3")).replicas == Range(min=3, max=3)
with pytest.raises(
ConfigurationError,
match="When you set `replicas` to a range, ensure to specify `scaling`",
):
- parse(test_conf("0..10"))
- assert parse(
+ parse_run_configuration(test_conf("0..10"))
+ assert parse_run_configuration(
test_conf(
"0..10",
{
@@ -41,7 +41,7 @@ def test_conf(replicas: Any, scaling: Optional[Any] = None):
ConfigurationError,
match="When you set `replicas` to a range, ensure to specify `scaling`",
):
- parse(
+ parse_run_configuration(
test_conf(
"0..10",
{
diff --git a/src/tests/_internal/server/routers/test_gateways.py b/src/tests/_internal/server/routers/test_gateways.py
index 7fe394da3..6a40947ab 100644
--- a/src/tests/_internal/server/routers/test_gateways.py
+++ b/src/tests/_internal/server/routers/test_gateways.py
@@ -73,6 +73,14 @@ async def test_list(self, test_db, session: AsyncSession):
"name": gateway.name,
"region": gateway.region,
"wildcard_domain": gateway.wildcard_domain,
+ "configuration": {
+ "type": "gateway",
+ "name": gateway.name,
+ "backend": backend.type.value,
+ "region": gateway.region,
+ "domain": gateway.wildcard_domain,
+ "default": False,
+ },
}
]
@@ -109,6 +117,14 @@ async def test_get(self, test_db, session: AsyncSession):
"name": gateway.name,
"region": gateway.region,
"wildcard_domain": gateway.wildcard_domain,
+ "configuration": {
+ "type": "gateway",
+ "name": gateway.name,
+ "backend": backend.type.value,
+ "region": gateway.region,
+ "domain": gateway.wildcard_domain,
+ "default": False,
+ },
}
@pytest.mark.asyncio
@@ -180,6 +196,14 @@ async def test_create_gateway(self, test_db, session: AsyncSession):
"wildcard_domain": None,
"default": True,
"created_at": response.json()["created_at"],
+ "configuration": {
+ "type": "gateway",
+ "name": "test",
+ "backend": backend.type.value,
+ "region": "us",
+ "domain": None,
+ "default": True,
+ },
}
@pytest.mark.asyncio
@@ -226,6 +250,14 @@ async def test_create_gateway_without_name(self, test_db, session: AsyncSession)
"wildcard_domain": None,
"default": True,
"created_at": response.json()["created_at"],
+ "configuration": {
+ "type": "gateway",
+ "name": "random-name",
+ "backend": backend.type.value,
+ "region": "us",
+ "domain": None,
+ "default": True,
+ },
}
@pytest.mark.asyncio
@@ -352,6 +384,14 @@ async def test_set_default_gateway(self, test_db, session: AsyncSession):
"name": gateway.name,
"region": gateway.region,
"wildcard_domain": gateway.wildcard_domain,
+ "configuration": {
+ "type": "gateway",
+ "name": gateway.name,
+ "backend": backend.type.value,
+ "region": gateway.region,
+ "domain": gateway.wildcard_domain,
+ "default": True,
+ },
}
@pytest.mark.asyncio
@@ -451,6 +491,14 @@ def get_backend(_, backend_type):
"name": gateway_gcp.name,
"region": gateway_gcp.region,
"wildcard_domain": gateway_gcp.wildcard_domain,
+ "configuration": {
+ "type": "gateway",
+ "name": gateway_gcp.name,
+ "backend": backend_gcp.type.value,
+ "region": gateway_gcp.region,
+ "domain": gateway_gcp.wildcard_domain,
+ "default": False,
+ },
}
]
@@ -502,6 +550,14 @@ async def test_set_wildcard_domain(self, test_db, session: AsyncSession):
"name": gateway.name,
"region": gateway.region,
"wildcard_domain": "test.com",
+ "configuration": {
+ "type": "gateway",
+ "name": gateway.name,
+ "backend": backend.type.value,
+ "region": gateway.region,
+ "domain": "test.com",
+ "default": False,
+ },
}
@pytest.mark.asyncio