Skip to content

Commit

Permalink
W-16295334: Added changes required for certification (#24)
Browse files Browse the repository at this point in the history
* W-16295334: Added changes for duration parameterisation

* W-16295334: Fixed failing munit test case

* W-16295334: Updated munit

* W-16295334: Removed model and llm default values

* W-16295334: Removed static initialisations from the code

* W-16295334: Bug fixes

* W-16295334: Added one streaming operation for chat
  • Loading branch information
arpitg-1 authored Aug 13, 2024
1 parent 9f53ff7 commit 8ab0026
Show file tree
Hide file tree
Showing 10 changed files with 197 additions and 63 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ http://www.mulesoft.org/schema/mule/ee/core http://www.mulesoft.org/schema/mule/
<http:listener-connection host="0.0.0.0" port="8081" />
</http:listener-config>
<mulechain:llm-configuration-config name="MISTRAL_AI" llmType="MISTRAL_AI" configType="Configuration Json" doc:name="MuleChain AI Llm configuration" doc:id="bf1ef7ec-4aa1-41c8-a184-a13ca165c925" filePath='#[mule.home ++ "/apps/" ++ app.name ++ "/envVars.json"]' modelName="mistral-large-latest" temperature="0.1"/>
<mulechain:llm-configuration-config name="OPENAI" llmType="OPENAI" configType="Configuration Json" doc:name="MuleChain AI Llm configuration" doc:id="edb0d5a6-97c5-4d93-8098-4e197e563827" filePath='#[mule.home ++ "/apps/" ++ app.name ++ "/envVars.json"]' temperature="0.1"/>
<mulechain:llm-configuration-config name="OPENAI" llmType="OPENAI" configType="Configuration Json" doc:name="MuleChain AI Llm configuration" doc:id="edb0d5a6-97c5-4d93-8098-4e197e563827" filePath='#[mule.home ++ "/apps/" ++ app.name ++ "/envVars.json"]' temperature="0.1" modelName="gpt-3.5-turbo"/>
<mulechain:llm-configuration-config name="OPENAI-GPT4-TURBO" llmType="OPENAI" configType="Configuration Json" doc:name="MuleChain AI Llm configuration" doc:id="74b85066-1569-4f10-a06b-e49e854eeef2" filePath='#[mule.home ++ "/apps/" ++ app.name ++ "/envVars.json"]' modelName="gpt-4o" />
<flow name="PromptTemplate" doc:id="cff3a8ed-3799-424a-becf-9d7387729bd0" >
<http:listener doc:name="Listener" doc:id="dd18126e-81f5-48ef-8f35-9dd19afdfaf0" config-ref="HTTP_Listener_config" path="/agent"/>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,15 @@
package org.mule.extension.mulechain.internal.config;

import dev.langchain4j.model.chat.ChatLanguageModel;
import org.mule.extension.mulechain.internal.exception.config.ConfigValidationException;
import org.mule.extension.mulechain.internal.operation.LangchainEmbeddingStoresOperations;
import org.mule.extension.mulechain.internal.operation.LangchainImageModelsOperations;
import org.mule.extension.mulechain.internal.llm.type.LangchainLLMType;
import org.mule.extension.mulechain.internal.llm.ConfigTypeProvider;
import org.mule.extension.mulechain.internal.config.util.LangchainLLMInitializerUtil;
import org.mule.extension.mulechain.internal.operation.LangchainLLMOperations;
import org.mule.extension.mulechain.internal.llm.LangchainLLMModelNameProvider;
import org.mule.extension.mulechain.internal.llm.LangchainLLMTypeProvider;
import org.mule.extension.mulechain.internal.llm.config.ConfigExtractor;
import org.mule.extension.mulechain.internal.llm.config.ConfigType;
import org.mule.extension.mulechain.internal.llm.config.EnvConfigExtractor;
import org.mule.extension.mulechain.internal.llm.config.FileConfigExtractor;
import org.mule.runtime.api.lifecycle.Initialisable;
import org.mule.runtime.api.lifecycle.InitialisationException;
import org.mule.runtime.api.meta.ExpressionSupport;
Expand All @@ -27,12 +23,10 @@
import org.mule.runtime.extension.api.annotation.param.Parameter;
import org.mule.runtime.extension.api.annotation.param.display.DisplayName;
import org.mule.runtime.extension.api.annotation.param.display.Placement;
import org.mule.runtime.extension.api.annotation.param.display.Summary;
import org.mule.runtime.extension.api.annotation.values.OfValues;

import java.util.HashMap;
import java.util.Map;
import java.util.function.BiFunction;
import java.util.function.Function;
import java.util.concurrent.TimeUnit;

/**
* This class represents an extension configuration, values set in this class are commonly used across multiple
Expand All @@ -42,25 +36,9 @@
@Operations({LangchainLLMOperations.class, LangchainEmbeddingStoresOperations.class, LangchainImageModelsOperations.class})
public class LangchainLLMConfiguration implements Initialisable {

private static final Map<LangchainLLMType, BiFunction<ConfigExtractor, LangchainLLMConfiguration, ChatLanguageModel>> llmMap;
private static final Map<ConfigType, Function<LangchainLLMConfiguration, ConfigExtractor>> configExtractorMap;

static {
configExtractorMap = new HashMap<>();
configExtractorMap.put(ConfigType.ENV_VARIABLE, (configuration) -> new EnvConfigExtractor());
configExtractorMap.put(ConfigType.CONFIG_JSON, FileConfigExtractor::new);

llmMap = new HashMap<>();
llmMap.put(LangchainLLMType.OPENAI, (LangchainLLMInitializerUtil::createOpenAiChatModel));
llmMap.put(LangchainLLMType.GROQAI_OPENAI, (LangchainLLMInitializerUtil::createGroqOpenAiChatModel));
llmMap.put(LangchainLLMType.MISTRAL_AI, (LangchainLLMInitializerUtil::createMistralAiChatModel));
llmMap.put(LangchainLLMType.OLLAMA, (LangchainLLMInitializerUtil::createOllamaChatModel));
llmMap.put(LangchainLLMType.ANTHROPIC, (LangchainLLMInitializerUtil::createAnthropicChatModel));
llmMap.put(LangchainLLMType.AZURE_OPENAI, (LangchainLLMInitializerUtil::createAzureOpenAiChatModel));
}

@Parameter
@Placement(order = 1, tab = Placement.DEFAULT_TAB)
@DisplayName("LLM type")
@OfValues(LangchainLLMTypeProvider.class)
private String llmType;

Expand All @@ -71,28 +49,35 @@ public class LangchainLLMConfiguration implements Initialisable {

@Parameter
@Placement(order = 3, tab = Placement.DEFAULT_TAB)
@Optional(defaultValue = "#[-]")
private String filePath;

@Parameter
@Expression(ExpressionSupport.SUPPORTED)
@OfValues(LangchainLLMModelNameProvider.class)
@Optional(defaultValue = "gpt-3.5-turbo")
@Placement(order = 4)
private String modelName = "gpt-3.5-turbo";
@Placement(order = 4, tab = Placement.DEFAULT_TAB)
private String modelName;

@Parameter
@Placement(order = 5)
@Placement(order = 5, tab = Placement.DEFAULT_TAB)
@Optional(defaultValue = "0.7")
private double temperature = 0.7;

@Parameter
@Placement(order = 6)
@Placement(order = 6, tab = Placement.DEFAULT_TAB)
@Optional(defaultValue = "60")
@DisplayName("Duration in sec")
private long durationInSeconds = 60;
@DisplayName("LLM timeout")
private int llmTimeout = 60;

@Parameter
@Optional(defaultValue = "SECONDS")
@Placement(order = 7, tab = Placement.DEFAULT_TAB)
@DisplayName("LLM timeout unit")
@Summary("Time unit to be used in the LLM Timeout")
private TimeUnit llmTimeoutUnit = TimeUnit.SECONDS;

@Parameter
@Placement(order = 7)
@Placement(order = 8, tab = Placement.DEFAULT_TAB)
@Expression(ExpressionSupport.SUPPORTED)
@Optional(defaultValue = "500")
private int maxTokens = 500;
Expand Down Expand Up @@ -121,8 +106,12 @@ public double getTemperature() {
return temperature;
}

public long getDurationInSeconds() {
return durationInSeconds;
public int getLlmTimeout() {
return llmTimeout;
}

public TimeUnit getLlmTimeoutUnit() {
return llmTimeoutUnit;
}

public int getMaxTokens() {
Expand All @@ -138,21 +127,14 @@ public ChatLanguageModel getModel() {
}

private ChatLanguageModel createModel(ConfigExtractor configExtractor) {
LangchainLLMType type = LangchainLLMType.valueOf(llmType);
if (llmMap.containsKey(type)) {
return llmMap.get(type).apply(configExtractor, this);
}
throw new ConfigValidationException("LLM Type not supported: " + llmType);
LangchainLLMType type = LangchainLLMType.fromValue(llmType);
return type.getConfigBiFunction().apply(configExtractor, this);
}

@Override
public void initialise() throws InitialisationException {
ConfigType config = ConfigType.fromValue(configType);
if (configExtractorMap.containsKey(config)) {
configExtractor = configExtractorMap.get(config).apply(this);
model = createModel(configExtractor);
} else {
throw new ConfigValidationException("Config Type not supported: " + configType);
}
configExtractor = config.getConfigExtractorFunction().apply(this);
model = createModel(configExtractor);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
import org.mule.extension.mulechain.internal.config.LangchainLLMConfiguration;
import org.mule.extension.mulechain.internal.llm.config.ConfigExtractor;

import java.time.Duration;

import static java.time.Duration.ofSeconds;

public final class LangchainLLMInitializerUtil {
Expand All @@ -19,12 +21,13 @@ private LangchainLLMInitializerUtil() {}

public static OpenAiChatModel createOpenAiChatModel(ConfigExtractor configExtractor, LangchainLLMConfiguration configuration) {
String openaiApiKey = configExtractor.extractValue("OPENAI_API_KEY");
long durationInSec = configuration.getLlmTimeoutUnit().toSeconds(configuration.getLlmTimeout());
return OpenAiChatModel.builder()
.apiKey(openaiApiKey)
.modelName(configuration.getModelName())
.maxTokens(configuration.getMaxTokens())
.temperature(configuration.getTemperature())
.timeout(ofSeconds(configuration.getDurationInSeconds()))
.timeout(ofSeconds(durationInSec))
.logRequests(true)
.logResponses(true)
.build();
Expand All @@ -34,13 +37,14 @@ public static OpenAiChatModel createOpenAiChatModel(ConfigExtractor configExtrac
public static OpenAiChatModel createGroqOpenAiChatModel(ConfigExtractor configExtractor,
LangchainLLMConfiguration configuration) {
String groqApiKey = configExtractor.extractValue("GROQ_API_KEY");
long durationInSec = configuration.getLlmTimeoutUnit().toSeconds(configuration.getLlmTimeout());
return OpenAiChatModel.builder()
.baseUrl("https://api.groq.com/openai/v1")
.apiKey(groqApiKey)
.modelName(configuration.getModelName())
.maxTokens(configuration.getMaxTokens())
.temperature(configuration.getTemperature())
.timeout(ofSeconds(configuration.getDurationInSeconds()))
.timeout(ofSeconds(durationInSec))
.logRequests(true)
.logResponses(true)
.build();
Expand All @@ -51,40 +55,43 @@ public static OpenAiChatModel createGroqOpenAiChatModel(ConfigExtractor configEx
public static MistralAiChatModel createMistralAiChatModel(ConfigExtractor configExtractor,
LangchainLLMConfiguration configuration) {
String mistralAiApiKey = configExtractor.extractValue("MISTRAL_AI_API_KEY");
long durationInSec = configuration.getLlmTimeoutUnit().toSeconds(configuration.getLlmTimeout());
return MistralAiChatModel.builder()
//.apiKey(configuration.getLlmApiKey())
.apiKey(mistralAiApiKey)
.modelName(configuration.getModelName())
.maxTokens(configuration.getMaxTokens())
.temperature(configuration.getTemperature())
.timeout(ofSeconds(configuration.getDurationInSeconds()))
.timeout(ofSeconds(durationInSec))
.logRequests(true)
.logResponses(true)
.build();
}

public static OllamaChatModel createOllamaChatModel(ConfigExtractor configExtractor, LangchainLLMConfiguration configuration) {
String ollamaBaseUrl = configExtractor.extractValue("OLLAMA_BASE_URL");
long durationInSec = configuration.getLlmTimeoutUnit().toSeconds(configuration.getLlmTimeout());
return OllamaChatModel.builder()
//.baseUrl(configuration.getLlmApiKey())
.baseUrl(ollamaBaseUrl)
.modelName(configuration.getModelName())
.temperature(configuration.getTemperature())
.timeout(ofSeconds(configuration.getDurationInSeconds()))
.timeout(ofSeconds(durationInSec))
.build();
}


public static AnthropicChatModel createAnthropicChatModel(ConfigExtractor configExtractor,
LangchainLLMConfiguration configuration) {
String anthropicApiKey = configExtractor.extractValue("ANTHROPIC_API_KEY");
long durationInSec = configuration.getLlmTimeoutUnit().toSeconds(configuration.getLlmTimeout());
return AnthropicChatModel.builder()
//.apiKey(configuration.getLlmApiKey())
.apiKey(anthropicApiKey)
.modelName(configuration.getModelName())
.maxTokens(configuration.getMaxTokens())
.temperature(configuration.getTemperature())
.timeout(ofSeconds(configuration.getDurationInSeconds()))
.timeout(ofSeconds(durationInSec))
.logRequests(true)
.logResponses(true)
.build();
Expand All @@ -96,13 +103,14 @@ public static AzureOpenAiChatModel createAzureOpenAiChatModel(ConfigExtractor co
String azureOpenaiKey = configExtractor.extractValue("AZURE_OPENAI_KEY");
String azureOpenaiEndpoint = configExtractor.extractValue("AZURE_OPENAI_ENDPOINT");
String azureOpenaiDeploymentName = configExtractor.extractValue("AZURE_OPENAI_DEPLOYMENT_NAME");
long durationInSec = configuration.getLlmTimeoutUnit().toSeconds(configuration.getLlmTimeout());
return AzureOpenAiChatModel.builder()
.apiKey(azureOpenaiKey)
.endpoint(azureOpenaiEndpoint)
.deploymentName(azureOpenaiDeploymentName)
.maxTokens(configuration.getMaxTokens())
.temperature(configuration.getTemperature())
.timeout(ofSeconds(configuration.getDurationInSeconds()))
.timeout(ofSeconds(durationInSec))
.logRequestsAndResponses(true)
.build();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,5 @@

public enum MuleChainErrorType implements ErrorTypeDefinition<MuleChainErrorType> {

AI_SERVICES_FAILURE, IMAGE_ANALYSIS_FAILURE, IMAGE_GENERATION_FAILURE, IMAGE_PROCESSING_FAILURE, FILE_HANDLING_FAILURE, RAG_FAILURE, EMBEDDING_OPERATIONS_FAILURE, TOOLS_OPERATION_FAILURE, VALIDATION_FAILURE
AI_SERVICES_FAILURE, IMAGE_ANALYSIS_FAILURE, IMAGE_GENERATION_FAILURE, IMAGE_PROCESSING_FAILURE, FILE_HANDLING_FAILURE, RAG_FAILURE, EMBEDDING_OPERATIONS_FAILURE, TOOLS_OPERATION_FAILURE, VALIDATION_FAILURE, STREAMING_FAILURE
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
/**
* (c) 2003-2024 MuleSoft, Inc. The software in this package is published under the terms of the Commercial Free Software license V.1 a copy of which has been included with this distribution in the LICENSE.md file.
*/
package org.mule.extension.mulechain.internal.exception;

import org.mule.extension.mulechain.internal.error.MuleChainErrorType;
import org.mule.runtime.extension.api.exception.ModuleException;

public class StreamingException extends ModuleException {

public StreamingException(String message, Throwable throwable) {
super(message, MuleChainErrorType.STREAMING_FAILURE, throwable);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ public class LangchainLLMModelNameProvider implements ValueProvider {

@Override
public Set<Value> resolve() throws ValueResolvingException {
return ValueBuilder.getValuesFor(LangchainLLMType.valueOf(llmType).getModelNameStream());
return ValueBuilder.getValuesFor(LangchainLLMType.fromValue(llmType).getModelNameStream());
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -3,26 +3,38 @@
*/
package org.mule.extension.mulechain.internal.llm.config;

import org.mule.extension.mulechain.internal.config.LangchainLLMConfiguration;
import org.mule.extension.mulechain.internal.exception.config.ConfigValidationException;

import java.util.Arrays;
import java.util.function.Function;

public enum ConfigType {

ENV_VARIABLE("Environment Variables"), CONFIG_JSON("Configuration Json");
ENV_VARIABLE("Environment Variables", (configuration) -> new EnvConfigExtractor()), CONFIG_JSON("Configuration Json",
FileConfigExtractor::new);

private final String value;

ConfigType(String value) {
private final Function<LangchainLLMConfiguration, ConfigExtractor> configExtractorFunction;

ConfigType(String value, Function<LangchainLLMConfiguration, ConfigExtractor> configExtractorFunction) {
this.value = value;
this.configExtractorFunction = configExtractorFunction;
}

public static ConfigType fromValue(String value) {
return Arrays.stream(ConfigType.values())
.filter(configType -> configType.value.equals(value))
.findFirst()
.orElseThrow(() -> new IllegalArgumentException("Unsupported Config Type: " + value));
.orElseThrow(() -> new ConfigValidationException("Unsupported Config Type: " + value));
}

public String getValue() {
return value;
}

public Function<LangchainLLMConfiguration, ConfigExtractor> getConfigExtractorFunction() {
return configExtractorFunction;
}
}
Loading

0 comments on commit 8ab0026

Please sign in to comment.