From 8ab002612377a405bcee44091023bc3cbd712e9c Mon Sep 17 00:00:00 2001 From: Arpit Gupta <162559421+arpitg-1@users.noreply.github.com> Date: Tue, 13 Aug 2024 15:05:04 +0530 Subject: [PATCH] W-16295334: Added changes required for certification (#24) * W-16295334: Added changes for duration parameterisation * W-16295334: Fixed failing munit test case * W-16295334: Updated munit * W-16295334: Removed model and llm default values * W-16295334: Removed static initialisations from the code * W-16295334: Bug fixes * W-16295334: Added one streaming operation for chat --- .../main/mule/mulechain-ai-connector-demo.xml | 2 +- .../config/LangchainLLMConfiguration.java | 74 ++++++---------- .../util/LangchainLLMInitializerUtil.java | 20 +++-- .../internal/error/MuleChainErrorType.java | 2 +- .../exception/StreamingException.java | 14 +++ .../llm/LangchainLLMModelNameProvider.java | 2 +- .../internal/llm/config/ConfigType.java | 18 +++- .../internal/llm/type/LangchainLLMType.java | 38 +++++++- .../LangchainStreamingOperations.java | 87 +++++++++++++++++++ .../langchain-llm-operation-testing-suite.xml | 3 +- 10 files changed, 197 insertions(+), 63 deletions(-) create mode 100644 src/main/java/org/mule/extension/mulechain/internal/exception/StreamingException.java create mode 100644 src/main/java/org/mule/extension/mulechain/internal/operation/LangchainStreamingOperations.java diff --git a/demo/mulechain-ai-connector-demo/src/main/mule/mulechain-ai-connector-demo.xml b/demo/mulechain-ai-connector-demo/src/main/mule/mulechain-ai-connector-demo.xml index 8dc65d7..777cb0d 100644 --- a/demo/mulechain-ai-connector-demo/src/main/mule/mulechain-ai-connector-demo.xml +++ b/demo/mulechain-ai-connector-demo/src/main/mule/mulechain-ai-connector-demo.xml @@ -10,7 +10,7 @@ http://www.mulesoft.org/schema/mule/ee/core http://www.mulesoft.org/schema/mule/ - + diff --git a/src/main/java/org/mule/extension/mulechain/internal/config/LangchainLLMConfiguration.java b/src/main/java/org/mule/extension/mulechain/internal/config/LangchainLLMConfiguration.java index f9fc82a..d33233e 100644 --- a/src/main/java/org/mule/extension/mulechain/internal/config/LangchainLLMConfiguration.java +++ b/src/main/java/org/mule/extension/mulechain/internal/config/LangchainLLMConfiguration.java @@ -4,19 +4,15 @@ package org.mule.extension.mulechain.internal.config; import dev.langchain4j.model.chat.ChatLanguageModel; -import org.mule.extension.mulechain.internal.exception.config.ConfigValidationException; import org.mule.extension.mulechain.internal.operation.LangchainEmbeddingStoresOperations; import org.mule.extension.mulechain.internal.operation.LangchainImageModelsOperations; import org.mule.extension.mulechain.internal.llm.type.LangchainLLMType; import org.mule.extension.mulechain.internal.llm.ConfigTypeProvider; -import org.mule.extension.mulechain.internal.config.util.LangchainLLMInitializerUtil; import org.mule.extension.mulechain.internal.operation.LangchainLLMOperations; import org.mule.extension.mulechain.internal.llm.LangchainLLMModelNameProvider; import org.mule.extension.mulechain.internal.llm.LangchainLLMTypeProvider; import org.mule.extension.mulechain.internal.llm.config.ConfigExtractor; import org.mule.extension.mulechain.internal.llm.config.ConfigType; -import org.mule.extension.mulechain.internal.llm.config.EnvConfigExtractor; -import org.mule.extension.mulechain.internal.llm.config.FileConfigExtractor; import org.mule.runtime.api.lifecycle.Initialisable; import org.mule.runtime.api.lifecycle.InitialisationException; import org.mule.runtime.api.meta.ExpressionSupport; @@ -27,12 +23,10 @@ import org.mule.runtime.extension.api.annotation.param.Parameter; import org.mule.runtime.extension.api.annotation.param.display.DisplayName; import org.mule.runtime.extension.api.annotation.param.display.Placement; +import org.mule.runtime.extension.api.annotation.param.display.Summary; import org.mule.runtime.extension.api.annotation.values.OfValues; -import java.util.HashMap; -import java.util.Map; -import java.util.function.BiFunction; -import java.util.function.Function; +import java.util.concurrent.TimeUnit; /** * This class represents an extension configuration, values set in this class are commonly used across multiple @@ -42,25 +36,9 @@ @Operations({LangchainLLMOperations.class, LangchainEmbeddingStoresOperations.class, LangchainImageModelsOperations.class}) public class LangchainLLMConfiguration implements Initialisable { - private static final Map> llmMap; - private static final Map> configExtractorMap; - - static { - configExtractorMap = new HashMap<>(); - configExtractorMap.put(ConfigType.ENV_VARIABLE, (configuration) -> new EnvConfigExtractor()); - configExtractorMap.put(ConfigType.CONFIG_JSON, FileConfigExtractor::new); - - llmMap = new HashMap<>(); - llmMap.put(LangchainLLMType.OPENAI, (LangchainLLMInitializerUtil::createOpenAiChatModel)); - llmMap.put(LangchainLLMType.GROQAI_OPENAI, (LangchainLLMInitializerUtil::createGroqOpenAiChatModel)); - llmMap.put(LangchainLLMType.MISTRAL_AI, (LangchainLLMInitializerUtil::createMistralAiChatModel)); - llmMap.put(LangchainLLMType.OLLAMA, (LangchainLLMInitializerUtil::createOllamaChatModel)); - llmMap.put(LangchainLLMType.ANTHROPIC, (LangchainLLMInitializerUtil::createAnthropicChatModel)); - llmMap.put(LangchainLLMType.AZURE_OPENAI, (LangchainLLMInitializerUtil::createAzureOpenAiChatModel)); - } - @Parameter @Placement(order = 1, tab = Placement.DEFAULT_TAB) + @DisplayName("LLM type") @OfValues(LangchainLLMTypeProvider.class) private String llmType; @@ -71,28 +49,35 @@ public class LangchainLLMConfiguration implements Initialisable { @Parameter @Placement(order = 3, tab = Placement.DEFAULT_TAB) + @Optional(defaultValue = "#[-]") private String filePath; @Parameter @Expression(ExpressionSupport.SUPPORTED) @OfValues(LangchainLLMModelNameProvider.class) - @Optional(defaultValue = "gpt-3.5-turbo") - @Placement(order = 4) - private String modelName = "gpt-3.5-turbo"; + @Placement(order = 4, tab = Placement.DEFAULT_TAB) + private String modelName; @Parameter - @Placement(order = 5) + @Placement(order = 5, tab = Placement.DEFAULT_TAB) @Optional(defaultValue = "0.7") private double temperature = 0.7; @Parameter - @Placement(order = 6) + @Placement(order = 6, tab = Placement.DEFAULT_TAB) @Optional(defaultValue = "60") - @DisplayName("Duration in sec") - private long durationInSeconds = 60; + @DisplayName("LLM timeout") + private int llmTimeout = 60; + + @Parameter + @Optional(defaultValue = "SECONDS") + @Placement(order = 7, tab = Placement.DEFAULT_TAB) + @DisplayName("LLM timeout unit") + @Summary("Time unit to be used in the LLM Timeout") + private TimeUnit llmTimeoutUnit = TimeUnit.SECONDS; @Parameter - @Placement(order = 7) + @Placement(order = 8, tab = Placement.DEFAULT_TAB) @Expression(ExpressionSupport.SUPPORTED) @Optional(defaultValue = "500") private int maxTokens = 500; @@ -121,8 +106,12 @@ public double getTemperature() { return temperature; } - public long getDurationInSeconds() { - return durationInSeconds; + public int getLlmTimeout() { + return llmTimeout; + } + + public TimeUnit getLlmTimeoutUnit() { + return llmTimeoutUnit; } public int getMaxTokens() { @@ -138,21 +127,14 @@ public ChatLanguageModel getModel() { } private ChatLanguageModel createModel(ConfigExtractor configExtractor) { - LangchainLLMType type = LangchainLLMType.valueOf(llmType); - if (llmMap.containsKey(type)) { - return llmMap.get(type).apply(configExtractor, this); - } - throw new ConfigValidationException("LLM Type not supported: " + llmType); + LangchainLLMType type = LangchainLLMType.fromValue(llmType); + return type.getConfigBiFunction().apply(configExtractor, this); } @Override public void initialise() throws InitialisationException { ConfigType config = ConfigType.fromValue(configType); - if (configExtractorMap.containsKey(config)) { - configExtractor = configExtractorMap.get(config).apply(this); - model = createModel(configExtractor); - } else { - throw new ConfigValidationException("Config Type not supported: " + configType); - } + configExtractor = config.getConfigExtractorFunction().apply(this); + model = createModel(configExtractor); } } diff --git a/src/main/java/org/mule/extension/mulechain/internal/config/util/LangchainLLMInitializerUtil.java b/src/main/java/org/mule/extension/mulechain/internal/config/util/LangchainLLMInitializerUtil.java index 6fefda6..d20e786 100644 --- a/src/main/java/org/mule/extension/mulechain/internal/config/util/LangchainLLMInitializerUtil.java +++ b/src/main/java/org/mule/extension/mulechain/internal/config/util/LangchainLLMInitializerUtil.java @@ -11,6 +11,8 @@ import org.mule.extension.mulechain.internal.config.LangchainLLMConfiguration; import org.mule.extension.mulechain.internal.llm.config.ConfigExtractor; +import java.time.Duration; + import static java.time.Duration.ofSeconds; public final class LangchainLLMInitializerUtil { @@ -19,12 +21,13 @@ private LangchainLLMInitializerUtil() {} public static OpenAiChatModel createOpenAiChatModel(ConfigExtractor configExtractor, LangchainLLMConfiguration configuration) { String openaiApiKey = configExtractor.extractValue("OPENAI_API_KEY"); + long durationInSec = configuration.getLlmTimeoutUnit().toSeconds(configuration.getLlmTimeout()); return OpenAiChatModel.builder() .apiKey(openaiApiKey) .modelName(configuration.getModelName()) .maxTokens(configuration.getMaxTokens()) .temperature(configuration.getTemperature()) - .timeout(ofSeconds(configuration.getDurationInSeconds())) + .timeout(ofSeconds(durationInSec)) .logRequests(true) .logResponses(true) .build(); @@ -34,13 +37,14 @@ public static OpenAiChatModel createOpenAiChatModel(ConfigExtractor configExtrac public static OpenAiChatModel createGroqOpenAiChatModel(ConfigExtractor configExtractor, LangchainLLMConfiguration configuration) { String groqApiKey = configExtractor.extractValue("GROQ_API_KEY"); + long durationInSec = configuration.getLlmTimeoutUnit().toSeconds(configuration.getLlmTimeout()); return OpenAiChatModel.builder() .baseUrl("https://api.groq.com/openai/v1") .apiKey(groqApiKey) .modelName(configuration.getModelName()) .maxTokens(configuration.getMaxTokens()) .temperature(configuration.getTemperature()) - .timeout(ofSeconds(configuration.getDurationInSeconds())) + .timeout(ofSeconds(durationInSec)) .logRequests(true) .logResponses(true) .build(); @@ -51,13 +55,14 @@ public static OpenAiChatModel createGroqOpenAiChatModel(ConfigExtractor configEx public static MistralAiChatModel createMistralAiChatModel(ConfigExtractor configExtractor, LangchainLLMConfiguration configuration) { String mistralAiApiKey = configExtractor.extractValue("MISTRAL_AI_API_KEY"); + long durationInSec = configuration.getLlmTimeoutUnit().toSeconds(configuration.getLlmTimeout()); return MistralAiChatModel.builder() //.apiKey(configuration.getLlmApiKey()) .apiKey(mistralAiApiKey) .modelName(configuration.getModelName()) .maxTokens(configuration.getMaxTokens()) .temperature(configuration.getTemperature()) - .timeout(ofSeconds(configuration.getDurationInSeconds())) + .timeout(ofSeconds(durationInSec)) .logRequests(true) .logResponses(true) .build(); @@ -65,12 +70,13 @@ public static MistralAiChatModel createMistralAiChatModel(ConfigExtractor config public static OllamaChatModel createOllamaChatModel(ConfigExtractor configExtractor, LangchainLLMConfiguration configuration) { String ollamaBaseUrl = configExtractor.extractValue("OLLAMA_BASE_URL"); + long durationInSec = configuration.getLlmTimeoutUnit().toSeconds(configuration.getLlmTimeout()); return OllamaChatModel.builder() //.baseUrl(configuration.getLlmApiKey()) .baseUrl(ollamaBaseUrl) .modelName(configuration.getModelName()) .temperature(configuration.getTemperature()) - .timeout(ofSeconds(configuration.getDurationInSeconds())) + .timeout(ofSeconds(durationInSec)) .build(); } @@ -78,13 +84,14 @@ public static OllamaChatModel createOllamaChatModel(ConfigExtractor configExtrac public static AnthropicChatModel createAnthropicChatModel(ConfigExtractor configExtractor, LangchainLLMConfiguration configuration) { String anthropicApiKey = configExtractor.extractValue("ANTHROPIC_API_KEY"); + long durationInSec = configuration.getLlmTimeoutUnit().toSeconds(configuration.getLlmTimeout()); return AnthropicChatModel.builder() //.apiKey(configuration.getLlmApiKey()) .apiKey(anthropicApiKey) .modelName(configuration.getModelName()) .maxTokens(configuration.getMaxTokens()) .temperature(configuration.getTemperature()) - .timeout(ofSeconds(configuration.getDurationInSeconds())) + .timeout(ofSeconds(durationInSec)) .logRequests(true) .logResponses(true) .build(); @@ -96,13 +103,14 @@ public static AzureOpenAiChatModel createAzureOpenAiChatModel(ConfigExtractor co String azureOpenaiKey = configExtractor.extractValue("AZURE_OPENAI_KEY"); String azureOpenaiEndpoint = configExtractor.extractValue("AZURE_OPENAI_ENDPOINT"); String azureOpenaiDeploymentName = configExtractor.extractValue("AZURE_OPENAI_DEPLOYMENT_NAME"); + long durationInSec = configuration.getLlmTimeoutUnit().toSeconds(configuration.getLlmTimeout()); return AzureOpenAiChatModel.builder() .apiKey(azureOpenaiKey) .endpoint(azureOpenaiEndpoint) .deploymentName(azureOpenaiDeploymentName) .maxTokens(configuration.getMaxTokens()) .temperature(configuration.getTemperature()) - .timeout(ofSeconds(configuration.getDurationInSeconds())) + .timeout(ofSeconds(durationInSec)) .logRequestsAndResponses(true) .build(); } diff --git a/src/main/java/org/mule/extension/mulechain/internal/error/MuleChainErrorType.java b/src/main/java/org/mule/extension/mulechain/internal/error/MuleChainErrorType.java index 716f206..9be753f 100644 --- a/src/main/java/org/mule/extension/mulechain/internal/error/MuleChainErrorType.java +++ b/src/main/java/org/mule/extension/mulechain/internal/error/MuleChainErrorType.java @@ -7,5 +7,5 @@ public enum MuleChainErrorType implements ErrorTypeDefinition { - AI_SERVICES_FAILURE, IMAGE_ANALYSIS_FAILURE, IMAGE_GENERATION_FAILURE, IMAGE_PROCESSING_FAILURE, FILE_HANDLING_FAILURE, RAG_FAILURE, EMBEDDING_OPERATIONS_FAILURE, TOOLS_OPERATION_FAILURE, VALIDATION_FAILURE + AI_SERVICES_FAILURE, IMAGE_ANALYSIS_FAILURE, IMAGE_GENERATION_FAILURE, IMAGE_PROCESSING_FAILURE, FILE_HANDLING_FAILURE, RAG_FAILURE, EMBEDDING_OPERATIONS_FAILURE, TOOLS_OPERATION_FAILURE, VALIDATION_FAILURE, STREAMING_FAILURE } diff --git a/src/main/java/org/mule/extension/mulechain/internal/exception/StreamingException.java b/src/main/java/org/mule/extension/mulechain/internal/exception/StreamingException.java new file mode 100644 index 0000000..533b5fa --- /dev/null +++ b/src/main/java/org/mule/extension/mulechain/internal/exception/StreamingException.java @@ -0,0 +1,14 @@ +/** + * (c) 2003-2024 MuleSoft, Inc. The software in this package is published under the terms of the Commercial Free Software license V.1 a copy of which has been included with this distribution in the LICENSE.md file. + */ +package org.mule.extension.mulechain.internal.exception; + +import org.mule.extension.mulechain.internal.error.MuleChainErrorType; +import org.mule.runtime.extension.api.exception.ModuleException; + +public class StreamingException extends ModuleException { + + public StreamingException(String message, Throwable throwable) { + super(message, MuleChainErrorType.STREAMING_FAILURE, throwable); + } +} diff --git a/src/main/java/org/mule/extension/mulechain/internal/llm/LangchainLLMModelNameProvider.java b/src/main/java/org/mule/extension/mulechain/internal/llm/LangchainLLMModelNameProvider.java index 262675b..1a57a07 100644 --- a/src/main/java/org/mule/extension/mulechain/internal/llm/LangchainLLMModelNameProvider.java +++ b/src/main/java/org/mule/extension/mulechain/internal/llm/LangchainLLMModelNameProvider.java @@ -21,7 +21,7 @@ public class LangchainLLMModelNameProvider implements ValueProvider { @Override public Set resolve() throws ValueResolvingException { - return ValueBuilder.getValuesFor(LangchainLLMType.valueOf(llmType).getModelNameStream()); + return ValueBuilder.getValuesFor(LangchainLLMType.fromValue(llmType).getModelNameStream()); } } diff --git a/src/main/java/org/mule/extension/mulechain/internal/llm/config/ConfigType.java b/src/main/java/org/mule/extension/mulechain/internal/llm/config/ConfigType.java index b454be8..0689466 100644 --- a/src/main/java/org/mule/extension/mulechain/internal/llm/config/ConfigType.java +++ b/src/main/java/org/mule/extension/mulechain/internal/llm/config/ConfigType.java @@ -3,26 +3,38 @@ */ package org.mule.extension.mulechain.internal.llm.config; +import org.mule.extension.mulechain.internal.config.LangchainLLMConfiguration; +import org.mule.extension.mulechain.internal.exception.config.ConfigValidationException; + import java.util.Arrays; +import java.util.function.Function; public enum ConfigType { - ENV_VARIABLE("Environment Variables"), CONFIG_JSON("Configuration Json"); + ENV_VARIABLE("Environment Variables", (configuration) -> new EnvConfigExtractor()), CONFIG_JSON("Configuration Json", + FileConfigExtractor::new); private final String value; - ConfigType(String value) { + private final Function configExtractorFunction; + + ConfigType(String value, Function configExtractorFunction) { this.value = value; + this.configExtractorFunction = configExtractorFunction; } public static ConfigType fromValue(String value) { return Arrays.stream(ConfigType.values()) .filter(configType -> configType.value.equals(value)) .findFirst() - .orElseThrow(() -> new IllegalArgumentException("Unsupported Config Type: " + value)); + .orElseThrow(() -> new ConfigValidationException("Unsupported Config Type: " + value)); } public String getValue() { return value; } + + public Function getConfigExtractorFunction() { + return configExtractorFunction; + } } diff --git a/src/main/java/org/mule/extension/mulechain/internal/llm/type/LangchainLLMType.java b/src/main/java/org/mule/extension/mulechain/internal/llm/type/LangchainLLMType.java index b34d9ea..37ebb7d 100644 --- a/src/main/java/org/mule/extension/mulechain/internal/llm/type/LangchainLLMType.java +++ b/src/main/java/org/mule/extension/mulechain/internal/llm/type/LangchainLLMType.java @@ -4,28 +4,51 @@ package org.mule.extension.mulechain.internal.llm.type; import dev.langchain4j.model.anthropic.AnthropicChatModelName; +import dev.langchain4j.model.chat.ChatLanguageModel; import dev.langchain4j.model.mistralai.MistralAiChatModelName; import dev.langchain4j.model.openai.OpenAiChatModelName; import dev.langchain4j.model.openai.OpenAiImageModelName; +import org.mule.extension.mulechain.internal.config.LangchainLLMConfiguration; +import org.mule.extension.mulechain.internal.config.util.LangchainLLMInitializerUtil; +import org.mule.extension.mulechain.internal.exception.config.ConfigValidationException; +import org.mule.extension.mulechain.internal.llm.config.ConfigExtractor; import java.util.Arrays; +import java.util.function.BiFunction; import java.util.stream.Stream; public enum LangchainLLMType { - OPENAI(getOpenAIModelNameStream()), GROQAI_OPENAI(OPENAI.getModelNameStream()), MISTRAL_AI( - getMistralAIModelNameStream()), OLLAMA( - getOllamaModelNameStream()), ANTHROPIC(getAnthropicModelNameStream()), AZURE_OPENAI(OPENAI.getModelNameStream()); + OPENAI("OPENAI", getOpenAIModelNameStream(), LangchainLLMInitializerUtil::createOpenAiChatModel), GROQAI_OPENAI("GROQAI_OPENAI", + OPENAI.getModelNameStream(), LangchainLLMInitializerUtil::createGroqOpenAiChatModel), MISTRAL_AI("MISTRAL_AI", + getMistralAIModelNameStream(), LangchainLLMInitializerUtil::createMistralAiChatModel), OLLAMA("OLLAMA", + getOllamaModelNameStream(), LangchainLLMInitializerUtil::createOllamaChatModel), ANTHROPIC("ANTHROPIC", + getAnthropicModelNameStream(), LangchainLLMInitializerUtil::createAnthropicChatModel), AZURE_OPENAI( + "AZURE_OPENAI", OPENAI.getModelNameStream(), LangchainLLMInitializerUtil::createAzureOpenAiChatModel); + private final String value; private final Stream modelNameStream; - LangchainLLMType(Stream modelNameStream) { + private final BiFunction configBiFunction; + + LangchainLLMType(String value, Stream modelNameStream, + BiFunction configBiFunction) { + this.value = value; this.modelNameStream = modelNameStream; + this.configBiFunction = configBiFunction; + } + + public String getValue() { + return value; } public Stream getModelNameStream() { return modelNameStream; } + public BiFunction getConfigBiFunction() { + return configBiFunction; + } + private static Stream getOpenAIModelNameStream() { return Stream.concat(Arrays.stream(OpenAiChatModelName.values()), Arrays.stream(OpenAiImageModelName.values())) .map(String::valueOf); @@ -43,6 +66,13 @@ private static Stream getAnthropicModelNameStream() { return Arrays.stream(AnthropicChatModelName.values()).map(String::valueOf); } + public static LangchainLLMType fromValue(String value) { + return Arrays.stream(LangchainLLMType.values()) + .filter(langchainLLMType -> langchainLLMType.value.equals(value)) + .findFirst() + .orElseThrow(() -> new ConfigValidationException("Unsupported LLM Type: " + value)); + } + enum OllamaModelName { MISTRAL("mistral"), PHI3("phi3"), ORCA_MINI("orca-mini"), LLAMA2("llama2"), CODE_LLAMA("codellama"), TINY_LLAMA("tinyllama"); diff --git a/src/main/java/org/mule/extension/mulechain/internal/operation/LangchainStreamingOperations.java b/src/main/java/org/mule/extension/mulechain/internal/operation/LangchainStreamingOperations.java new file mode 100644 index 0000000..720be5a --- /dev/null +++ b/src/main/java/org/mule/extension/mulechain/internal/operation/LangchainStreamingOperations.java @@ -0,0 +1,87 @@ +/** + * (c) 2003-2024 MuleSoft, Inc. The software in this package is published under the terms of the Commercial Free Software license V.1 a copy of which has been included with this distribution in the LICENSE.md file. + */ +package org.mule.extension.mulechain.internal.operation; + +import dev.langchain4j.model.chat.StreamingChatLanguageModel; +import dev.langchain4j.model.openai.OpenAiStreamingChatModel; +import dev.langchain4j.service.AiServices; +import dev.langchain4j.service.TokenStream; +import org.mule.extension.mulechain.internal.config.LangchainLLMConfiguration; +import org.mule.extension.mulechain.internal.error.provider.AiServiceErrorTypeProvider; +import org.mule.extension.mulechain.internal.exception.ChatException; +import org.mule.extension.mulechain.internal.exception.StreamingException; +import org.mule.runtime.extension.api.annotation.Alias; +import org.mule.runtime.extension.api.annotation.error.Throws; +import org.mule.runtime.extension.api.annotation.param.Config; +import org.mule.runtime.extension.api.annotation.param.MediaType; + +import java.io.IOException; +import java.io.InputStream; +import java.io.PipedInputStream; +import java.io.PipedOutputStream; +import java.nio.charset.StandardCharsets; + +import static java.time.Duration.ofSeconds; +import static org.mule.runtime.extension.api.annotation.param.MediaType.ANY; + +/** + * This class is a container for operations, every public method in this class will be taken as an extension operation. + */ +public class LangchainStreamingOperations { + + interface Assistant { + + TokenStream chat(String userMessage); + } + + /** + * Implements a simple Chat agent + */ + @MediaType(value = ANY, strict = false) + @Alias("CHAT-answer-prompt-w-stream") + @Throws(AiServiceErrorTypeProvider.class) + public InputStream answerPromptByModelNameStream(@Config LangchainLLMConfiguration configuration, String prompt) { + String openaiApiKey = configuration.getConfigExtractor().extractValue("OPENAI_API_KEY"); + long durationInSec = configuration.getLlmTimeoutUnit().toSeconds(configuration.getLlmTimeout()); + try { + StreamingChatLanguageModel model = OpenAiStreamingChatModel.builder() + .apiKey(openaiApiKey) + .modelName(configuration.getModelName()) + .maxTokens(configuration.getMaxTokens()) + .temperature(configuration.getTemperature()) + .timeout(ofSeconds(durationInSec)) + .build(); + Assistant assistant = AiServices.create(Assistant.class, model); + TokenStream tokenStream = assistant.chat(prompt); + + PipedOutputStream pipedOutputStream = new PipedOutputStream(); + PipedInputStream pipedInputStream = new PipedInputStream(pipedOutputStream); + + tokenStream.onNext(value -> { + try { + pipedOutputStream.write(value.getBytes(StandardCharsets.UTF_8)); + } catch (IOException e) { + throw new StreamingException("Error occurred while streaming output", e); + } + }) + .onComplete(response -> { + try { + pipedOutputStream.close(); + } catch (IOException e) { + throw new StreamingException("Error occurred while closing the stream", e); + } + }) + .onError(throwable -> { + throw new StreamingException("Exception occurred onError()", throwable); + }) + .start(); + return pipedInputStream; + } catch (Exception e) { + throw new ChatException("Unable to respond with the chat provided", e); + } + } + + + +} diff --git a/src/test/munit/langchain-llm-operation-testing-suite.xml b/src/test/munit/langchain-llm-operation-testing-suite.xml index f0181e2..4aced6b 100644 --- a/src/test/munit/langchain-llm-operation-testing-suite.xml +++ b/src/test/munit/langchain-llm-operation-testing-suite.xml @@ -12,6 +12,7 @@ @@ -23,7 +24,7 @@ config-ref="OPENAI" prompt="#[payload.data]"> - +