Skip to content

Commit

Permalink
Support concurrent usage of fake pager and poller APIs (#995)
Browse files Browse the repository at this point in the history
* Support concurrent usage of fake pager and poller APIs

Create a pager/poller state machine per combination of HTTP verb and URL
path so that API calls for different resources don't overlap.
Bump version number in preparation for release.

* Remove fake poller status path from URL's path

Don't use the HTTP verb when stashing the state machine as it can change
for an LRO depending on the operation.
  • Loading branch information
jhendrixMSFT authored Jul 18, 2023
1 parent 8ca37a3 commit d98f4ae
Show file tree
Hide file tree
Showing 184 changed files with 8,334 additions and 4,676 deletions.
2 changes: 1 addition & 1 deletion packages/autorest.go/package.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"name": "@autorest/go",
"version": "4.0.0-preview.52",
"version": "4.0.0-preview.53",
"description": "AutoRest Go Generator",
"main": "dist/exports.js",
"typings": "dist/exports.d.ts",
Expand Down
79 changes: 72 additions & 7 deletions packages/autorest.go/src/generator/fake/internal.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,23 +5,44 @@

import { Session } from '@autorest/extension-base';
import { CodeModel } from '@autorest/codemodel';
import { values } from '@azure-tools/linq';
import { contentPreamble } from '../helpers';
import { ImportManager } from '../imports';
import { isLROOperation, isPageableOperation } from '../../common/helpers';

export async function generateServerInternal(session: Session<CodeModel>): Promise<string> {
if (session.model.operationGroups.length === 0) {
return '';
}
const text = await contentPreamble(session, 'fake');
return text + content;
const imports = new ImportManager();
imports.add('io');
imports.add('net/http');
imports.add('reflect');
let body = content;
// only generate the tracker content if required
let needsTracker = false;
for (const group of values(session.model.operationGroups)) {
for (const op of values(group.operations)) {
if (isLROOperation(op) || isPageableOperation(op)) {
needsTracker = true;
break;
}
}
if (needsTracker) {
break;
}
}
if (needsTracker) {
imports.add('regexp');
imports.add('strings');
imports.add('sync');
body += tracker;
}
return text + imports.text() + body;
}

const content = `
import (
"io"
"net/http"
"reflect"
)
type nonRetriableError struct {
error
}
Expand Down Expand Up @@ -85,3 +106,47 @@ func contains[T comparable](s []T, v T) bool {
return false
}
`;

const tracker = `
func newTracker[T any]() *tracker[T] {
return &tracker[T]{
items: map[string]*T{},
}
}
type tracker[T any] struct {
items map[string]*T
mu sync.Mutex
}
func (p *tracker[T]) key(req *http.Request) string {
path := req.URL.Path
if match, _ := regexp.Match(\`/page_\\d+$\`, []byte(path)); match {
path = path[:strings.LastIndex(path, "/")]
} else if strings.HasSuffix(path, "/get/fake/status") {
path = path[:len(path)-16]
}
return path
}
func (p *tracker[T]) get(req *http.Request) *T {
p.mu.Lock()
defer p.mu.Unlock()
if item, ok := p.items[p.key(req)]; ok {
return item
}
return nil
}
func (p *tracker[T]) add(req *http.Request, item *T) {
p.mu.Lock()
defer p.mu.Unlock()
p.items[p.key(req)] = item
}
func (p *tracker[T]) remove(req *http.Request) {
p.mu.Lock()
defer p.mu.Unlock()
delete(p.items, p.key(req))
}
`;
61 changes: 46 additions & 15 deletions packages/autorest.go/src/generator/fake/servers.ts
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ export async function generateServers(session: Session<CodeModel>): Promise<Arra

// we might remove some operations from the list
const finalOperations = new Array<Operation>();
let countLROs = 0;
let countPagers = 0;

for (const op of values(group.operations)) {
if (isLROOperation(op)) {
Expand Down Expand Up @@ -90,6 +92,11 @@ export async function generateServers(session: Session<CodeModel>): Promise<Arra
}
content += `\t${operationName} func(${getAPIParametersSig(op, imports, clientPkg)}) (${op.language.go!.serverResponse})\n\n`;
finalOperations.push(op);
if (isLROOperation(op)) {
++countLROs;
} else if (isPageableOperation(op)) {
++countPagers;
}
}

content += '}\n\n';
Expand All @@ -102,22 +109,38 @@ export async function generateServers(session: Session<CodeModel>): Promise<Arra
content += `// The returned ${serverTransport} instance is connected to an instance of ${clientPkg}.${group.language.go!.clientName} via the\n`;
content += '// azcore.ClientOptions.Transporter field in the client\'s constructor parameters.\n';
content += `func New${serverTransport}(srv *${serverName}) *${serverTransport} {\n`;
content += `\treturn &${serverTransport}{srv: srv}\n}\n\n`;
if (countLROs === 0 && countPagers === 0) {
content += `\treturn &${serverTransport}{srv: srv}\n}\n\n`;
} else {
content += `\treturn &${serverTransport}{\n\t\tsrv: srv,\n`;
for (const op of values(finalOperations)) {
let respType = `${clientPkg}.${getResponseEnvelopeName(op)}`;
if (isLROOperation(op)) {
if (isPageableOperation(op)) {
respType = `azfake.PagerResponder[${clientPkg}.${getResponseEnvelopeName(op)}]`;
}
content += `\t\t${uncapitalize(fixUpOperationName(op))}: newTracker[azfake.PollerResponder[${respType}]](),\n`;
} else if (isPageableOperation(op)) {
content += `\t\t${uncapitalize(fixUpOperationName(op))}: newTracker[azfake.PagerResponder[${respType}]](),\n`;
}
}
content += '\t}\n}\n\n';
}

content += `// ${serverTransport} connects instances of ${clientPkg}.${group.language.go!.clientName} to instances of ${serverName}.\n`;
content += `// Don't use this type directly, use New${serverTransport} instead.\n`;
content += `type ${serverTransport} struct {\n`;
content += `\tsrv *${serverName}\n`;
for (const op of values(finalOperations)) {
// create state machines for any pager/poller operations
let respType = `${clientPkg}.${getResponseEnvelopeName(op)}`;
if (isLROOperation(op)) {
let respType = `${clientPkg}.${getResponseEnvelopeName(op)}`;
if (isPageableOperation(op)) {
respType = `azfake.PagerResponder[${clientPkg}.${getResponseEnvelopeName(op)}]`;
}
content +=`\t${uncapitalize(fixUpOperationName(op))} *azfake.PollerResponder[${respType}]\n`;
content +=`\t${uncapitalize(fixUpOperationName(op))} *tracker[azfake.PollerResponder[${respType}]]\n`;
} else if (isPageableOperation(op)) {
content += `\t${uncapitalize(fixUpOperationName(op))} *azfake.PagerResponder[${clientPkg}.${getResponseEnvelopeName(op)}]\n`;
content += `\t${uncapitalize(fixUpOperationName(op))} *tracker[azfake.PagerResponder[${respType}]]\n`;
}
}
content += '}\n\n';
Expand Down Expand Up @@ -412,47 +435,55 @@ function dispatchForOperationBody(clientPkg: string, receiverName: string, op: O

function dispatchForLROBody(clientPkg: string, receiverName: string, op: Operation, imports: ImportManager): string {
const operationName = fixUpOperationName(op);
const localVarName = uncapitalize(operationName);
const operationStateMachine = `${receiverName}.${uncapitalize(operationName)}`;
let content = `\tif ${operationStateMachine} == nil {\n`;
let content = `\t${localVarName} := ${operationStateMachine}.get(req)\n`;
content += `\tif ${localVarName} == nil {\n`;
content += dispatchForOperationBody(clientPkg, receiverName, op, imports);
content += `\t\t${operationStateMachine} = &respr\n`;
content += `\t\t${localVarName} = &respr\n`;
content += `\t\t${operationStateMachine}.add(req, ${localVarName})\n`;
content += '\t}\n\n';

content += `\tresp, err := server.PollerResponderNext(${operationStateMachine}, req)\n`;
content += `\tresp, err := server.PollerResponderNext(${localVarName}, req)\n`;
content += '\tif err != nil {\n\t\treturn nil, err\n\t}\n\n';

const formattedStatusCodes = formatStatusCodes(getStatusCodes(op));
content += `\tif !contains([]int{${formattedStatusCodes}}, resp.StatusCode) {\n`;
content += `\t\t${operationStateMachine}.remove(req)\n`;
content += `\t\treturn nil, &nonRetriableError{fmt.Errorf("unexpected status code %d. acceptable values are ${formattedStatusCodes}", resp.StatusCode)}\n\t}\n`;

content += `\tif !server.PollerResponderMore(${operationStateMachine}) {\n`;
content += `\t\t${operationStateMachine} = nil\n\t}\n\n`;
content += `\tif !server.PollerResponderMore(${localVarName}) {\n`;
content += `\t\t${operationStateMachine}.remove(req)\n\t}\n\n`;
content += '\treturn resp, nil\n';
return content;
}

function dispatchForPagerBody(clientPkg: string, receiverName: string, op: Operation, imports: ImportManager): string {
const operationName = fixUpOperationName(op);
const localVarName = uncapitalize(operationName);
const operationStateMachine = `${receiverName}.${uncapitalize(operationName)}`;
let content = `\tif ${operationStateMachine} == nil {\n`;
let content = `\t${localVarName} := ${operationStateMachine}.get(req)\n`;
content += `\tif ${localVarName} == nil {\n`;
content += dispatchForOperationBody(clientPkg, receiverName, op, imports);
content += `\t\t${operationStateMachine} = &resp\n`;
content += `\t\t${localVarName} = &resp\n`;
content += `\t\t${operationStateMachine}.add(req, ${localVarName})\n`;
if (op.language.go!.paging.nextLinkName) {
imports.add('github.com/Azure/azure-sdk-for-go/sdk/azcore/to');
content += `\t\tserver.PagerResponderInjectNextLinks(${operationStateMachine}, req, func(page *${clientPkg}.${getResponseEnvelopeName(op)}, createLink func() string) {\n`;
content += `\t\tserver.PagerResponderInjectNextLinks(${localVarName}, req, func(page *${clientPkg}.${getResponseEnvelopeName(op)}, createLink func() string) {\n`;
content += `\t\t\tpage.${op.language.go!.paging.nextLinkName} = to.Ptr(createLink())\n`;
content += '\t\t})\n';
}
content += '\t}\n'; // end if
content += `\tresp, err := server.PagerResponderNext(${operationStateMachine}, req)\n`;
content += `\tresp, err := server.PagerResponderNext(${localVarName}, req)\n`;
content += '\tif err != nil {\n\t\treturn nil, err\n\t}\n';

const formattedStatusCodes = formatStatusCodes(getStatusCodes(op));
content += `\tif !contains([]int{${formattedStatusCodes}}, resp.StatusCode) {\n`;
content += `\t\t${operationStateMachine}.remove(req)\n`;
content += `\t\treturn nil, &nonRetriableError{fmt.Errorf("unexpected status code %d. acceptable values are ${formattedStatusCodes}", resp.StatusCode)}\n\t}\n`;

content += `\tif !server.PagerResponderMore(${operationStateMachine}) {\n`;
content += `\t\t${operationStateMachine} = nil\n\t}\n`;
content += `\tif !server.PagerResponderMore(${localVarName}) {\n`;
content += `\t\t${operationStateMachine}.remove(req)\n\t}\n`;
content += '\treturn resp, nil\n';
return content;
}
Expand Down

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit d98f4ae

Please sign in to comment.