Skip to content

Commit

Permalink
Merge pull request #237 from necusjz/support-final-state-schema
Browse files Browse the repository at this point in the history
support lro final-state-schema
  • Loading branch information
kairu-ms authored Apr 19, 2023
2 parents 29340d3 + 6a30466 commit 93201ac
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 0 deletions.
16 changes: 16 additions & 0 deletions src/aaz_dev/swagger/model/schema/cmd_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,16 @@

from swagger.utils import exceptions
from .fields import MutabilityEnum
from .response import Response
from .schema import ReferenceSchema
from .x_ms_pageable import XmsPageable
from functools import reduce
from utils.case import to_camel_case
import logging
import re

logger = logging.getLogger("backend")


class CMDBuilder:

Expand Down Expand Up @@ -549,6 +554,17 @@ def classify_responses(schema):
# append 204 No Content response at the end of success response
success_responses.append(success_204_response)

success_codes = reduce(lambda x, y: x | y, [codes for codes, _ in success_responses])
if schema.x_ms_long_running_operation and not success_codes & {200, 201}:
if lro_schema := schema.x_ms_lro_final_state_schema:
lro_response = Response()
lro_response.description = "Response schema for long-running operation."
lro_response.schema = lro_schema

success_responses.append(({200, 201}, lro_response)) # use `final-state-schema` as response
else:
logger.warning(f"No response schema for long-running-operation: {schema.operation_id}.")

# # default response
# if 'default' not in error_responses and len(error_responses) == 1:
# p_resp, p_model = [*error_responses.values()][0]
Expand Down
12 changes: 12 additions & 0 deletions src/aaz_dev/swagger/model/schema/operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
FormDataParameter, ParameterBase
from .reference import Reference, Linkable
from .response import Response
from .schema import ReferenceSchema
from .x_ms_long_running_operation import XmsLongRunningOperationField, XmsLongRunningOperationOptionsField
from .x_ms_odata import XmsODataField
from .x_ms_pageable import XmsPageableField
Expand Down Expand Up @@ -60,6 +61,7 @@ class Operation(Model, Linkable):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.x_ms_odata_instance = None
self.x_ms_lro_final_state_schema = None

def link(self, swagger_loader, *traces):
if self.is_linked():
Expand Down Expand Up @@ -119,6 +121,16 @@ def link(self, swagger_loader, *traces):
if isinstance(self.x_ms_odata_instance, Linkable):
self.x_ms_odata_instance.link(swagger_loader, *instance_traces)

if self.x_ms_long_running_operation_options is not None and \
self.x_ms_long_running_operation_options.final_state_schema is not None:
# `final-state-schema` to `$ref`
self.x_ms_lro_final_state_schema = ReferenceSchema()
self.x_ms_lro_final_state_schema.ref = self.x_ms_long_running_operation_options.final_state_schema
self.x_ms_lro_final_state_schema.link(
swagger_loader,
*self.traces, "x_ms_long_running_operation_options", "final_state_schema"
)

def to_cmd(self, builder, parent_parameters, **kwargs):
cmd_op = CMDHttpOperation()
if self.x_ms_long_running_operation:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,11 @@ class XmsLongRunningOperationOptions(Model):
deserialize_from='final-state-via',
)

final_state_schema = StringType(
serialized_name="final-state-schema",
deserialize_from="final-state-schema",
)


class XmsLongRunningOperationOptionsField(ModelType):

Expand Down
22 changes: 22 additions & 0 deletions src/aaz_dev/swagger/tests/schema_tests/test_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,3 +227,25 @@ def test_XmsParameterizedHost(self):
raise err
parsed += 1
print(f"Parsed: {parsed}")

def test_lro_final_state_schema(self):
from functools import reduce
from swagger.controller.command_generator import CommandGenerator

rp = next(self.get_mgmt_plane_resource_providers(
module_filter=lambda m: m.name == "dnsresolver",
resource_provider_filter=lambda r: r.name == "Microsoft.Network"
))
version = "2022-07-01"
r_id = "/subscriptions/{}/resourcegroups/{}/providers/microsoft.network/dnsresolvers/{}"

resource_map = rp.get_resource_map()
resource = resource_map[r_id][version]

generator = CommandGenerator()
generator.load_resources([resource])
command_group = generator.create_draft_command_group(resource, methods={"put"}) # only modify PUT operation
responses = command_group.commands[0].operations[0].http.responses
status_codes = reduce(lambda x, y: x + y, [r.status_codes for r in responses])

assert set(status_codes).issuperset({200, 201})

0 comments on commit 93201ac

Please sign in to comment.