Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support concurrent usage of fake pager and poller APIs #995

Merged
merged 2 commits into from
Jul 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion packages/autorest.go/package.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"name": "@autorest/go",
"version": "4.0.0-preview.52",
"version": "4.0.0-preview.53",
"description": "AutoRest Go Generator",
"main": "dist/exports.js",
"typings": "dist/exports.d.ts",
Expand Down
79 changes: 72 additions & 7 deletions packages/autorest.go/src/generator/fake/internal.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,23 +5,44 @@

import { Session } from '@autorest/extension-base';
import { CodeModel } from '@autorest/codemodel';
import { values } from '@azure-tools/linq';
import { contentPreamble } from '../helpers';
import { ImportManager } from '../imports';
import { isLROOperation, isPageableOperation } from '../../common/helpers';

export async function generateServerInternal(session: Session<CodeModel>): Promise<string> {
if (session.model.operationGroups.length === 0) {
return '';
}
const text = await contentPreamble(session, 'fake');
return text + content;
const imports = new ImportManager();
imports.add('io');
imports.add('net/http');
imports.add('reflect');
let body = content;
// only generate the tracker content if required
let needsTracker = false;
for (const group of values(session.model.operationGroups)) {
for (const op of values(group.operations)) {
if (isLROOperation(op) || isPageableOperation(op)) {
needsTracker = true;
break;
}
}
if (needsTracker) {
break;
}
}
if (needsTracker) {
imports.add('regexp');
imports.add('strings');
imports.add('sync');
body += tracker;
}
return text + imports.text() + body;
}

const content = `
import (
"io"
"net/http"
"reflect"
)

type nonRetriableError struct {
error
}
Expand Down Expand Up @@ -85,3 +106,47 @@ func contains[T comparable](s []T, v T) bool {
return false
}
`;

const tracker = `
func newTracker[T any]() *tracker[T] {
return &tracker[T]{
items: map[string]*T{},
}
}

type tracker[T any] struct {
items map[string]*T
mu sync.Mutex
}

func (p *tracker[T]) key(req *http.Request) string {
path := req.URL.Path
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This assumes that path will be unique per invocation. Probably always true for ARM, but is it universally true? I'm thinking perhaps a POST for the same path but with differing query parameters.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately we don't propagate query parameters when creating the fake poller or injecting next links for paged operations. Will fix this in the next azcore beta.

if match, _ := regexp.Match(\`/page_\\d+$\`, []byte(path)); match {
path = path[:strings.LastIndex(path, "/")]
} else if strings.HasSuffix(path, "/get/fake/status") {
path = path[:len(path)-16]
}
return path
}

func (p *tracker[T]) get(req *http.Request) *T {
p.mu.Lock()
defer p.mu.Unlock()
if item, ok := p.items[p.key(req)]; ok {
return item
}
return nil
}

func (p *tracker[T]) add(req *http.Request, item *T) {
p.mu.Lock()
defer p.mu.Unlock()
p.items[p.key(req)] = item
}

func (p *tracker[T]) remove(req *http.Request) {
p.mu.Lock()
defer p.mu.Unlock()
delete(p.items, p.key(req))
}
`;
61 changes: 46 additions & 15 deletions packages/autorest.go/src/generator/fake/servers.ts
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ export async function generateServers(session: Session<CodeModel>): Promise<Arra

// we might remove some operations from the list
const finalOperations = new Array<Operation>();
let countLROs = 0;
let countPagers = 0;

for (const op of values(group.operations)) {
if (isLROOperation(op)) {
Expand Down Expand Up @@ -90,6 +92,11 @@ export async function generateServers(session: Session<CodeModel>): Promise<Arra
}
content += `\t${operationName} func(${getAPIParametersSig(op, imports, clientPkg)}) (${op.language.go!.serverResponse})\n\n`;
finalOperations.push(op);
if (isLROOperation(op)) {
++countLROs;
} else if (isPageableOperation(op)) {
++countPagers;
}
}

content += '}\n\n';
Expand All @@ -102,22 +109,38 @@ export async function generateServers(session: Session<CodeModel>): Promise<Arra
content += `// The returned ${serverTransport} instance is connected to an instance of ${clientPkg}.${group.language.go!.clientName} via the\n`;
content += '// azcore.ClientOptions.Transporter field in the client\'s constructor parameters.\n';
content += `func New${serverTransport}(srv *${serverName}) *${serverTransport} {\n`;
content += `\treturn &${serverTransport}{srv: srv}\n}\n\n`;
if (countLROs === 0 && countPagers === 0) {
content += `\treturn &${serverTransport}{srv: srv}\n}\n\n`;
} else {
content += `\treturn &${serverTransport}{\n\t\tsrv: srv,\n`;
for (const op of values(finalOperations)) {
let respType = `${clientPkg}.${getResponseEnvelopeName(op)}`;
if (isLROOperation(op)) {
if (isPageableOperation(op)) {
respType = `azfake.PagerResponder[${clientPkg}.${getResponseEnvelopeName(op)}]`;
}
content += `\t\t${uncapitalize(fixUpOperationName(op))}: newTracker[azfake.PollerResponder[${respType}]](),\n`;
} else if (isPageableOperation(op)) {
content += `\t\t${uncapitalize(fixUpOperationName(op))}: newTracker[azfake.PagerResponder[${respType}]](),\n`;
}
}
content += '\t}\n}\n\n';
}

content += `// ${serverTransport} connects instances of ${clientPkg}.${group.language.go!.clientName} to instances of ${serverName}.\n`;
content += `// Don't use this type directly, use New${serverTransport} instead.\n`;
content += `type ${serverTransport} struct {\n`;
content += `\tsrv *${serverName}\n`;
for (const op of values(finalOperations)) {
// create state machines for any pager/poller operations
let respType = `${clientPkg}.${getResponseEnvelopeName(op)}`;
if (isLROOperation(op)) {
let respType = `${clientPkg}.${getResponseEnvelopeName(op)}`;
if (isPageableOperation(op)) {
respType = `azfake.PagerResponder[${clientPkg}.${getResponseEnvelopeName(op)}]`;
}
content +=`\t${uncapitalize(fixUpOperationName(op))} *azfake.PollerResponder[${respType}]\n`;
content +=`\t${uncapitalize(fixUpOperationName(op))} *tracker[azfake.PollerResponder[${respType}]]\n`;
} else if (isPageableOperation(op)) {
content += `\t${uncapitalize(fixUpOperationName(op))} *azfake.PagerResponder[${clientPkg}.${getResponseEnvelopeName(op)}]\n`;
content += `\t${uncapitalize(fixUpOperationName(op))} *tracker[azfake.PagerResponder[${respType}]]\n`;
}
}
content += '}\n\n';
Expand Down Expand Up @@ -412,47 +435,55 @@ function dispatchForOperationBody(clientPkg: string, receiverName: string, op: O

function dispatchForLROBody(clientPkg: string, receiverName: string, op: Operation, imports: ImportManager): string {
const operationName = fixUpOperationName(op);
const localVarName = uncapitalize(operationName);
const operationStateMachine = `${receiverName}.${uncapitalize(operationName)}`;
let content = `\tif ${operationStateMachine} == nil {\n`;
let content = `\t${localVarName} := ${operationStateMachine}.get(req)\n`;
content += `\tif ${localVarName} == nil {\n`;
content += dispatchForOperationBody(clientPkg, receiverName, op, imports);
content += `\t\t${operationStateMachine} = &respr\n`;
content += `\t\t${localVarName} = &respr\n`;
content += `\t\t${operationStateMachine}.add(req, ${localVarName})\n`;
content += '\t}\n\n';

content += `\tresp, err := server.PollerResponderNext(${operationStateMachine}, req)\n`;
content += `\tresp, err := server.PollerResponderNext(${localVarName}, req)\n`;
content += '\tif err != nil {\n\t\treturn nil, err\n\t}\n\n';

const formattedStatusCodes = formatStatusCodes(getStatusCodes(op));
content += `\tif !contains([]int{${formattedStatusCodes}}, resp.StatusCode) {\n`;
content += `\t\t${operationStateMachine}.remove(req)\n`;
content += `\t\treturn nil, &nonRetriableError{fmt.Errorf("unexpected status code %d. acceptable values are ${formattedStatusCodes}", resp.StatusCode)}\n\t}\n`;

content += `\tif !server.PollerResponderMore(${operationStateMachine}) {\n`;
content += `\t\t${operationStateMachine} = nil\n\t}\n\n`;
content += `\tif !server.PollerResponderMore(${localVarName}) {\n`;
content += `\t\t${operationStateMachine}.remove(req)\n\t}\n\n`;
content += '\treturn resp, nil\n';
return content;
}

function dispatchForPagerBody(clientPkg: string, receiverName: string, op: Operation, imports: ImportManager): string {
const operationName = fixUpOperationName(op);
const localVarName = uncapitalize(operationName);
const operationStateMachine = `${receiverName}.${uncapitalize(operationName)}`;
let content = `\tif ${operationStateMachine} == nil {\n`;
let content = `\t${localVarName} := ${operationStateMachine}.get(req)\n`;
content += `\tif ${localVarName} == nil {\n`;
content += dispatchForOperationBody(clientPkg, receiverName, op, imports);
content += `\t\t${operationStateMachine} = &resp\n`;
content += `\t\t${localVarName} = &resp\n`;
content += `\t\t${operationStateMachine}.add(req, ${localVarName})\n`;
if (op.language.go!.paging.nextLinkName) {
imports.add('github.com/Azure/azure-sdk-for-go/sdk/azcore/to');
content += `\t\tserver.PagerResponderInjectNextLinks(${operationStateMachine}, req, func(page *${clientPkg}.${getResponseEnvelopeName(op)}, createLink func() string) {\n`;
content += `\t\tserver.PagerResponderInjectNextLinks(${localVarName}, req, func(page *${clientPkg}.${getResponseEnvelopeName(op)}, createLink func() string) {\n`;
content += `\t\t\tpage.${op.language.go!.paging.nextLinkName} = to.Ptr(createLink())\n`;
content += '\t\t})\n';
}
content += '\t}\n'; // end if
content += `\tresp, err := server.PagerResponderNext(${operationStateMachine}, req)\n`;
content += `\tresp, err := server.PagerResponderNext(${localVarName}, req)\n`;
content += '\tif err != nil {\n\t\treturn nil, err\n\t}\n';

const formattedStatusCodes = formatStatusCodes(getStatusCodes(op));
content += `\tif !contains([]int{${formattedStatusCodes}}, resp.StatusCode) {\n`;
content += `\t\t${operationStateMachine}.remove(req)\n`;
content += `\t\treturn nil, &nonRetriableError{fmt.Errorf("unexpected status code %d. acceptable values are ${formattedStatusCodes}", resp.StatusCode)}\n\t}\n`;

content += `\tif !server.PagerResponderMore(${operationStateMachine}) {\n`;
content += `\t\t${operationStateMachine} = nil\n\t}\n`;
content += `\tif !server.PagerResponderMore(${localVarName}) {\n`;
content += `\t\t${operationStateMachine}.remove(req)\n\t}\n`;
content += '\treturn resp, nil\n';
return content;
}
Expand Down

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading