Skip to content

Commit

Permalink
fix: Fix langchain types (#167)
Browse files Browse the repository at this point in the history
  • Loading branch information
tomfrenken authored Sep 23, 2024
1 parent 8798913 commit 8cda2de
Show file tree
Hide file tree
Showing 5 changed files with 46 additions and 24 deletions.
5 changes: 5 additions & 0 deletions .changeset/unlucky-rivers-work.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
'@sap-ai-sdk/langchain': patch
---

[Fixed Issue] Fix LangChain types for proper IDE auto completion.
10 changes: 4 additions & 6 deletions packages/langchain/src/openai/chat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,10 @@ export class AzureOpenAiChatClient
modelName: AzureOpenAiChatModel;
modelVersion?: string;
resourceGroup?: string;
temperature?: number;
top_p?: number;
logit_bias?: Record<string, unknown>;
temperature?: number | null;
top_p?: number | null;
logit_bias?: Record<string, any> | null | undefined;
user?: string;
n?: number;
presence_penalty?: number;
frequency_penalty?: number;
stop?: string | string[];
Expand All @@ -41,7 +40,6 @@ export class AzureOpenAiChatClient
this.top_p = fields.top_p;
this.logit_bias = fields.logit_bias;
this.user = fields.user;
this.n = fields.n;
this.stop = fields.stop;
this.presence_penalty = fields.presence_penalty;
this.frequency_penalty = fields.frequency_penalty;
Expand All @@ -63,7 +61,7 @@ export class AzureOpenAiChatClient
},
() =>
this.openAiChatClient.run(
mapLangchainToAiClient(this, options, messages),
mapLangchainToAiClient(this, messages, options),
options.requestConfig
)
);
Expand Down
32 changes: 22 additions & 10 deletions packages/langchain/src/openai/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,22 +5,25 @@ import type {
import { BaseLLMParams } from '@langchain/core/language_models/llms';
import type {
AzureOpenAiCreateChatCompletionRequest,
AzureOpenAiChatModel
AzureOpenAiChatModel,
AzureOpenAiChatCompletionsRequestCommon
} from '@sap-ai-sdk/foundation-models';
import type { CustomRequestConfig } from '@sap-ai-sdk/core';
import type { ModelConfig, ResourceGroupConfig } from '@sap-ai-sdk/ai-api';

/**
* Input type for {@link AzureOpenAiChatClient} initialization.
*/
export type AzureOpenAiChatModelParams = Omit<
AzureOpenAiCreateChatCompletionRequest,
| 'messages'
| 'response_format'
| 'seed'
| 'functions'
| 'tools'
| 'tool_choice'
export type AzureOpenAiChatModelParams = Pick<
AzureOpenAiChatCompletionsRequestCommon,
| 'temperature'
| 'top_p'
| 'stop'
| 'max_tokens'
| 'presence_penalty'
| 'frequency_penalty'
| 'logit_bias'
| 'user'
> &
BaseChatModelParams &
ModelConfig<AzureOpenAiChatModel> &
Expand All @@ -32,7 +35,16 @@ export type AzureOpenAiChatModelParams = Omit<
export type AzureOpenAiChatCallOptions = BaseChatModelCallOptions &
Pick<
AzureOpenAiCreateChatCompletionRequest,
'response_format' | 'seed' | 'functions' | 'tools' | 'tool_choice'
| 'data_sources'
| 'n'
| 'seed'
| 'logprobs'
| 'top_logprobs'
| 'response_format'
| 'tools'
| 'tool_choice'
| 'functions'
| 'function_call'
> & {
requestConfig?: CustomRequestConfig;
};
Expand Down
4 changes: 2 additions & 2 deletions packages/langchain/src/openai/util.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@ describe('Mapping Functions', () => {
const defaultOptions = { signal: undefined, promptIndex: 0 };
const mapping = mapLangchainToAiClient(
client,
defaultOptions,
langchainPrompt
langchainPrompt,
defaultOptions
);
expect(mapping).toMatchObject(request);
});
Expand Down
19 changes: 13 additions & 6 deletions packages/langchain/src/openai/util.ts
Original file line number Diff line number Diff line change
Expand Up @@ -230,26 +230,33 @@ function mapToolChoice(
*/
export function mapLangchainToAiClient(
client: AzureOpenAiChatClient,
options: AzureOpenAiChatCallOptions & { promptIndex?: number },
messages: BaseMessage[]
messages: BaseMessage[],
options?: AzureOpenAiChatCallOptions & { promptIndex?: number }
): AzureOpenAiCreateChatCompletionRequest {
return removeUndefinedProperties<AzureOpenAiCreateChatCompletionRequest>({
messages: messages.map(mapBaseMessageToAzureOpenAiChatMessage),
max_tokens: client.max_tokens === -1 ? undefined : client.max_tokens,
presence_penalty: client.presence_penalty,
frequency_penalty: client.frequency_penalty,
temperature: client.temperature,
top_p: client.top_p,
logit_bias: client.logit_bias,
n: client.n,
user: client.user,
data_sources: options?.data_sources,
n: options?.n,
response_format: options?.response_format,
seed: options?.seed,
logprobs: options?.logprobs,
top_logprobs: options?.top_logprobs,
function_call: options?.function_call,
stop: options?.stop ?? client.stop,
functions: isStructuredToolArray(options?.functions)
? options?.functions.map(mapToolToOpenAiFunction)
: options?.functions,
tools: isStructuredToolArray(options?.tools)
? options?.tools.map(mapToolToOpenAiTool)
: options?.tools,
tool_choice: mapToolChoice(options?.tool_choice),
response_format: options?.response_format,
seed: options?.seed
tool_choice: mapToolChoice(options?.tool_choice)
});
}

Expand Down

0 comments on commit 8cda2de

Please sign in to comment.