From 8ab002612377a405bcee44091023bc3cbd712e9c Mon Sep 17 00:00:00 2001
From: Arpit Gupta <162559421+arpitg-1@users.noreply.github.com>
Date: Tue, 13 Aug 2024 15:05:04 +0530
Subject: [PATCH] W-16295334: Added changes required for certification (#24)
* W-16295334: Added changes for duration parameterisation
* W-16295334: Fixed failing munit test case
* W-16295334: Updated munit
* W-16295334: Removed model and llm default values
* W-16295334: Removed static initialisations from the code
* W-16295334: Bug fixes
* W-16295334: Added one streaming operation for chat
---
.../main/mule/mulechain-ai-connector-demo.xml | 2 +-
.../config/LangchainLLMConfiguration.java | 74 ++++++----------
.../util/LangchainLLMInitializerUtil.java | 20 +++--
.../internal/error/MuleChainErrorType.java | 2 +-
.../exception/StreamingException.java | 14 +++
.../llm/LangchainLLMModelNameProvider.java | 2 +-
.../internal/llm/config/ConfigType.java | 18 +++-
.../internal/llm/type/LangchainLLMType.java | 38 +++++++-
.../LangchainStreamingOperations.java | 87 +++++++++++++++++++
.../langchain-llm-operation-testing-suite.xml | 3 +-
10 files changed, 197 insertions(+), 63 deletions(-)
create mode 100644 src/main/java/org/mule/extension/mulechain/internal/exception/StreamingException.java
create mode 100644 src/main/java/org/mule/extension/mulechain/internal/operation/LangchainStreamingOperations.java
diff --git a/demo/mulechain-ai-connector-demo/src/main/mule/mulechain-ai-connector-demo.xml b/demo/mulechain-ai-connector-demo/src/main/mule/mulechain-ai-connector-demo.xml
index 8dc65d7..777cb0d 100644
--- a/demo/mulechain-ai-connector-demo/src/main/mule/mulechain-ai-connector-demo.xml
+++ b/demo/mulechain-ai-connector-demo/src/main/mule/mulechain-ai-connector-demo.xml
@@ -10,7 +10,7 @@ http://www.mulesoft.org/schema/mule/ee/core http://www.mulesoft.org/schema/mule/
-
+
diff --git a/src/main/java/org/mule/extension/mulechain/internal/config/LangchainLLMConfiguration.java b/src/main/java/org/mule/extension/mulechain/internal/config/LangchainLLMConfiguration.java
index f9fc82a..d33233e 100644
--- a/src/main/java/org/mule/extension/mulechain/internal/config/LangchainLLMConfiguration.java
+++ b/src/main/java/org/mule/extension/mulechain/internal/config/LangchainLLMConfiguration.java
@@ -4,19 +4,15 @@
package org.mule.extension.mulechain.internal.config;
import dev.langchain4j.model.chat.ChatLanguageModel;
-import org.mule.extension.mulechain.internal.exception.config.ConfigValidationException;
import org.mule.extension.mulechain.internal.operation.LangchainEmbeddingStoresOperations;
import org.mule.extension.mulechain.internal.operation.LangchainImageModelsOperations;
import org.mule.extension.mulechain.internal.llm.type.LangchainLLMType;
import org.mule.extension.mulechain.internal.llm.ConfigTypeProvider;
-import org.mule.extension.mulechain.internal.config.util.LangchainLLMInitializerUtil;
import org.mule.extension.mulechain.internal.operation.LangchainLLMOperations;
import org.mule.extension.mulechain.internal.llm.LangchainLLMModelNameProvider;
import org.mule.extension.mulechain.internal.llm.LangchainLLMTypeProvider;
import org.mule.extension.mulechain.internal.llm.config.ConfigExtractor;
import org.mule.extension.mulechain.internal.llm.config.ConfigType;
-import org.mule.extension.mulechain.internal.llm.config.EnvConfigExtractor;
-import org.mule.extension.mulechain.internal.llm.config.FileConfigExtractor;
import org.mule.runtime.api.lifecycle.Initialisable;
import org.mule.runtime.api.lifecycle.InitialisationException;
import org.mule.runtime.api.meta.ExpressionSupport;
@@ -27,12 +23,10 @@
import org.mule.runtime.extension.api.annotation.param.Parameter;
import org.mule.runtime.extension.api.annotation.param.display.DisplayName;
import org.mule.runtime.extension.api.annotation.param.display.Placement;
+import org.mule.runtime.extension.api.annotation.param.display.Summary;
import org.mule.runtime.extension.api.annotation.values.OfValues;
-import java.util.HashMap;
-import java.util.Map;
-import java.util.function.BiFunction;
-import java.util.function.Function;
+import java.util.concurrent.TimeUnit;
/**
* This class represents an extension configuration, values set in this class are commonly used across multiple
@@ -42,25 +36,9 @@
@Operations({LangchainLLMOperations.class, LangchainEmbeddingStoresOperations.class, LangchainImageModelsOperations.class})
public class LangchainLLMConfiguration implements Initialisable {
- private static final Map> llmMap;
- private static final Map> configExtractorMap;
-
- static {
- configExtractorMap = new HashMap<>();
- configExtractorMap.put(ConfigType.ENV_VARIABLE, (configuration) -> new EnvConfigExtractor());
- configExtractorMap.put(ConfigType.CONFIG_JSON, FileConfigExtractor::new);
-
- llmMap = new HashMap<>();
- llmMap.put(LangchainLLMType.OPENAI, (LangchainLLMInitializerUtil::createOpenAiChatModel));
- llmMap.put(LangchainLLMType.GROQAI_OPENAI, (LangchainLLMInitializerUtil::createGroqOpenAiChatModel));
- llmMap.put(LangchainLLMType.MISTRAL_AI, (LangchainLLMInitializerUtil::createMistralAiChatModel));
- llmMap.put(LangchainLLMType.OLLAMA, (LangchainLLMInitializerUtil::createOllamaChatModel));
- llmMap.put(LangchainLLMType.ANTHROPIC, (LangchainLLMInitializerUtil::createAnthropicChatModel));
- llmMap.put(LangchainLLMType.AZURE_OPENAI, (LangchainLLMInitializerUtil::createAzureOpenAiChatModel));
- }
-
@Parameter
@Placement(order = 1, tab = Placement.DEFAULT_TAB)
+ @DisplayName("LLM type")
@OfValues(LangchainLLMTypeProvider.class)
private String llmType;
@@ -71,28 +49,35 @@ public class LangchainLLMConfiguration implements Initialisable {
@Parameter
@Placement(order = 3, tab = Placement.DEFAULT_TAB)
+ @Optional(defaultValue = "#[-]")
private String filePath;
@Parameter
@Expression(ExpressionSupport.SUPPORTED)
@OfValues(LangchainLLMModelNameProvider.class)
- @Optional(defaultValue = "gpt-3.5-turbo")
- @Placement(order = 4)
- private String modelName = "gpt-3.5-turbo";
+ @Placement(order = 4, tab = Placement.DEFAULT_TAB)
+ private String modelName;
@Parameter
- @Placement(order = 5)
+ @Placement(order = 5, tab = Placement.DEFAULT_TAB)
@Optional(defaultValue = "0.7")
private double temperature = 0.7;
@Parameter
- @Placement(order = 6)
+ @Placement(order = 6, tab = Placement.DEFAULT_TAB)
@Optional(defaultValue = "60")
- @DisplayName("Duration in sec")
- private long durationInSeconds = 60;
+ @DisplayName("LLM timeout")
+ private int llmTimeout = 60;
+
+ @Parameter
+ @Optional(defaultValue = "SECONDS")
+ @Placement(order = 7, tab = Placement.DEFAULT_TAB)
+ @DisplayName("LLM timeout unit")
+ @Summary("Time unit to be used in the LLM Timeout")
+ private TimeUnit llmTimeoutUnit = TimeUnit.SECONDS;
@Parameter
- @Placement(order = 7)
+ @Placement(order = 8, tab = Placement.DEFAULT_TAB)
@Expression(ExpressionSupport.SUPPORTED)
@Optional(defaultValue = "500")
private int maxTokens = 500;
@@ -121,8 +106,12 @@ public double getTemperature() {
return temperature;
}
- public long getDurationInSeconds() {
- return durationInSeconds;
+ public int getLlmTimeout() {
+ return llmTimeout;
+ }
+
+ public TimeUnit getLlmTimeoutUnit() {
+ return llmTimeoutUnit;
}
public int getMaxTokens() {
@@ -138,21 +127,14 @@ public ChatLanguageModel getModel() {
}
private ChatLanguageModel createModel(ConfigExtractor configExtractor) {
- LangchainLLMType type = LangchainLLMType.valueOf(llmType);
- if (llmMap.containsKey(type)) {
- return llmMap.get(type).apply(configExtractor, this);
- }
- throw new ConfigValidationException("LLM Type not supported: " + llmType);
+ LangchainLLMType type = LangchainLLMType.fromValue(llmType);
+ return type.getConfigBiFunction().apply(configExtractor, this);
}
@Override
public void initialise() throws InitialisationException {
ConfigType config = ConfigType.fromValue(configType);
- if (configExtractorMap.containsKey(config)) {
- configExtractor = configExtractorMap.get(config).apply(this);
- model = createModel(configExtractor);
- } else {
- throw new ConfigValidationException("Config Type not supported: " + configType);
- }
+ configExtractor = config.getConfigExtractorFunction().apply(this);
+ model = createModel(configExtractor);
}
}
diff --git a/src/main/java/org/mule/extension/mulechain/internal/config/util/LangchainLLMInitializerUtil.java b/src/main/java/org/mule/extension/mulechain/internal/config/util/LangchainLLMInitializerUtil.java
index 6fefda6..d20e786 100644
--- a/src/main/java/org/mule/extension/mulechain/internal/config/util/LangchainLLMInitializerUtil.java
+++ b/src/main/java/org/mule/extension/mulechain/internal/config/util/LangchainLLMInitializerUtil.java
@@ -11,6 +11,8 @@
import org.mule.extension.mulechain.internal.config.LangchainLLMConfiguration;
import org.mule.extension.mulechain.internal.llm.config.ConfigExtractor;
+import java.time.Duration;
+
import static java.time.Duration.ofSeconds;
public final class LangchainLLMInitializerUtil {
@@ -19,12 +21,13 @@ private LangchainLLMInitializerUtil() {}
public static OpenAiChatModel createOpenAiChatModel(ConfigExtractor configExtractor, LangchainLLMConfiguration configuration) {
String openaiApiKey = configExtractor.extractValue("OPENAI_API_KEY");
+ long durationInSec = configuration.getLlmTimeoutUnit().toSeconds(configuration.getLlmTimeout());
return OpenAiChatModel.builder()
.apiKey(openaiApiKey)
.modelName(configuration.getModelName())
.maxTokens(configuration.getMaxTokens())
.temperature(configuration.getTemperature())
- .timeout(ofSeconds(configuration.getDurationInSeconds()))
+ .timeout(ofSeconds(durationInSec))
.logRequests(true)
.logResponses(true)
.build();
@@ -34,13 +37,14 @@ public static OpenAiChatModel createOpenAiChatModel(ConfigExtractor configExtrac
public static OpenAiChatModel createGroqOpenAiChatModel(ConfigExtractor configExtractor,
LangchainLLMConfiguration configuration) {
String groqApiKey = configExtractor.extractValue("GROQ_API_KEY");
+ long durationInSec = configuration.getLlmTimeoutUnit().toSeconds(configuration.getLlmTimeout());
return OpenAiChatModel.builder()
.baseUrl("https://api.groq.com/openai/v1")
.apiKey(groqApiKey)
.modelName(configuration.getModelName())
.maxTokens(configuration.getMaxTokens())
.temperature(configuration.getTemperature())
- .timeout(ofSeconds(configuration.getDurationInSeconds()))
+ .timeout(ofSeconds(durationInSec))
.logRequests(true)
.logResponses(true)
.build();
@@ -51,13 +55,14 @@ public static OpenAiChatModel createGroqOpenAiChatModel(ConfigExtractor configEx
public static MistralAiChatModel createMistralAiChatModel(ConfigExtractor configExtractor,
LangchainLLMConfiguration configuration) {
String mistralAiApiKey = configExtractor.extractValue("MISTRAL_AI_API_KEY");
+ long durationInSec = configuration.getLlmTimeoutUnit().toSeconds(configuration.getLlmTimeout());
return MistralAiChatModel.builder()
//.apiKey(configuration.getLlmApiKey())
.apiKey(mistralAiApiKey)
.modelName(configuration.getModelName())
.maxTokens(configuration.getMaxTokens())
.temperature(configuration.getTemperature())
- .timeout(ofSeconds(configuration.getDurationInSeconds()))
+ .timeout(ofSeconds(durationInSec))
.logRequests(true)
.logResponses(true)
.build();
@@ -65,12 +70,13 @@ public static MistralAiChatModel createMistralAiChatModel(ConfigExtractor config
public static OllamaChatModel createOllamaChatModel(ConfigExtractor configExtractor, LangchainLLMConfiguration configuration) {
String ollamaBaseUrl = configExtractor.extractValue("OLLAMA_BASE_URL");
+ long durationInSec = configuration.getLlmTimeoutUnit().toSeconds(configuration.getLlmTimeout());
return OllamaChatModel.builder()
//.baseUrl(configuration.getLlmApiKey())
.baseUrl(ollamaBaseUrl)
.modelName(configuration.getModelName())
.temperature(configuration.getTemperature())
- .timeout(ofSeconds(configuration.getDurationInSeconds()))
+ .timeout(ofSeconds(durationInSec))
.build();
}
@@ -78,13 +84,14 @@ public static OllamaChatModel createOllamaChatModel(ConfigExtractor configExtrac
public static AnthropicChatModel createAnthropicChatModel(ConfigExtractor configExtractor,
LangchainLLMConfiguration configuration) {
String anthropicApiKey = configExtractor.extractValue("ANTHROPIC_API_KEY");
+ long durationInSec = configuration.getLlmTimeoutUnit().toSeconds(configuration.getLlmTimeout());
return AnthropicChatModel.builder()
//.apiKey(configuration.getLlmApiKey())
.apiKey(anthropicApiKey)
.modelName(configuration.getModelName())
.maxTokens(configuration.getMaxTokens())
.temperature(configuration.getTemperature())
- .timeout(ofSeconds(configuration.getDurationInSeconds()))
+ .timeout(ofSeconds(durationInSec))
.logRequests(true)
.logResponses(true)
.build();
@@ -96,13 +103,14 @@ public static AzureOpenAiChatModel createAzureOpenAiChatModel(ConfigExtractor co
String azureOpenaiKey = configExtractor.extractValue("AZURE_OPENAI_KEY");
String azureOpenaiEndpoint = configExtractor.extractValue("AZURE_OPENAI_ENDPOINT");
String azureOpenaiDeploymentName = configExtractor.extractValue("AZURE_OPENAI_DEPLOYMENT_NAME");
+ long durationInSec = configuration.getLlmTimeoutUnit().toSeconds(configuration.getLlmTimeout());
return AzureOpenAiChatModel.builder()
.apiKey(azureOpenaiKey)
.endpoint(azureOpenaiEndpoint)
.deploymentName(azureOpenaiDeploymentName)
.maxTokens(configuration.getMaxTokens())
.temperature(configuration.getTemperature())
- .timeout(ofSeconds(configuration.getDurationInSeconds()))
+ .timeout(ofSeconds(durationInSec))
.logRequestsAndResponses(true)
.build();
}
diff --git a/src/main/java/org/mule/extension/mulechain/internal/error/MuleChainErrorType.java b/src/main/java/org/mule/extension/mulechain/internal/error/MuleChainErrorType.java
index 716f206..9be753f 100644
--- a/src/main/java/org/mule/extension/mulechain/internal/error/MuleChainErrorType.java
+++ b/src/main/java/org/mule/extension/mulechain/internal/error/MuleChainErrorType.java
@@ -7,5 +7,5 @@
public enum MuleChainErrorType implements ErrorTypeDefinition {
- AI_SERVICES_FAILURE, IMAGE_ANALYSIS_FAILURE, IMAGE_GENERATION_FAILURE, IMAGE_PROCESSING_FAILURE, FILE_HANDLING_FAILURE, RAG_FAILURE, EMBEDDING_OPERATIONS_FAILURE, TOOLS_OPERATION_FAILURE, VALIDATION_FAILURE
+ AI_SERVICES_FAILURE, IMAGE_ANALYSIS_FAILURE, IMAGE_GENERATION_FAILURE, IMAGE_PROCESSING_FAILURE, FILE_HANDLING_FAILURE, RAG_FAILURE, EMBEDDING_OPERATIONS_FAILURE, TOOLS_OPERATION_FAILURE, VALIDATION_FAILURE, STREAMING_FAILURE
}
diff --git a/src/main/java/org/mule/extension/mulechain/internal/exception/StreamingException.java b/src/main/java/org/mule/extension/mulechain/internal/exception/StreamingException.java
new file mode 100644
index 0000000..533b5fa
--- /dev/null
+++ b/src/main/java/org/mule/extension/mulechain/internal/exception/StreamingException.java
@@ -0,0 +1,14 @@
+/**
+ * (c) 2003-2024 MuleSoft, Inc. The software in this package is published under the terms of the Commercial Free Software license V.1 a copy of which has been included with this distribution in the LICENSE.md file.
+ */
+package org.mule.extension.mulechain.internal.exception;
+
+import org.mule.extension.mulechain.internal.error.MuleChainErrorType;
+import org.mule.runtime.extension.api.exception.ModuleException;
+
+public class StreamingException extends ModuleException {
+
+ public StreamingException(String message, Throwable throwable) {
+ super(message, MuleChainErrorType.STREAMING_FAILURE, throwable);
+ }
+}
diff --git a/src/main/java/org/mule/extension/mulechain/internal/llm/LangchainLLMModelNameProvider.java b/src/main/java/org/mule/extension/mulechain/internal/llm/LangchainLLMModelNameProvider.java
index 262675b..1a57a07 100644
--- a/src/main/java/org/mule/extension/mulechain/internal/llm/LangchainLLMModelNameProvider.java
+++ b/src/main/java/org/mule/extension/mulechain/internal/llm/LangchainLLMModelNameProvider.java
@@ -21,7 +21,7 @@ public class LangchainLLMModelNameProvider implements ValueProvider {
@Override
public Set resolve() throws ValueResolvingException {
- return ValueBuilder.getValuesFor(LangchainLLMType.valueOf(llmType).getModelNameStream());
+ return ValueBuilder.getValuesFor(LangchainLLMType.fromValue(llmType).getModelNameStream());
}
}
diff --git a/src/main/java/org/mule/extension/mulechain/internal/llm/config/ConfigType.java b/src/main/java/org/mule/extension/mulechain/internal/llm/config/ConfigType.java
index b454be8..0689466 100644
--- a/src/main/java/org/mule/extension/mulechain/internal/llm/config/ConfigType.java
+++ b/src/main/java/org/mule/extension/mulechain/internal/llm/config/ConfigType.java
@@ -3,26 +3,38 @@
*/
package org.mule.extension.mulechain.internal.llm.config;
+import org.mule.extension.mulechain.internal.config.LangchainLLMConfiguration;
+import org.mule.extension.mulechain.internal.exception.config.ConfigValidationException;
+
import java.util.Arrays;
+import java.util.function.Function;
public enum ConfigType {
- ENV_VARIABLE("Environment Variables"), CONFIG_JSON("Configuration Json");
+ ENV_VARIABLE("Environment Variables", (configuration) -> new EnvConfigExtractor()), CONFIG_JSON("Configuration Json",
+ FileConfigExtractor::new);
private final String value;
- ConfigType(String value) {
+ private final Function configExtractorFunction;
+
+ ConfigType(String value, Function configExtractorFunction) {
this.value = value;
+ this.configExtractorFunction = configExtractorFunction;
}
public static ConfigType fromValue(String value) {
return Arrays.stream(ConfigType.values())
.filter(configType -> configType.value.equals(value))
.findFirst()
- .orElseThrow(() -> new IllegalArgumentException("Unsupported Config Type: " + value));
+ .orElseThrow(() -> new ConfigValidationException("Unsupported Config Type: " + value));
}
public String getValue() {
return value;
}
+
+ public Function getConfigExtractorFunction() {
+ return configExtractorFunction;
+ }
}
diff --git a/src/main/java/org/mule/extension/mulechain/internal/llm/type/LangchainLLMType.java b/src/main/java/org/mule/extension/mulechain/internal/llm/type/LangchainLLMType.java
index b34d9ea..37ebb7d 100644
--- a/src/main/java/org/mule/extension/mulechain/internal/llm/type/LangchainLLMType.java
+++ b/src/main/java/org/mule/extension/mulechain/internal/llm/type/LangchainLLMType.java
@@ -4,28 +4,51 @@
package org.mule.extension.mulechain.internal.llm.type;
import dev.langchain4j.model.anthropic.AnthropicChatModelName;
+import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.mistralai.MistralAiChatModelName;
import dev.langchain4j.model.openai.OpenAiChatModelName;
import dev.langchain4j.model.openai.OpenAiImageModelName;
+import org.mule.extension.mulechain.internal.config.LangchainLLMConfiguration;
+import org.mule.extension.mulechain.internal.config.util.LangchainLLMInitializerUtil;
+import org.mule.extension.mulechain.internal.exception.config.ConfigValidationException;
+import org.mule.extension.mulechain.internal.llm.config.ConfigExtractor;
import java.util.Arrays;
+import java.util.function.BiFunction;
import java.util.stream.Stream;
public enum LangchainLLMType {
- OPENAI(getOpenAIModelNameStream()), GROQAI_OPENAI(OPENAI.getModelNameStream()), MISTRAL_AI(
- getMistralAIModelNameStream()), OLLAMA(
- getOllamaModelNameStream()), ANTHROPIC(getAnthropicModelNameStream()), AZURE_OPENAI(OPENAI.getModelNameStream());
+ OPENAI("OPENAI", getOpenAIModelNameStream(), LangchainLLMInitializerUtil::createOpenAiChatModel), GROQAI_OPENAI("GROQAI_OPENAI",
+ OPENAI.getModelNameStream(), LangchainLLMInitializerUtil::createGroqOpenAiChatModel), MISTRAL_AI("MISTRAL_AI",
+ getMistralAIModelNameStream(), LangchainLLMInitializerUtil::createMistralAiChatModel), OLLAMA("OLLAMA",
+ getOllamaModelNameStream(), LangchainLLMInitializerUtil::createOllamaChatModel), ANTHROPIC("ANTHROPIC",
+ getAnthropicModelNameStream(), LangchainLLMInitializerUtil::createAnthropicChatModel), AZURE_OPENAI(
+ "AZURE_OPENAI", OPENAI.getModelNameStream(), LangchainLLMInitializerUtil::createAzureOpenAiChatModel);
+ private final String value;
private final Stream modelNameStream;
- LangchainLLMType(Stream modelNameStream) {
+ private final BiFunction configBiFunction;
+
+ LangchainLLMType(String value, Stream modelNameStream,
+ BiFunction configBiFunction) {
+ this.value = value;
this.modelNameStream = modelNameStream;
+ this.configBiFunction = configBiFunction;
+ }
+
+ public String getValue() {
+ return value;
}
public Stream getModelNameStream() {
return modelNameStream;
}
+ public BiFunction getConfigBiFunction() {
+ return configBiFunction;
+ }
+
private static Stream getOpenAIModelNameStream() {
return Stream.concat(Arrays.stream(OpenAiChatModelName.values()), Arrays.stream(OpenAiImageModelName.values()))
.map(String::valueOf);
@@ -43,6 +66,13 @@ private static Stream getAnthropicModelNameStream() {
return Arrays.stream(AnthropicChatModelName.values()).map(String::valueOf);
}
+ public static LangchainLLMType fromValue(String value) {
+ return Arrays.stream(LangchainLLMType.values())
+ .filter(langchainLLMType -> langchainLLMType.value.equals(value))
+ .findFirst()
+ .orElseThrow(() -> new ConfigValidationException("Unsupported LLM Type: " + value));
+ }
+
enum OllamaModelName {
MISTRAL("mistral"), PHI3("phi3"), ORCA_MINI("orca-mini"), LLAMA2("llama2"), CODE_LLAMA("codellama"), TINY_LLAMA("tinyllama");
diff --git a/src/main/java/org/mule/extension/mulechain/internal/operation/LangchainStreamingOperations.java b/src/main/java/org/mule/extension/mulechain/internal/operation/LangchainStreamingOperations.java
new file mode 100644
index 0000000..720be5a
--- /dev/null
+++ b/src/main/java/org/mule/extension/mulechain/internal/operation/LangchainStreamingOperations.java
@@ -0,0 +1,87 @@
+/**
+ * (c) 2003-2024 MuleSoft, Inc. The software in this package is published under the terms of the Commercial Free Software license V.1 a copy of which has been included with this distribution in the LICENSE.md file.
+ */
+package org.mule.extension.mulechain.internal.operation;
+
+import dev.langchain4j.model.chat.StreamingChatLanguageModel;
+import dev.langchain4j.model.openai.OpenAiStreamingChatModel;
+import dev.langchain4j.service.AiServices;
+import dev.langchain4j.service.TokenStream;
+import org.mule.extension.mulechain.internal.config.LangchainLLMConfiguration;
+import org.mule.extension.mulechain.internal.error.provider.AiServiceErrorTypeProvider;
+import org.mule.extension.mulechain.internal.exception.ChatException;
+import org.mule.extension.mulechain.internal.exception.StreamingException;
+import org.mule.runtime.extension.api.annotation.Alias;
+import org.mule.runtime.extension.api.annotation.error.Throws;
+import org.mule.runtime.extension.api.annotation.param.Config;
+import org.mule.runtime.extension.api.annotation.param.MediaType;
+
+import java.io.IOException;
+import java.io.InputStream;
+import java.io.PipedInputStream;
+import java.io.PipedOutputStream;
+import java.nio.charset.StandardCharsets;
+
+import static java.time.Duration.ofSeconds;
+import static org.mule.runtime.extension.api.annotation.param.MediaType.ANY;
+
+/**
+ * This class is a container for operations, every public method in this class will be taken as an extension operation.
+ */
+public class LangchainStreamingOperations {
+
+ interface Assistant {
+
+ TokenStream chat(String userMessage);
+ }
+
+ /**
+ * Implements a simple Chat agent
+ */
+ @MediaType(value = ANY, strict = false)
+ @Alias("CHAT-answer-prompt-w-stream")
+ @Throws(AiServiceErrorTypeProvider.class)
+ public InputStream answerPromptByModelNameStream(@Config LangchainLLMConfiguration configuration, String prompt) {
+ String openaiApiKey = configuration.getConfigExtractor().extractValue("OPENAI_API_KEY");
+ long durationInSec = configuration.getLlmTimeoutUnit().toSeconds(configuration.getLlmTimeout());
+ try {
+ StreamingChatLanguageModel model = OpenAiStreamingChatModel.builder()
+ .apiKey(openaiApiKey)
+ .modelName(configuration.getModelName())
+ .maxTokens(configuration.getMaxTokens())
+ .temperature(configuration.getTemperature())
+ .timeout(ofSeconds(durationInSec))
+ .build();
+ Assistant assistant = AiServices.create(Assistant.class, model);
+ TokenStream tokenStream = assistant.chat(prompt);
+
+ PipedOutputStream pipedOutputStream = new PipedOutputStream();
+ PipedInputStream pipedInputStream = new PipedInputStream(pipedOutputStream);
+
+ tokenStream.onNext(value -> {
+ try {
+ pipedOutputStream.write(value.getBytes(StandardCharsets.UTF_8));
+ } catch (IOException e) {
+ throw new StreamingException("Error occurred while streaming output", e);
+ }
+ })
+ .onComplete(response -> {
+ try {
+ pipedOutputStream.close();
+ } catch (IOException e) {
+ throw new StreamingException("Error occurred while closing the stream", e);
+ }
+ })
+ .onError(throwable -> {
+ throw new StreamingException("Exception occurred onError()", throwable);
+ })
+ .start();
+ return pipedInputStream;
+ } catch (Exception e) {
+ throw new ChatException("Unable to respond with the chat provided", e);
+ }
+ }
+
+
+
+}
diff --git a/src/test/munit/langchain-llm-operation-testing-suite.xml b/src/test/munit/langchain-llm-operation-testing-suite.xml
index f0181e2..4aced6b 100644
--- a/src/test/munit/langchain-llm-operation-testing-suite.xml
+++ b/src/test/munit/langchain-llm-operation-testing-suite.xml
@@ -12,6 +12,7 @@
@@ -23,7 +24,7 @@
config-ref="OPENAI" prompt="#[payload.data]">
-
+