Skip to content

Commit

Permalink
fix: Fix LangChain client's internal mapping logic (#488)
Browse files Browse the repository at this point in the history
  • Loading branch information
tomfrenken authored Jan 28, 2025
1 parent 5782f0d commit ccfa2eb
Show file tree
Hide file tree
Showing 8 changed files with 259 additions and 133 deletions.
6 changes: 6 additions & 0 deletions .changeset/tender-numbers-sin.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
---
'@sap-ai-sdk/langchain': minor
'@sap-ai-sdk/sample-code': minor
---

[Fixed Issue] Fixed the internal mapping of LangChain to Azure OpenAI and vice versa.
1 change: 1 addition & 0 deletions packages/langchain/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
"@sap-ai-sdk/core": "workspace:^",
"@sap-ai-sdk/foundation-models": "workspace:^",
"@sap-cloud-sdk/connectivity": "^3.25.0",
"uuid": "^11.0.0",
"@langchain/core": "0.3.36",
"zod-to-json-schema": "^3.24.1"
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ exports[`Mapping Functions should parse an OpenAI response to a (LangChain) chat
"finish_reason": "stop",
"function_call": undefined,
"index": 0,
"tool_call_id": "",
"tool_calls": undefined,
},
"content": "Hello! I’m here and ready to help. How can I assist you today?",
Expand Down
2 changes: 1 addition & 1 deletion packages/langchain/src/openai/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,10 @@ export type AzureOpenAiChatCallOptions = BaseChatModelCallOptions &
| 'logprobs'
| 'top_logprobs'
| 'response_format'
| 'tools'
| 'tool_choice'
| 'functions'
| 'function_call'
| 'tools'
> & {
requestConfig?: CustomRequestConfig;
};
Expand Down
274 changes: 147 additions & 127 deletions packages/langchain/src/openai/util.ts
Original file line number Diff line number Diff line change
@@ -1,44 +1,32 @@
import { AIMessage, ToolMessage } from '@langchain/core/messages';
import { AIMessage } from '@langchain/core/messages';
import { zodToJsonSchema } from 'zod-to-json-schema';
import { v4 as uuidv4 } from 'uuid';
import type { ToolCall } from '@langchain/core/messages/tool';
import type {
AzureOpenAiChatCompletionRequestFunctionMessage,
AzureOpenAiChatCompletionRequestToolMessage,
AzureOpenAiChatCompletionRequestSystemMessage,
AzureOpenAiChatCompletionRequestUserMessage,
AzureOpenAiChatCompletionRequestAssistantMessage,
AzureOpenAiChatCompletionTool,
AzureOpenAiChatCompletionRequestMessage,
AzureOpenAiCreateChatCompletionResponse,
AzureOpenAiCreateChatCompletionRequest,
AzureOpenAiFunctionParameters
AzureOpenAiFunctionParameters,
AzureOpenAiChatCompletionMessageToolCalls,
AzureOpenAiChatCompletionRequestToolMessage,
AzureOpenAiChatCompletionRequestFunctionMessage,
AzureOpenAiChatCompletionRequestSystemMessage
} from '@sap-ai-sdk/foundation-models';
import type { BaseMessage } from '@langchain/core/messages';
import type {
BaseMessage,
FunctionMessage,
HumanMessage,
SystemMessage,
ToolMessage
} from '@langchain/core/messages';
import type { ChatResult } from '@langchain/core/outputs';
import type { StructuredTool } from '@langchain/core/tools';
import type { AzureOpenAiChatClient } from './chat.js';
import type { AzureOpenAiChatCallOptions } from './types.js';

type ToolChoice =
| 'none'
| 'auto'
| {
/**
* The type of the tool.
*/
type: 'function';
/**
* Use to force the model to call a specific function.
*/
function: {
/**
* The name of the function to call.
*/
name: string;
};
};

type LangChainToolChoice = string | Record<string, any> | 'auto' | 'any';

/**
* Maps a LangChain {@link StructuredTool} to {@link AzureOpenAiChatCompletionFunctions}.
* @param tool - Base class for tools that accept input of any shape defined by a Zod schema.
Expand Down Expand Up @@ -71,30 +59,21 @@ function mapToolToOpenAiTool(
}

/**
* Maps a {@link BaseMessage} to{@link AzureOpenAiChatMessage} message role.
* @param message - The {@link BaseMessage} to map.
* @returns The {@link AzureOpenAiChatMessage} message Role.
* Maps {@link AzureOpenAiChatCompletionMessageToolCalls} to LangChain's {@link ToolCall}.
* @param toolCalls - The {@link AzureOpenAiChatCompletionMessageToolCalls} response.
* @returns The LangChain {@link ToolCall}.
*/
function mapBaseMessageToRole(
message: BaseMessage
): AzureOpenAiChatCompletionRequestMessage['role'] {
const messageTypeToRoleMap = new Map<
string,
AzureOpenAiChatCompletionRequestMessage['role']
>([
['human', 'user'],
['ai', 'assistant'],
['system', 'system'],
['function', 'function'],
['tool', 'tool']
]);

const messageType = message._getType();
const role = messageTypeToRoleMap.get(messageType);
if (!role) {
throw new Error(`Unsupported message type: ${messageType}`);
function mapAzureOpenAiToLangchainToolCall(
toolCalls?: AzureOpenAiChatCompletionMessageToolCalls
): ToolCall[] | undefined {
if (toolCalls) {
return toolCalls.map(toolCall => ({
id: toolCall.id,
name: toolCall.function.name,
args: JSON.parse(toolCall.function.arguments),
type: 'tool_call'
}));
}
return role;
}

/**
Expand All @@ -107,113 +86,154 @@ export function mapOutputToChatResult(
completionResponse: AzureOpenAiCreateChatCompletionResponse
): ChatResult {
return {
generations: completionResponse.choices.map(
(choice: (typeof completionResponse)['choices'][0]) => ({
text: choice.message?.content || '',
message: new AIMessage({
content: choice.message?.content || '',
additional_kwargs: {
finish_reason: choice.finish_reason,
index: choice.index,
function_call: choice.message?.function_call,
tool_calls: choice.message?.tool_calls,
tool_call_id: ''
}
}),
generationInfo: {
generations: completionResponse.choices.map(choice => ({
text: choice.message.content ?? '',
message: new AIMessage({
content: choice.message.content ?? '',
tool_calls: mapAzureOpenAiToLangchainToolCall(
choice.message.tool_calls
),
additional_kwargs: {
finish_reason: choice.finish_reason,
index: choice.index,
function_call: choice.message?.function_call,
tool_calls: choice.message?.tool_calls
function_call: choice.message.function_call,
tool_calls: choice.message.tool_calls
}
})
),
}),
generationInfo: {
finish_reason: choice.finish_reason,
index: choice.index,
function_call: choice.message.function_call,
tool_calls: choice.message.tool_calls
}
})),
llmOutput: {
created: completionResponse.created,
id: completionResponse.id,
model: completionResponse.model,
object: completionResponse.object,
tokenUsage: {
completionTokens: completionResponse.usage?.completion_tokens || 0,
promptTokens: completionResponse.usage?.prompt_tokens || 0,
totalTokens: completionResponse.usage?.total_tokens || 0
completionTokens: completionResponse.usage?.completion_tokens ?? 0,
promptTokens: completionResponse.usage?.prompt_tokens ?? 0,
totalTokens: completionResponse.usage?.total_tokens ?? 0
}
}
};
}

/**
* Maps {@link BaseMessage} to {@link AzureOpenAiChatMessage}.
* @param message - The message to map.
* @returns The {@link AzureOpenAiChatMessage}.
* Maps LangChain's {@link ToolCall} to Azure OpenAI's {@link AzureOpenAiChatCompletionMessageToolCalls}.
* @param toolCalls - The {@link ToolCall} to map.
* @returns The Azure OpenAI {@link AzureOpenAiChatCompletionMessageToolCalls}.
*/
function mapBaseMessageToAzureOpenAiChatMessage(
message: BaseMessage
): AzureOpenAiChatCompletionRequestMessage {
return removeUndefinedProperties<AzureOpenAiChatCompletionRequestMessage>({
name: message.name ?? '',
...mapRoleAndContent(message),
function_call: message.additional_kwargs.function_call,
tool_calls: message.additional_kwargs.tool_calls,
tool_call_id: mapToolCallId(message)
});
function mapLangchainToolCallToAzureOpenAiToolCall(
toolCalls?: ToolCall[]
): AzureOpenAiChatCompletionMessageToolCalls | undefined {
if (toolCalls) {
return toolCalls.map(toolCall => ({
id: toolCall.id || uuidv4(),
type: 'function',
function: {
name: toolCall.name,
arguments: JSON.stringify(toolCall.args)
}
}));
}
}

// The following types are used to match a role to its specific content, otherwise TypeScript would not be able to infer the content type.

type Role = 'system' | 'user' | 'assistant' | 'tool' | 'function';
/**
* Maps LangChain's {@link AIMessage} to Azure OpenAI's {@link AzureOpenAiChatCompletionRequestAssistantMessage}.
* @param message - The {@link AIMessage} to map.
* @returns The Azure OpenAI {@link AzureOpenAiChatCompletionRequestAssistantMessage}.
*/
function mapAiMessageToAzureOpenAiAssistantMessage(
message: AIMessage
): AzureOpenAiChatCompletionRequestAssistantMessage {
return {
name: message.name,
tool_calls:
mapLangchainToolCallToAzureOpenAiToolCall(message.tool_calls) ??
message.additional_kwargs.tool_calls,
function_call: message.additional_kwargs.function_call,
content:
message.content as AzureOpenAiChatCompletionRequestAssistantMessage['content'],
role: 'assistant'
};
}

type ContentType<T extends Role> = T extends 'system'
? AzureOpenAiChatCompletionRequestSystemMessage['content']
: T extends 'user'
? AzureOpenAiChatCompletionRequestUserMessage['content']
: T extends 'assistant'
? AzureOpenAiChatCompletionRequestAssistantMessage['content']
: T extends 'tool'
? AzureOpenAiChatCompletionRequestToolMessage['content']
: T extends 'function'
? AzureOpenAiChatCompletionRequestFunctionMessage['content']
: never;
function mapHumanMessageToAzureOpenAiUserMessage(
message: HumanMessage
): AzureOpenAiChatCompletionRequestUserMessage {
return {
role: 'user',
content:
message.content as AzureOpenAiChatCompletionRequestUserMessage['content'],
name: message.name
};
}

type RoleAndContent = {
[T in Role]: { role: T; content: ContentType<T> };
}[Role];
function mapToolMessageToAzureOpenAiToolMessage(
message: ToolMessage
): AzureOpenAiChatCompletionRequestToolMessage {
return {
role: 'tool',
content:
message.content as AzureOpenAiChatCompletionRequestToolMessage['content'],
tool_call_id: message.tool_call_id
};
}

function mapRoleAndContent(baseMessage: BaseMessage): RoleAndContent {
const role = mapBaseMessageToRole(baseMessage);
function mapFunctionMessageToAzureOpenAiFunctionMessage(
message: FunctionMessage
): AzureOpenAiChatCompletionRequestFunctionMessage {
return {
role,
content: baseMessage.content as ContentType<typeof role>
} as RoleAndContent;
role: 'function',
content:
message.content as AzureOpenAiChatCompletionRequestFunctionMessage['content'],
name: message.name || 'default'
};
}

function isStructuredToolArray(tools?: unknown[]): tools is StructuredTool[] {
return !!tools?.every(tool =>
Array.isArray((tool as StructuredTool).lc_namespace)
);
function mapSystemMessageToAzureOpenAiSystemMessage(
message: SystemMessage
): AzureOpenAiChatCompletionRequestSystemMessage {
return {
role: 'system',
content:
message.content as AzureOpenAiChatCompletionRequestSystemMessage['content'],
name: message.name
};
}

/**
* Has to return an empty string to match one of the types of {@link AzureOpenAiChatCompletionRequestMessage}.
* @internal
* Maps {@link BaseMessage} to {@link AzureOpenAiChatMessage}.
* @param message - The message to map.
* @returns The {@link AzureOpenAiChatMessage}.
*/
function mapToolCallId(message: BaseMessage): string {
return ToolMessage.isInstance(message) ? message.tool_call_id : '';
}

function mapToolChoice(
toolChoice?: LangChainToolChoice
): ToolChoice | undefined {
if (toolChoice === 'auto' || toolChoice === 'none') {
return toolChoice;
// TODO: Add mapping of refusal property, once LangChain base class supports it natively.
function mapBaseMessageToAzureOpenAiChatMessage(
message: BaseMessage
): AzureOpenAiChatCompletionRequestMessage {
switch (message.getType()) {
case 'ai':
return mapAiMessageToAzureOpenAiAssistantMessage(message);
case 'human':
return mapHumanMessageToAzureOpenAiUserMessage(message);
case 'system':
return mapSystemMessageToAzureOpenAiSystemMessage(message);
case 'function':
return mapFunctionMessageToAzureOpenAiFunctionMessage(message);
case 'tool':
return mapToolMessageToAzureOpenAiToolMessage(message as ToolMessage);
default:
throw new Error(`Unsupported message type: ${message.getType()}`);
}
}

if (typeof toolChoice === 'string') {
return {
type: 'function',
function: { name: toolChoice }
};
}
function isStructuredToolArray(tools?: unknown[]): tools is StructuredTool[] {
return !!tools?.every(tool =>
Array.isArray((tool as StructuredTool).lc_namespace)
);
}

/**
Expand Down Expand Up @@ -252,7 +272,7 @@ export function mapLangchainToAiClient(
tools: isStructuredToolArray(options?.tools)
? options?.tools.map(mapToolToOpenAiTool)
: options?.tools,
tool_choice: mapToolChoice(options?.tool_choice)
tool_choice: options?.tool_choice
});
}

Expand Down
Loading

0 comments on commit ccfa2eb

Please sign in to comment.