Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

W-16295334: Added changes required for certification #24

Merged
merged 9 commits into from
Aug 13, 2024
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ http://www.mulesoft.org/schema/mule/ee/core http://www.mulesoft.org/schema/mule/
<http:listener-connection host="0.0.0.0" port="8081" />
</http:listener-config>
<mulechain:llm-configuration-config name="MISTRAL_AI" llmType="MISTRAL_AI" configType="Configuration Json" doc:name="MuleChain AI Llm configuration" doc:id="bf1ef7ec-4aa1-41c8-a184-a13ca165c925" filePath='#[mule.home ++ "/apps/" ++ app.name ++ "/envVars.json"]' modelName="mistral-large-latest" temperature="0.1"/>
<mulechain:llm-configuration-config name="OPENAI" llmType="OPENAI" configType="Configuration Json" doc:name="MuleChain AI Llm configuration" doc:id="edb0d5a6-97c5-4d93-8098-4e197e563827" filePath='#[mule.home ++ "/apps/" ++ app.name ++ "/envVars.json"]' temperature="0.1"/>
<mulechain:llm-configuration-config name="OPENAI" llmType="OPENAI" configType="Configuration Json" doc:name="MuleChain AI Llm configuration" doc:id="edb0d5a6-97c5-4d93-8098-4e197e563827" filePath='#[mule.home ++ "/apps/" ++ app.name ++ "/envVars.json"]' temperature="0.1" modelName="gpt-3.5-turbo"/>
<mulechain:llm-configuration-config name="OPENAI-GPT4-TURBO" llmType="OPENAI" configType="Configuration Json" doc:name="MuleChain AI Llm configuration" doc:id="74b85066-1569-4f10-a06b-e49e854eeef2" filePath='#[mule.home ++ "/apps/" ++ app.name ++ "/envVars.json"]' modelName="gpt-4o" />
<flow name="PromptTemplate" doc:id="cff3a8ed-3799-424a-becf-9d7387729bd0" >
<http:listener doc:name="Listener" doc:id="dd18126e-81f5-48ef-8f35-9dd19afdfaf0" config-ref="HTTP_Listener_config" path="/agent"/>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,15 @@
package org.mule.extension.mulechain.internal.config;

import dev.langchain4j.model.chat.ChatLanguageModel;
import org.mule.extension.mulechain.internal.exception.config.ConfigValidationException;
import org.mule.extension.mulechain.internal.operation.LangchainEmbeddingStoresOperations;
import org.mule.extension.mulechain.internal.operation.LangchainImageModelsOperations;
import org.mule.extension.mulechain.internal.llm.type.LangchainLLMType;
import org.mule.extension.mulechain.internal.llm.ConfigTypeProvider;
import org.mule.extension.mulechain.internal.config.util.LangchainLLMInitializerUtil;
import org.mule.extension.mulechain.internal.operation.LangchainLLMOperations;
import org.mule.extension.mulechain.internal.llm.LangchainLLMModelNameProvider;
import org.mule.extension.mulechain.internal.llm.LangchainLLMTypeProvider;
import org.mule.extension.mulechain.internal.llm.config.ConfigExtractor;
import org.mule.extension.mulechain.internal.llm.config.ConfigType;
import org.mule.extension.mulechain.internal.llm.config.EnvConfigExtractor;
import org.mule.extension.mulechain.internal.llm.config.FileConfigExtractor;
import org.mule.runtime.api.lifecycle.Initialisable;
import org.mule.runtime.api.lifecycle.InitialisationException;
import org.mule.runtime.api.meta.ExpressionSupport;
Expand All @@ -27,12 +23,10 @@
import org.mule.runtime.extension.api.annotation.param.Parameter;
import org.mule.runtime.extension.api.annotation.param.display.DisplayName;
import org.mule.runtime.extension.api.annotation.param.display.Placement;
import org.mule.runtime.extension.api.annotation.param.display.Summary;
import org.mule.runtime.extension.api.annotation.values.OfValues;

import java.util.HashMap;
import java.util.Map;
import java.util.function.BiFunction;
import java.util.function.Function;
import java.util.concurrent.TimeUnit;

/**
* This class represents an extension configuration, values set in this class are commonly used across multiple
Expand All @@ -42,25 +36,9 @@
@Operations({LangchainLLMOperations.class, LangchainEmbeddingStoresOperations.class, LangchainImageModelsOperations.class})
public class LangchainLLMConfiguration implements Initialisable {

private static final Map<LangchainLLMType, BiFunction<ConfigExtractor, LangchainLLMConfiguration, ChatLanguageModel>> llmMap;
private static final Map<ConfigType, Function<LangchainLLMConfiguration, ConfigExtractor>> configExtractorMap;

static {
configExtractorMap = new HashMap<>();
configExtractorMap.put(ConfigType.ENV_VARIABLE, (configuration) -> new EnvConfigExtractor());
configExtractorMap.put(ConfigType.CONFIG_JSON, FileConfigExtractor::new);

llmMap = new HashMap<>();
llmMap.put(LangchainLLMType.OPENAI, (LangchainLLMInitializerUtil::createOpenAiChatModel));
llmMap.put(LangchainLLMType.GROQAI_OPENAI, (LangchainLLMInitializerUtil::createGroqOpenAiChatModel));
llmMap.put(LangchainLLMType.MISTRAL_AI, (LangchainLLMInitializerUtil::createMistralAiChatModel));
llmMap.put(LangchainLLMType.OLLAMA, (LangchainLLMInitializerUtil::createOllamaChatModel));
llmMap.put(LangchainLLMType.ANTHROPIC, (LangchainLLMInitializerUtil::createAnthropicChatModel));
llmMap.put(LangchainLLMType.AZURE_OPENAI, (LangchainLLMInitializerUtil::createAzureOpenAiChatModel));
}

@Parameter
@Placement(order = 1, tab = Placement.DEFAULT_TAB)
@DisplayName("LLM type")
@OfValues(LangchainLLMTypeProvider.class)
private String llmType;

Expand All @@ -71,28 +49,35 @@ public class LangchainLLMConfiguration implements Initialisable {

@Parameter
@Placement(order = 3, tab = Placement.DEFAULT_TAB)
@Optional(defaultValue = "#[-]")
private String filePath;

@Parameter
@Expression(ExpressionSupport.SUPPORTED)
@OfValues(LangchainLLMModelNameProvider.class)
@Optional(defaultValue = "gpt-3.5-turbo")
@Placement(order = 4)
private String modelName = "gpt-3.5-turbo";
@Placement(order = 4, tab = Placement.DEFAULT_TAB)
private String modelName;

@Parameter
@Placement(order = 5)
@Placement(order = 5, tab = Placement.DEFAULT_TAB)
@Optional(defaultValue = "0.7")
private double temperature = 0.7;

@Parameter
@Placement(order = 6)
@Placement(order = 6, tab = Placement.DEFAULT_TAB)
@Optional(defaultValue = "60")
@DisplayName("Duration in sec")
private long durationInSeconds = 60;
@DisplayName("LLM timeout")
private int llmTimeout = 60;

@Parameter
@Optional(defaultValue = "SECONDS")
@Placement(order = 7, tab = Placement.DEFAULT_TAB)
@DisplayName("LLM timeout unit")
@Summary("Time unit to be used in the LLM Timeout")
private TimeUnit llmTimeoutUnit = TimeUnit.SECONDS;

@Parameter
@Placement(order = 7)
@Placement(order = 8, tab = Placement.DEFAULT_TAB)
@Expression(ExpressionSupport.SUPPORTED)
@Optional(defaultValue = "500")
private int maxTokens = 500;
Expand Down Expand Up @@ -121,8 +106,12 @@ public double getTemperature() {
return temperature;
}

public long getDurationInSeconds() {
return durationInSeconds;
public int getLlmTimeout() {
return llmTimeout;
}

public TimeUnit getLlmTimeoutUnit() {
return llmTimeoutUnit;
}

public int getMaxTokens() {
Expand All @@ -138,21 +127,14 @@ public ChatLanguageModel getModel() {
}

private ChatLanguageModel createModel(ConfigExtractor configExtractor) {
LangchainLLMType type = LangchainLLMType.valueOf(llmType);
if (llmMap.containsKey(type)) {
return llmMap.get(type).apply(configExtractor, this);
}
throw new ConfigValidationException("LLM Type not supported: " + llmType);
LangchainLLMType type = LangchainLLMType.fromValue(llmType);
return type.getConfigBiFunction().apply(configExtractor, this);
}

@Override
public void initialise() throws InitialisationException {
ConfigType config = ConfigType.fromValue(configType);
if (configExtractorMap.containsKey(config)) {
configExtractor = configExtractorMap.get(config).apply(this);
model = createModel(configExtractor);
} else {
throw new ConfigValidationException("Config Type not supported: " + configType);
}
configExtractor = config.getConfigExtractorFunction().apply(this);
model = createModel(configExtractor);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
import org.mule.extension.mulechain.internal.config.LangchainLLMConfiguration;
import org.mule.extension.mulechain.internal.llm.config.ConfigExtractor;

import java.time.Duration;

import static java.time.Duration.ofSeconds;

public final class LangchainLLMInitializerUtil {
Expand All @@ -19,12 +21,13 @@ private LangchainLLMInitializerUtil() {}

public static OpenAiChatModel createOpenAiChatModel(ConfigExtractor configExtractor, LangchainLLMConfiguration configuration) {
String openaiApiKey = configExtractor.extractValue("OPENAI_API_KEY");
long durationInSec = configuration.getLlmTimeoutUnit().toSeconds(configuration.getLlmTimeout());
return OpenAiChatModel.builder()
.apiKey(openaiApiKey)
.modelName(configuration.getModelName())
.maxTokens(configuration.getMaxTokens())
.temperature(configuration.getTemperature())
.timeout(ofSeconds(configuration.getDurationInSeconds()))
.timeout(ofSeconds(durationInSec))
.logRequests(true)
.logResponses(true)
.build();
Expand All @@ -34,13 +37,14 @@ public static OpenAiChatModel createOpenAiChatModel(ConfigExtractor configExtrac
public static OpenAiChatModel createGroqOpenAiChatModel(ConfigExtractor configExtractor,
LangchainLLMConfiguration configuration) {
String groqApiKey = configExtractor.extractValue("GROQ_API_KEY");
long durationInSec = configuration.getLlmTimeoutUnit().toSeconds(configuration.getLlmTimeout());
return OpenAiChatModel.builder()
.baseUrl("https://api.groq.com/openai/v1")
.apiKey(groqApiKey)
.modelName(configuration.getModelName())
.maxTokens(configuration.getMaxTokens())
.temperature(configuration.getTemperature())
.timeout(ofSeconds(configuration.getDurationInSeconds()))
.timeout(ofSeconds(durationInSec))
.logRequests(true)
.logResponses(true)
.build();
Expand All @@ -51,40 +55,43 @@ public static OpenAiChatModel createGroqOpenAiChatModel(ConfigExtractor configEx
public static MistralAiChatModel createMistralAiChatModel(ConfigExtractor configExtractor,
LangchainLLMConfiguration configuration) {
String mistralAiApiKey = configExtractor.extractValue("MISTRAL_AI_API_KEY");
long durationInSec = configuration.getLlmTimeoutUnit().toSeconds(configuration.getLlmTimeout());
return MistralAiChatModel.builder()
//.apiKey(configuration.getLlmApiKey())
.apiKey(mistralAiApiKey)
.modelName(configuration.getModelName())
.maxTokens(configuration.getMaxTokens())
.temperature(configuration.getTemperature())
.timeout(ofSeconds(configuration.getDurationInSeconds()))
.timeout(ofSeconds(durationInSec))
.logRequests(true)
.logResponses(true)
.build();
}

public static OllamaChatModel createOllamaChatModel(ConfigExtractor configExtractor, LangchainLLMConfiguration configuration) {
String ollamaBaseUrl = configExtractor.extractValue("OLLAMA_BASE_URL");
long durationInSec = configuration.getLlmTimeoutUnit().toSeconds(configuration.getLlmTimeout());
return OllamaChatModel.builder()
//.baseUrl(configuration.getLlmApiKey())
.baseUrl(ollamaBaseUrl)
.modelName(configuration.getModelName())
.temperature(configuration.getTemperature())
.timeout(ofSeconds(configuration.getDurationInSeconds()))
.timeout(ofSeconds(durationInSec))
.build();
}


public static AnthropicChatModel createAnthropicChatModel(ConfigExtractor configExtractor,
LangchainLLMConfiguration configuration) {
String anthropicApiKey = configExtractor.extractValue("ANTHROPIC_API_KEY");
long durationInSec = configuration.getLlmTimeoutUnit().toSeconds(configuration.getLlmTimeout());
return AnthropicChatModel.builder()
//.apiKey(configuration.getLlmApiKey())
.apiKey(anthropicApiKey)
.modelName(configuration.getModelName())
.maxTokens(configuration.getMaxTokens())
.temperature(configuration.getTemperature())
.timeout(ofSeconds(configuration.getDurationInSeconds()))
.timeout(ofSeconds(durationInSec))
.logRequests(true)
.logResponses(true)
.build();
Expand All @@ -96,13 +103,14 @@ public static AzureOpenAiChatModel createAzureOpenAiChatModel(ConfigExtractor co
String azureOpenaiKey = configExtractor.extractValue("AZURE_OPENAI_KEY");
String azureOpenaiEndpoint = configExtractor.extractValue("AZURE_OPENAI_ENDPOINT");
String azureOpenaiDeploymentName = configExtractor.extractValue("AZURE_OPENAI_DEPLOYMENT_NAME");
long durationInSec = configuration.getLlmTimeoutUnit().toSeconds(configuration.getLlmTimeout());
return AzureOpenAiChatModel.builder()
.apiKey(azureOpenaiKey)
.endpoint(azureOpenaiEndpoint)
.deploymentName(azureOpenaiDeploymentName)
.maxTokens(configuration.getMaxTokens())
.temperature(configuration.getTemperature())
.timeout(ofSeconds(configuration.getDurationInSeconds()))
.timeout(ofSeconds(durationInSec))
.logRequestsAndResponses(true)
.build();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,5 @@

public enum MuleChainErrorType implements ErrorTypeDefinition<MuleChainErrorType> {

AI_SERVICES_FAILURE, IMAGE_ANALYSIS_FAILURE, IMAGE_GENERATION_FAILURE, IMAGE_PROCESSING_FAILURE, FILE_HANDLING_FAILURE, RAG_FAILURE, EMBEDDING_OPERATIONS_FAILURE, TOOLS_OPERATION_FAILURE, VALIDATION_FAILURE
AI_SERVICES_FAILURE, IMAGE_ANALYSIS_FAILURE, IMAGE_GENERATION_FAILURE, IMAGE_PROCESSING_FAILURE, FILE_HANDLING_FAILURE, RAG_FAILURE, EMBEDDING_OPERATIONS_FAILURE, TOOLS_OPERATION_FAILURE, VALIDATION_FAILURE, STREAMING_FAILURE
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
/**
* (c) 2003-2024 MuleSoft, Inc. The software in this package is published under the terms of the Commercial Free Software license V.1 a copy of which has been included with this distribution in the LICENSE.md file.
*/
package org.mule.extension.mulechain.internal.exception;

import org.mule.extension.mulechain.internal.error.MuleChainErrorType;
import org.mule.runtime.extension.api.exception.ModuleException;

public class StreamingException extends ModuleException {

public StreamingException(String message, Throwable throwable) {
super(message, MuleChainErrorType.STREAMING_FAILURE, throwable);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ public class LangchainLLMModelNameProvider implements ValueProvider {

@Override
public Set<Value> resolve() throws ValueResolvingException {
return ValueBuilder.getValuesFor(LangchainLLMType.valueOf(llmType).getModelNameStream());
return ValueBuilder.getValuesFor(LangchainLLMType.fromValue(llmType).getModelNameStream());
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -3,26 +3,38 @@
*/
package org.mule.extension.mulechain.internal.llm.config;

import org.mule.extension.mulechain.internal.config.LangchainLLMConfiguration;
import org.mule.extension.mulechain.internal.exception.config.ConfigValidationException;

import java.util.Arrays;
import java.util.function.Function;

public enum ConfigType {

ENV_VARIABLE("Environment Variables"), CONFIG_JSON("Configuration Json");
ENV_VARIABLE("Environment Variables", (configuration) -> new EnvConfigExtractor()), CONFIG_JSON("Configuration Json",
FileConfigExtractor::new);

private final String value;

ConfigType(String value) {
private final Function<LangchainLLMConfiguration, ConfigExtractor> configExtractorFunction;

ConfigType(String value, Function<LangchainLLMConfiguration, ConfigExtractor> configExtractorFunction) {
this.value = value;
this.configExtractorFunction = configExtractorFunction;
}

public static ConfigType fromValue(String value) {
return Arrays.stream(ConfigType.values())
.filter(configType -> configType.value.equals(value))
.findFirst()
.orElseThrow(() -> new IllegalArgumentException("Unsupported Config Type: " + value));
.orElseThrow(() -> new ConfigValidationException("Unsupported Config Type: " + value));
}

public String getValue() {
return value;
}

public Function<LangchainLLMConfiguration, ConfigExtractor> getConfigExtractorFunction() {
return configExtractorFunction;
}
}
Loading