Skip to content

Commit

Permalink
Wanl/fix fist preview version bump (#8671)
Browse files Browse the repository at this point in the history
# Issue
#8522 

# Problem
The bug is introduced by new codegen [pr](Azure/autorest.typescript@b0853b2). The way to define default api version is changed.
extracting the api version should also consider version of codegen.

e.g. new api-version at client level is defined [here](https://github.com/Azure/azure-sdk-for-js/blob/main/sdk/mongocluster/arm-mongocluster/src/rest/documentDBClient.ts#L23)
while old one is:
- [here](https://github.com/Azure/azure-sdk-for-js/blob/06716722818f5838cc10a9e5644b7ba9f32089d5/sdk/face/ai-vision-face-rest/src/faceClient.ts#L26)
- [another](https://github.com/Azure/azure-sdk-for-js/blob/06716722818f5838cc10a9e5644b7ba9f32089d5/sdk/openai/openai-rest/src/openAIClient.ts#L22)

# Solution
1. try to detect in new client, return api-version if found
2. fallback to detect old client, return whatever found
  • Loading branch information
wanlwanl authored Jul 22, 2024
1 parent b32d70e commit 8db91de
Show file tree
Hide file tree
Showing 15 changed files with 323 additions and 46 deletions.
Original file line number Diff line number Diff line change
@@ -1,81 +1,135 @@
import { SourceFile, SyntaxKind } from "ts-morph";
import shell from 'shelljs';
import path from 'path';
import { SourceFile, SyntaxKind } from "ts-morph";
import shell from "shelljs";
import path from "path";
import * as ts from "typescript";

import { ApiVersionType } from "../../common/types"
import { ApiVersionType } from "../../common/types";
import { IApiVersionTypeExtractor } from "../../common/interfaces";
import { getTsSourceFile } from "../../common/utils";
import { readFileSync } from "fs";

const findRestClientPath = (packageRoot: string): string => {
const restPath = path.join(packageRoot, 'src/rest/');
const restPath = path.join(packageRoot, "src/rest/");
const fileNames = shell.ls(restPath);
const clientFiles = fileNames.filter(f => f.endsWith("Client.ts"));
if (clientFiles.length !== 1) throw new Error(`Single client is supported, but found ${clientFiles}`);
const clientFiles = fileNames.filter((f) => f.endsWith("Client.ts"));
if (clientFiles.length !== 1)
throw new Error(`Single client is supported, but found "${clientFiles}" in ${restPath}`);

const clientPath = path.join(restPath, clientFiles[0]);
return clientPath;
};

const matchPattern = (text: string, pattern: RegExp): string | undefined => {
const match = text.match(pattern);
const found = match != null && match.length === 2;
return found ? match?.at(1) : undefined;
}

const findApiVersionInRestClient = (clientPath: string): string | undefined => {
const findApiVersionInRestClientV1 = (
clientPath: string
): string | undefined => {
const sourceFile = getTsSourceFile(clientPath);
const createClientFunction = sourceFile?.getFunction("createClient");
if (!createClientFunction) throw new Error("Function 'createClient' not found.");
if (!createClientFunction)
throw new Error("Function 'createClient' not found.");

const apiVersionStatements = createClientFunction.getStatements()
.filter(s =>
s.getKind() === SyntaxKind.ExpressionStatement &&
s.getText().indexOf("options.apiVersion") > -1);
if (apiVersionStatements.length === 0) return undefined;
const apiVersionStatements = createClientFunction
.getStatements()
.filter((s) => s.getText().includes("options.apiVersion"));
if (apiVersionStatements.length === 0) {
return undefined;
}
const text =
apiVersionStatements[apiVersionStatements.length - 1].getText();
return extractApiVersionFromText(text);
};

const text = apiVersionStatements[apiVersionStatements.length - 1].getText();
const pattern = /(\d{4}-\d{2}-\d{2}(?:-preview)?)/;
const apiVersion = matchPattern(text, pattern);
const extractApiVersionFromText = (text: string): string | undefined => {
const begin = text.indexOf('"');
const end = text.lastIndexOf('"');
return text.substring(begin + 1, end);
};

// new ways in @autorest/typespec-ts emitter to set up api-version
const findApiVersionInRestClientV2 = (clientPath: string): string | undefined => {
const sourceCode= readFileSync(clientPath, {encoding: 'utf-8'})
const sourceFile = ts.createSourceFile("example.ts", sourceCode, ts.ScriptTarget.Latest, true);
const createClientFunction = sourceFile.statements.filter(s => (s as ts.FunctionDeclaration)?.name?.escapedText === 'createClient').map(s => (s as ts.FunctionDeclaration))[0];
let apiVersion: string | undefined = undefined;
createClientFunction.parameters.forEach(p => {
const isBindingPattern = node => node && typeof node === "object" && "elements" in node && "parent" in node && "kind" in node;
if (!isBindingPattern(p.name)) {
return;
}
const binding = p.name as ts.ObjectBindingPattern;
const apiVersionTexts = binding.elements?.filter(e => (e.name as ts.Identifier)?.escapedText === "apiVersion").map(e => e.initializer?.getText());
// apiVersionTexts.length must be 0 or 1, otherwise the binding pattern contains the same keys, which causes a ts error
if (apiVersionTexts.length === 1 && apiVersionTexts[0]) {
apiVersion = extractApiVersionFromText(apiVersionTexts[0]);
}
});
return apiVersion;
};

const getApiVersionTypeFromRestClient: IApiVersionTypeExtractor = (packageRoot: string): ApiVersionType => {
// workaround for createClient function changes it's way to setup api-version
export const findApiVersionInRestClient = (clientPath: string): string | undefined => {
const version2 = findApiVersionInRestClientV2(clientPath);
if (version2) {
return version2;
}
const version1 = findApiVersionInRestClientV1(clientPath);
return version1;
};

const getApiVersionTypeFromRestClient: IApiVersionTypeExtractor = (
packageRoot: string
): ApiVersionType => {
const clientPath = findRestClientPath(packageRoot);
const apiVersion = findApiVersionInRestClient(clientPath);
if (apiVersion && apiVersion.indexOf("-preview") >= 0) return ApiVersionType.Preview;
if (apiVersion && apiVersion.indexOf("-preview") < 0) return ApiVersionType.Stable;
if (apiVersion && apiVersion.indexOf("-preview") >= 0)
return ApiVersionType.Preview;
if (apiVersion && apiVersion.indexOf("-preview") < 0)
return ApiVersionType.Stable;
return ApiVersionType.None;
};

const findApiVersionsInOperations = (sourceFile: SourceFile | undefined): Array<string> | undefined => {
const findApiVersionsInOperations = (
sourceFile: SourceFile | undefined
): Array<string> | undefined => {
const interfaces = sourceFile?.getInterfaces();
const interfacesWithApiVersion = interfaces?.filter(itf => itf.getProperty('"api-version"'));
const apiVersions = interfacesWithApiVersion?.map(itf => {
const property = itf.getMembers()
.filter(m => {
const defaultValue = m.getChildrenOfKind(SyntaxKind.StringLiteral)[0];
return defaultValue && defaultValue.getText() === '"api-version"';
})[0];
const apiVersion = property.getChildrenOfKind(SyntaxKind.LiteralType)[0].getText();
const interfacesWithApiVersion = interfaces?.filter((itf) =>
itf.getProperty('"api-version"')
);
const apiVersions = interfacesWithApiVersion?.map((itf) => {
const property = itf.getMembers().filter((m) => {
const defaultValue = m.getChildrenOfKind(
SyntaxKind.StringLiteral
)[0];
return defaultValue && defaultValue.getText() === '"api-version"';
})[0];
const apiVersion = property
.getChildrenOfKind(SyntaxKind.LiteralType)[0]
.getText();
return apiVersion;
});
return apiVersions;
}
};

const getApiVersionTypeFromOperations: IApiVersionTypeExtractor = (packageRoot: string): ApiVersionType => {
const paraPath = path.join(packageRoot, 'src/rest/parameters.ts');
const getApiVersionTypeFromOperations: IApiVersionTypeExtractor = (
packageRoot: string
): ApiVersionType => {
const paraPath = path.join(packageRoot, "src/rest/parameters.ts");
const sourceFile = getTsSourceFile(paraPath);
const apiVersions = findApiVersionsInOperations(sourceFile);
if (!apiVersions) return ApiVersionType.None;
const previewVersions = apiVersions.filter(v => v.indexOf("-preview") >= 0);
return previewVersions.length > 0 ? ApiVersionType.Preview : ApiVersionType.Stable;
const previewVersions = apiVersions.filter(
(v) => v.indexOf("-preview") >= 0
);
return previewVersions.length > 0
? ApiVersionType.Preview
: ApiVersionType.Stable;
};

// TODO: add unit test
export const getApiVersionType: IApiVersionTypeExtractor = (packageRoot: string): ApiVersionType => {
export const getApiVersionType: IApiVersionTypeExtractor = (
packageRoot: string
): ApiVersionType => {
const typeFromClient = getApiVersionTypeFromRestClient(packageRoot);
if (typeFromClient !== ApiVersionType.None) return typeFromClient;
const typeFromOperations = getApiVersionTypeFromOperations(packageRoot);
if (typeFromOperations !== ApiVersionType.None) return typeFromOperations;
return ApiVersionType.Stable;
}
};
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import { expect, test } from "vitest";
import { findApiVersionInRestClient, getApiVersionType } from "../../mlc/apiVersion/apiVersionTypeExtractor";
import { join } from "path";
import { ApiVersionType } from "../../common/types";

test("MLC api-version Extractor: new createClient function", async () => {
const clientPath = join(__dirname, 'testCases/new/src/rest/newClient.ts');
const version = findApiVersionInRestClient(clientPath);
expect(version).toBe('2024-03-01-preview');
});

test("MLC api-version Extractor: get api version type from new createClient function", async () => {
const root = join(__dirname, 'testCases/new/');
const version = getApiVersionType(root);
expect(version).toBe(ApiVersionType.Preview);
});

test("MLC api-version Extractor: old createClient function 1", async () => {
const clientPath = join(__dirname, 'testCases/old1/src/rest/oldClient.ts');
const version = findApiVersionInRestClient(clientPath);
expect(version).toBe('v1.1-preview.1');
});

test("MLC api-version Extractor: get api version type from old createClient function 1", async () => {
const root = join(__dirname, 'testCases/old1/');
const version = getApiVersionType(root);
expect(version).toBe(ApiVersionType.Preview);
});

test("MLC api-version Extractor: old createClient function 2", async () => {
const clientPath = join(__dirname, 'testCases/old2/src/rest/oldClient.ts');
const version = findApiVersionInRestClient(clientPath);
expect(version).toBe('2024-03-01');
});

test("MLC api-version Extractor: get api version type from old createClient function 2", async () => {
const root = join(__dirname, 'testCases/old2/');
const version = getApiVersionType(root);
expect(version).toBe(ApiVersionType.Stable);
});
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.

import { getClient, ClientOptions } from "@azure-rest/core-client";
import { logger } from "../logger.js";
import { TokenCredential } from "@azure/core-auth";
import { DocumentDBContext } from "./clientDefinitions.js";

/** The optional parameters for the client */
export interface DocumentDBContextOptions extends ClientOptions {
/** The api version option of the client */
apiVersion?: string;
}

/**
* Initialize a new instance of `DocumentDBContext`
* @param credentials - uniquely identify client credential
* @param options - the parameter for all optional parameters
*/
export default function createClient(
credentials: TokenCredential,
{
apiVersion = "2024-03-01-preview",
...options
}: DocumentDBContextOptions = {},
): DocumentDBContext {
const endpointUrl =
options.endpoint ?? options.baseUrl ?? `https://management.azure.com`;
const userAgentInfo = `azsdk-js-arm-mongocluster/1.0.0-beta.1`;
const userAgentPrefix =
options.userAgentOptions && options.userAgentOptions.userAgentPrefix
? `${options.userAgentOptions.userAgentPrefix} ${userAgentInfo}`
: `${userAgentInfo}`;
options = {
...options,
userAgentOptions: {
userAgentPrefix,
},
loggingOptions: {
logger: options.loggingOptions?.logger ?? logger.info,
},
credentials: {
scopes: options.credentials?.scopes ?? [`${endpointUrl}/.default`],
},
};
const client = getClient(
endpointUrl,
credentials,
options,
) as DocumentDBContext;

client.pipeline.removePolicy({ name: "ApiVersionPolicy" });
client.pipeline.addPolicy({
name: "ClientApiVersionPolicy",
sendRequest: (req, next) => {
// Use the apiVersion defined in request url directly
// Append one if there is no apiVersion and we have one at client options
const url = new URL(req.url);
if (!url.searchParams.get("api-version") && apiVersion) {
req.url = `${req.url}${Array.from(url.searchParams.keys()).length > 0 ? "&" : "?"
}api-version=${apiVersion}`;
}

return next(req);
},
});
return client;
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.

import { getClient, ClientOptions } from "@azure-rest/core-client";
import { logger } from "./logger.js";
import { TokenCredential, KeyCredential } from "@azure/core-auth";
import { FaceClient } from "./clientDefinitions.js";
import { Versions } from "./models.js";

export interface FaceClientOptions extends ClientOptions {
apiVersion?: Versions;
}

/**
* Initialize a new instance of `FaceClient`
* @param endpointParam - Supported Cognitive Services endpoints (protocol and hostname, for example:
* https://{resource-name}.cognitiveservices.azure.com).
* @param credentials - uniquely identify client credential
* @param options - the parameter for all optional parameters
*/
export default function createClient(
endpointParam: string,
credentials: TokenCredential | KeyCredential,
options: FaceClientOptions = {},
): FaceClient {
const apiVersion = options.apiVersion ?? "v1.1-preview.1";
const endpointUrl = options.endpoint ?? options.baseUrl ?? `${endpointParam}/face/${apiVersion}`;

const userAgentInfo = `azsdk-js-ai-vision-face-rest/1.0.0-beta.1`;
const userAgentPrefix =
options.userAgentOptions && options.userAgentOptions.userAgentPrefix
? `${options.userAgentOptions.userAgentPrefix} ${userAgentInfo}`
: `${userAgentInfo}`;
options = {
...options,
userAgentOptions: {
userAgentPrefix,
},
loggingOptions: {
logger: options.loggingOptions?.logger ?? logger.info,
},
credentials: {
scopes: options.credentials?.scopes ?? ["https://cognitiveservices.azure.com/.default"],
apiKeyHeaderName: options.credentials?.apiKeyHeaderName ?? "Ocp-Apim-Subscription-Key",
},
};

const client = getClient(endpointUrl, credentials, options) as FaceClient;

client.pipeline.removePolicy({ name: "ApiVersionPolicy" });

client.pipeline.addPolicy({
name: "VerifyImageFilenamePolicy",
sendRequest: (request, next) => {
for (const part of request.multipartBody?.parts ?? []) {
const contentDisposition = part.headers.get("content-disposition");
if (
contentDisposition &&
contentDisposition.includes(`name="VerifyImage"`) &&
!contentDisposition.includes("filename=")
) {
part.headers.set("content-disposition", `form-data; name="VerifyImage"; filename="blob"`);
}
}
return next(request);
},
});

return client;
}
Loading

0 comments on commit 8db91de

Please sign in to comment.