Skip to content

Commit

Permalink
🪶 feat: Add Support for Azure OpenAI Base URL (danny-avila#1596)
Browse files Browse the repository at this point in the history
* refactor(extractBaseURL): add handling for all possible Cloudflare AI Gateway endpoints

* chore: added endpointoption todo for updating type and optimizing handling app-wide

* feat(azureUtils):
- `genAzureChatCompletion`: allow optional client pass to update azure property
- `constructAzureURL`: optionally replace placeholders for instance and deployment names of an azure baseURL
- add tests for module

* refactor(extractBaseURL): return entire input when cloudflare `azure-openai` suffix detected
- also add more tests for both construct and extract URL

* refactor(genAzureChatCompletion): only allow omitting instance name if baseURL is not set

* refactor(initializeClient): determine `reverseProxyUrl` based on endpoint (azure or openai)

* refactor: utitlize `constructAzureURL` when `AZURE_OPENAI_BASEURL` is set

* docs: update docs on `AZURE_OPENAI_BASEURL`

* fix(ci): update expected error message for `azureUtils` tests
  • Loading branch information
danny-avila authored Jan 19, 2024
1 parent e6746c4 commit db20283
Show file tree
Hide file tree
Showing 8 changed files with 453 additions and 45 deletions.
21 changes: 16 additions & 5 deletions app/clients/OpenAIClient.js
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,13 @@ const OpenAI = require('openai');
const { HttpsProxyAgent } = require('https-proxy-agent');
const { getResponseSender, ImageDetailCost, ImageDetail } = require('librechat-data-provider');
const { encoding_for_model: encodingForModel, get_encoding: getEncoding } = require('tiktoken');
const {
getModelMaxTokens,
genAzureChatCompletion,
extractBaseURL,
constructAzureURL,
} = require('~/utils');
const { encodeAndFormat, validateVisionModel } = require('~/server/services/Files/images');
const { getModelMaxTokens, genAzureChatCompletion, extractBaseURL } = require('~/utils');
const { truncateText, formatMessage, CUT_OFF_PROMPT } = require('./prompts');
const { handleOpenAIErrors } = require('./tools/util');
const spendTokens = require('~/models/spendTokens');
Expand Down Expand Up @@ -32,6 +37,7 @@ class OpenAIClient extends BaseClient {
? options.contextStrategy.toLowerCase()
: 'discard';
this.shouldSummarize = this.contextStrategy === 'summarize';
/** @type {AzureOptions} */
this.azure = options.azure || false;
this.setOptions(options);
}
Expand Down Expand Up @@ -104,10 +110,10 @@ class OpenAIClient extends BaseClient {
}

if (this.azure && process.env.AZURE_OPENAI_DEFAULT_MODEL) {
this.azureEndpoint = genAzureChatCompletion(this.azure, this.modelOptions.model);
this.azureEndpoint = genAzureChatCompletion(this.azure, this.modelOptions.model, this);
this.modelOptions.model = process.env.AZURE_OPENAI_DEFAULT_MODEL;
} else if (this.azure) {
this.azureEndpoint = genAzureChatCompletion(this.azure, this.modelOptions.model);
this.azureEndpoint = genAzureChatCompletion(this.azure, this.modelOptions.model, this);
}

const { model } = this.modelOptions;
Expand Down Expand Up @@ -711,7 +717,7 @@ class OpenAIClient extends BaseClient {

if (this.azure) {
modelOptions.model = process.env.AZURE_OPENAI_DEFAULT_MODEL ?? modelOptions.model;
this.azureEndpoint = genAzureChatCompletion(this.azure, modelOptions.model);
this.azureEndpoint = genAzureChatCompletion(this.azure, modelOptions.model, this);
}

const instructionsPayload = [
Expand Down Expand Up @@ -949,7 +955,12 @@ ${convo}
// Azure does not accept `model` in the body, so we need to remove it.
delete modelOptions.model;

opts.baseURL = this.azureEndpoint.split('/chat')[0];
opts.baseURL = this.langchainProxy
? constructAzureURL({
baseURL: this.langchainProxy,
azure: this.azure,
})
: this.azureEndpoint.split(/\/(chat|completion)/)[0];
opts.defaultQuery = { 'api-version': this.azure.azureOpenAIApiVersion };
opts.defaultHeaders = { ...opts.defaultHeaders, 'api-key': this.apiKey };
}
Expand Down
13 changes: 9 additions & 4 deletions app/clients/llm/createLLM.js
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
const { ChatOpenAI } = require('langchain/chat_models/openai');
const { sanitizeModelName } = require('../../../utils');
const { isEnabled } = require('../../../server/utils');
const { sanitizeModelName, constructAzureURL } = require('~/utils');
const { isEnabled } = require('~/server/utils');

/**
* Creates a new instance of a language model (LLM) for chat interactions.
Expand Down Expand Up @@ -36,6 +36,7 @@ function createLLM({
apiKey: openAIApiKey,
};

/** @type {AzureOptions} */
let azureOptions = {};
if (azure) {
const useModelName = isEnabled(process.env.AZURE_USE_MODEL_AS_DEPLOYMENT_NAME);
Expand All @@ -53,8 +54,12 @@ function createLLM({
modelOptions.modelName = process.env.AZURE_OPENAI_DEFAULT_MODEL;
}

// console.debug('createLLM: configOptions');
// console.debug(configOptions);
if (azure && configOptions.basePath) {
configOptions.basePath = constructAzureURL({
baseURL: configOptions.basePath,
azure: azureOptions,
});
}

return new ChatOpenAI(
{
Expand Down
35 changes: 25 additions & 10 deletions server/services/Endpoints/gptPlugins/initializeClient.js
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
const { PluginsClient } = require('~/app');
const { isEnabled } = require('~/server/utils');
const { getAzureCredentials } = require('~/utils');
const { EModelEndpoint } = require('librechat-data-provider');
const { getUserKey, checkUserKeyExpiry } = require('~/server/services/UserService');
const { getAzureCredentials } = require('~/utils');
const { isEnabled } = require('~/server/utils');
const { PluginsClient } = require('~/app');

const initializeClient = async ({ req, res, endpointOption }) => {
const {
Expand All @@ -10,26 +11,40 @@ const initializeClient = async ({ req, res, endpointOption }) => {
AZURE_API_KEY,
PLUGINS_USE_AZURE,
OPENAI_REVERSE_PROXY,
AZURE_OPENAI_BASEURL,
OPENAI_SUMMARIZE,
DEBUG_PLUGINS,
} = process.env;

const { key: expiresAt } = req.body;
const contextStrategy = isEnabled(OPENAI_SUMMARIZE) ? 'summarize' : null;

const useAzure = isEnabled(PLUGINS_USE_AZURE);
const endpoint = useAzure ? EModelEndpoint.azureOpenAI : EModelEndpoint.openAI;

const baseURLOptions = {
[EModelEndpoint.openAI]: OPENAI_REVERSE_PROXY,
[EModelEndpoint.azureOpenAI]: AZURE_OPENAI_BASEURL,
};

const reverseProxyUrl = baseURLOptions[endpoint] ?? null;

const clientOptions = {
contextStrategy,
debug: isEnabled(DEBUG_PLUGINS),
reverseProxyUrl: OPENAI_REVERSE_PROXY ?? null,
reverseProxyUrl,
proxy: PROXY ?? null,
req,
res,
...endpointOption,
};

const useAzure = isEnabled(PLUGINS_USE_AZURE);
const credentials = {
[EModelEndpoint.openAI]: OPENAI_API_KEY,
[EModelEndpoint.azureOpenAI]: AZURE_API_KEY,
};

const isUserProvided = useAzure
? AZURE_API_KEY === 'user_provided'
: OPENAI_API_KEY === 'user_provided';
const isUserProvided = credentials[endpoint] === 'user_provided';

let userKey = null;
if (expiresAt && isUserProvided) {
Expand All @@ -39,11 +54,11 @@ const initializeClient = async ({ req, res, endpointOption }) => {
);
userKey = await getUserKey({
userId: req.user.id,
name: useAzure ? 'azureOpenAI' : 'openAI',
name: endpoint,
});
}

let apiKey = isUserProvided ? userKey : OPENAI_API_KEY;
let apiKey = isUserProvided ? userKey : credentials[endpoint];

if (useAzure || (apiKey && apiKey.includes('azure') && !clientOptions.azure)) {
clientOptions.azure = isUserProvided ? JSON.parse(userKey) : getAzureCredentials();
Expand Down
24 changes: 17 additions & 7 deletions server/services/Endpoints/openAI/initializeClient.js
Original file line number Diff line number Diff line change
@@ -1,32 +1,42 @@
const { OpenAIClient } = require('~/app');
const { isEnabled } = require('~/server/utils');
const { getAzureCredentials } = require('~/utils');
const { EModelEndpoint } = require('librechat-data-provider');
const { getUserKey, checkUserKeyExpiry } = require('~/server/services/UserService');
const { getAzureCredentials } = require('~/utils');
const { isEnabled } = require('~/server/utils');
const { OpenAIClient } = require('~/app');

const initializeClient = async ({ req, res, endpointOption }) => {
const {
PROXY,
OPENAI_API_KEY,
AZURE_API_KEY,
OPENAI_REVERSE_PROXY,
AZURE_OPENAI_BASEURL,
OPENAI_SUMMARIZE,
DEBUG_OPENAI,
} = process.env;
const { key: expiresAt, endpoint } = req.body;
const contextStrategy = isEnabled(OPENAI_SUMMARIZE) ? 'summarize' : null;

const baseURLOptions = {
[EModelEndpoint.openAI]: OPENAI_REVERSE_PROXY,
[EModelEndpoint.azureOpenAI]: AZURE_OPENAI_BASEURL,
};

const reverseProxyUrl = baseURLOptions[endpoint] ?? null;

const clientOptions = {
debug: isEnabled(DEBUG_OPENAI),
contextStrategy,
reverseProxyUrl: OPENAI_REVERSE_PROXY ?? null,
reverseProxyUrl,
proxy: PROXY ?? null,
req,
res,
...endpointOption,
};

const credentials = {
openAI: OPENAI_API_KEY,
azureOpenAI: AZURE_API_KEY,
[EModelEndpoint.openAI]: OPENAI_API_KEY,
[EModelEndpoint.azureOpenAI]: AZURE_API_KEY,
};

const isUserProvided = credentials[endpoint] === 'user_provided';
Expand All @@ -42,7 +52,7 @@ const initializeClient = async ({ req, res, endpointOption }) => {

let apiKey = isUserProvided ? userKey : credentials[endpoint];

if (endpoint === 'azureOpenAI') {
if (endpoint === EModelEndpoint.azureOpenAI) {
clientOptions.azure = isUserProvided ? JSON.parse(userKey) : getAzureCredentials();
apiKey = clientOptions.azure.azureOpenAIApiKey;
}
Expand Down
45 changes: 34 additions & 11 deletions utils/azureUtils.js
Original file line number Diff line number Diff line change
@@ -1,11 +1,3 @@
/**
* @typedef {Object} AzureCredentials
* @property {string} azureOpenAIApiKey - The Azure OpenAI API key.
* @property {string} azureOpenAIApiInstanceName - The Azure OpenAI API instance name.
* @property {string} azureOpenAIApiDeploymentName - The Azure OpenAI API deployment name.
* @property {string} azureOpenAIApiVersion - The Azure OpenAI API version.
*/

const { isEnabled } = require('~/server/utils');

/**
Expand Down Expand Up @@ -37,30 +29,37 @@ const genAzureEndpoint = ({ azureOpenAIApiInstanceName, azureOpenAIApiDeployment
* @param {string} [AzureConfig.azureOpenAIApiDeploymentName] - The Azure OpenAI API deployment name (optional).
* @param {string} AzureConfig.azureOpenAIApiVersion - The Azure OpenAI API version.
* @param {string} [modelName] - The model name to be included in the deployment name (optional).
* @param {Object} [client] - The API Client class for optionally setting properties (optional).
* @returns {string} The complete chat completion endpoint URL for the Azure OpenAI API.
* @throws {Error} If neither azureOpenAIApiDeploymentName nor modelName is provided.
*/
const genAzureChatCompletion = (
{ azureOpenAIApiInstanceName, azureOpenAIApiDeploymentName, azureOpenAIApiVersion },
modelName,
client,
) => {
// Determine the deployment segment of the URL based on provided modelName or azureOpenAIApiDeploymentName
let deploymentSegment;
if (isEnabled(process.env.AZURE_USE_MODEL_AS_DEPLOYMENT_NAME) && modelName) {
const sanitizedModelName = sanitizeModelName(modelName);
deploymentSegment = `${sanitizedModelName}`;
client &&
typeof client === 'object' &&
(client.azure.azureOpenAIApiDeploymentName = sanitizedModelName);
} else if (azureOpenAIApiDeploymentName) {
deploymentSegment = azureOpenAIApiDeploymentName;
} else {
throw new Error('Either a model name or a deployment name must be provided.');
} else if (!process.env.AZURE_OPENAI_BASEURL) {
throw new Error(
'Either a model name with the `AZURE_USE_MODEL_AS_DEPLOYMENT_NAME` setting or a deployment name must be provided if `AZURE_OPENAI_BASEURL` is omitted.',
);
}

return `https://${azureOpenAIApiInstanceName}.openai.azure.com/openai/deployments/${deploymentSegment}/chat/completions?api-version=${azureOpenAIApiVersion}`;
};

/**
* Retrieves the Azure OpenAI API credentials from environment variables.
* @returns {AzureCredentials} An object containing the Azure OpenAI API credentials.
* @returns {AzureOptions} An object containing the Azure OpenAI API credentials.
*/
const getAzureCredentials = () => {
return {
Expand All @@ -71,9 +70,33 @@ const getAzureCredentials = () => {
};
};

/**
* Constructs a URL by replacing placeholders in the baseURL with values from the azure object.
* It specifically looks for '${INSTANCE_NAME}' and '${DEPLOYMENT_NAME}' within the baseURL and replaces
* them with 'azureOpenAIApiInstanceName' and 'azureOpenAIApiDeploymentName' from the azure object.
* If the respective azure property is not provided, the placeholder is replaced with an empty string.
*
* @param {Object} params - The parameters object.
* @param {string} params.baseURL - The baseURL to inspect for replacement placeholders.
* @param {AzureOptions} params.azure - The baseURL to inspect for replacement placeholders.
* @returns {string} The complete baseURL with credentials injected for the Azure OpenAI API.
*/
function constructAzureURL({ baseURL, azure }) {
let finalURL = baseURL;

// Replace INSTANCE_NAME and DEPLOYMENT_NAME placeholders with actual values if available
if (azure) {
finalURL = finalURL.replace('${INSTANCE_NAME}', azure.azureOpenAIApiInstanceName ?? '');
finalURL = finalURL.replace('${DEPLOYMENT_NAME}', azure.azureOpenAIApiDeploymentName ?? '');
}

return finalURL;
}

module.exports = {
sanitizeModelName,
genAzureEndpoint,
genAzureChatCompletion,
getAzureCredentials,
constructAzureURL,
};
Loading

0 comments on commit db20283

Please sign in to comment.