From effd76edfa912ce5d889beae3b6b9609eed25bb0 Mon Sep 17 00:00:00 2001 From: oandreeva-nv Date: Mon, 21 Oct 2024 17:19:40 -0700 Subject: [PATCH 01/15] Initial readme and sematic caching implementation --- .../Part_8-semantic_caching/README.md | 104 ++++++++++ .../artifacts/semantic_caching.py | 177 ++++++++++++++++++ 2 files changed, 281 insertions(+) create mode 100644 Conceptual_Guide/Part_8-semantic_caching/README.md create mode 100644 Conceptual_Guide/Part_8-semantic_caching/artifacts/semantic_caching.py diff --git a/Conceptual_Guide/Part_8-semantic_caching/README.md b/Conceptual_Guide/Part_8-semantic_caching/README.md new file mode 100644 index 00000000..d24c3c79 --- /dev/null +++ b/Conceptual_Guide/Part_8-semantic_caching/README.md @@ -0,0 +1,104 @@ + + +# Semantic caching + +When deploying large language models (LLMs) or LLM-based workflows +there are two key factors to consider: the performance and cost-efficiency +of your application. Generating language model outputs requires significant +computational resources, for example GPU time, memory usage, and other +infrastructure costs. These resource-intensive requirements create a +pressing need for optimization strategies that can maintain +high-quality outputs while minimizing operational expenses. + +Semantic caching emerges as a powerful solution to reduce computational costs +for LLM-based applications. Unlike traditional caching, it considers +the content and context of incoming requests. + +## Definition and Main Benefits + +**_Semantic caching_** is a caching mechanism that takes into account +the semantics of the incoming request, rather than just the raw data itself. +It goes beyond simple key-value pairs and considers the content or +context of the data. + +This approach offers several benefits including, but not limited to: + ++ Cost Optimization + +Semantic caching can substantially reduce operational expenses associated +with LLM deployments. By storing and reusing responses for semantically +similar queries, it minimizes the number of actual LLM calls required. + ++ Reduced Latency + +One of the primary benefits of semantic caching is its ability to significantly +improve response times. By retrieving cached responses for similar queries, +the system can bypass the need for full model inference, +resulting in the reduced latency. + ++ Increased Throughput + +Semantic caching allows for more efficient utilization of computational +resources. By serving cached responses for similar queries, it reduces the load +on infrastructure components. This efficiency enables the system to handle +a higher volume of requests with the same hardware, effectively increasing +throughput. + ++ Scalability + +The improved resource efficiency and reduced computational demands allows +applications to serve more users without a proportional increase in +infrastructure costs. + ++ Consistency in Responses + +For certain applications, maintaining consistency in responses to +similar queries can be beneficial. Semantic caching ensures that analogous +questions receive uniform answers, which can be particularly useful +in scenarios like customer service or educational applications. + +## Sample Reference Implementation + +In this tutorial we provide a reference implementation for Semantic Cache in +[semantic_caching.py](tutorials/Conceptual_Guide/Part_8-semantic_caching/artifacts/semantic_caching.py) + +## Further optimisations + +## Interested in This Feature? + +While this reference implementation provides a glimpse into the potential +of semantic caching, it's important to note that it's not an officially +supported feature in Triton Inference Server. + +We value your input! If you're interested in seeing semantic caching as a +supported feature in future releases, we encourage you to [FILL IN] + +Provide details about why you think semantic caching would be valuable for +your use case. Your feedback helps shape our product roadmap, +and we appreciate your contributions to making our software better for everyone. \ No newline at end of file diff --git a/Conceptual_Guide/Part_8-semantic_caching/artifacts/semantic_caching.py b/Conceptual_Guide/Part_8-semantic_caching/artifacts/semantic_caching.py new file mode 100644 index 00000000..e5779f8e --- /dev/null +++ b/Conceptual_Guide/Part_8-semantic_caching/artifacts/semantic_caching.py @@ -0,0 +1,177 @@ +import itertools +from dataclasses import dataclass +from typing import Any, Dict, Hashable, Optional + +import faiss +import numpy as np +from sentence_transformers import SentenceTransformer +from theine import Cache + + +class KeyMapper: + """ + A class to manage bidirectional mapping between hashable keys and integer IDs. + """ + + def __init__(self): + self.hk_map: Dict[Hashable, int] = {} + self.kh_map: Dict[int, Hashable] = {} + self.counter = itertools.count() + + def add_key(self, key: Hashable): + """ + Add a new key to the mapper and return its assigned ID. + + Args: + key (Hashable): The key to be added. + + Returns: + int: The assigned ID for the key. + """ + if key in self.hk_map.keys(): + return None + id = next(self.counter) + self.hk_map[key] = id + self.kh_map[id] = key + return id + + def remove_key(self, key: Hashable): + """ + Remove key from the mapper and return its ID. + + Args: + key (Hashable): The key to be removed. + + Returns: + int: The ID for the removed key. + """ + id = self.hk_map.pop(key, None) + if id is not None: + self.kh_map.pop(id, None) + return id + return None + + def get_key(self, id: int): + """ + Retrieve the key associated with the given ID. + + Args: + id (int): The ID to look up. + + Returns: + Optional[Hashable]: The associated key, or None if not found. + """ + return self.kh_map.get(id) + + def get_id(self, key: Hashable): + """ + Retrieve the ID associated with the given key. + + Args: + key (Hashable): The key to look up. + + Returns: + Optional[int]: The associated ID, or None if not found. + """ + return self.hk_map.get(key) + + +@dataclass +class SemanticCPUCacheConfig: + """ + Configuration class for SemanticCPUCache. + + Attributes: + cache (Any): The cache object to use. + encoder (Any): The encoder object for embedding queries. + index (Any): The index object for similarity search. + threshold (float): The similarity threshold for considering a match. + key_mapper (Any): The key mapper object for managing key-ID mappings. + """ + + cache: Any = Cache(policy="lru", size=1000) + encoder: Any = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2") + index: Any = faiss.IndexIDMap(faiss.IndexFlatL2(384)) + threshold: float = 0.25 + key_mapper: Any = KeyMapper() + + +class SemanticCPUCache: + """ + Semantic cache implementation. + """ + + def __init__(self, config: SemanticCPUCacheConfig): + """ + Initialize the SemanticCPUCache with the given configuration. + + Args: + config (SemanticCPUCacheConfig): The configuration object. + """ + self.encoder = config.encoder + self.index = config.index + self.cache = config.cache + self.key_map = config.key_mapper + self.threshold = config.threshold + + def get(self, key: Hashable, default: Any = None) -> Any: + """ + Retrieve a value from the cache based on the given key. + + First, a similarity search is performed. If a similar key is found + within the threshold, its associated value is returned. + Otherwise, the default value is returned. + + Args: + key (Hashable): The key to look up. + default (Any, optional): The default value to return if no match is found. Defaults to None. + + Returns: + Any: The retrieved value or the default value. + """ + if self.index.ntotal < 1: + return default + + key_search = np.asarray([self.encoder.encode(key)]) + dist, ind = self.index.search(key_search, 1) + + if dist[0][0] > self.threshold: + return default + + key_str = self.key_map.get_key(ind[0][0]) + + return self.cache.get(key=key_str, default=default) + + def set(self, key: Hashable, value: Any) -> Optional[str]: + """ + Set a key-value pair in the cache. + + This method adds the key to the key mapper, encodes the key, + adds the encoded key to the index, and sets the value in the cache. + + Args: + key (Hashable): The key to set. + value (Any): The value to associate with the key. + + Returns: + Optional[str]: The result of setting the value in the cache. + + Raises: + AssertionError: If the key could not be added to the key mapper. + """ + id = self.key_map.add_key(key) + assert id is not None, "Adding key to the key map failed, returned id is None." + self.index.add_with_ids( + np.expand_dims(self.encoder.encode(key), axis=0), np.asarray([id]) + ) + + evicted_key = self.cache.set(key, value) + self._handle_evicted_key(evicted_key=evicted_key) + + return None + + def _handle_evicted_key(self, evicted_key: Optional[Hashable]) -> None: + if evicted_key: + evicted_id = self.key_map.remove_key(evicted_key) + self.index.remove_ids(np.array([evicted_id])) + return None From df692db1bb3a79890b8c94157926f8460694eca4 Mon Sep 17 00:00:00 2001 From: oandreeva-nv Date: Mon, 21 Oct 2024 17:52:03 -0700 Subject: [PATCH 02/15] Formatting --- .../Part_8-semantic_caching/README.md | 48 +++++++++---------- 1 file changed, 24 insertions(+), 24 deletions(-) diff --git a/Conceptual_Guide/Part_8-semantic_caching/README.md b/Conceptual_Guide/Part_8-semantic_caching/README.md index d24c3c79..03a99d4d 100644 --- a/Conceptual_Guide/Part_8-semantic_caching/README.md +++ b/Conceptual_Guide/Part_8-semantic_caching/README.md @@ -49,39 +49,39 @@ context of the data. This approach offers several benefits including, but not limited to: -+ Cost Optimization ++ **Cost Optimization** -Semantic caching can substantially reduce operational expenses associated -with LLM deployments. By storing and reusing responses for semantically -similar queries, it minimizes the number of actual LLM calls required. + - Semantic caching can substantially reduce operational expenses associated + with LLM deployments. By storing and reusing responses for semantically + similar queries, it minimizes the number of actual LLM calls required. -+ Reduced Latency ++ **Reduced Latency** -One of the primary benefits of semantic caching is its ability to significantly -improve response times. By retrieving cached responses for similar queries, -the system can bypass the need for full model inference, -resulting in the reduced latency. + - One of the primary benefits of semantic caching is its ability to + significantly improve response times. By retrieving cached responses for + similar queries, the system can bypass the need for full model inference, + resulting in the reduced latency. -+ Increased Throughput ++ **Increased Throughput** -Semantic caching allows for more efficient utilization of computational -resources. By serving cached responses for similar queries, it reduces the load -on infrastructure components. This efficiency enables the system to handle -a higher volume of requests with the same hardware, effectively increasing -throughput. + - Semantic caching allows for more efficient utilization of computational + resources. By serving cached responses for similar queries, it reduces the + load on infrastructure components. This efficiency enables the system + to handle a higher volume of requests with the same hardware, effectively + increasing throughput. -+ Scalability ++ **Scalability** -The improved resource efficiency and reduced computational demands allows -applications to serve more users without a proportional increase in -infrastructure costs. + - The improved resource efficiency and reduced computational demands allows + applications to serve more users without a proportional increase in + infrastructure costs. -+ Consistency in Responses ++ **Consistency in Responses** -For certain applications, maintaining consistency in responses to -similar queries can be beneficial. Semantic caching ensures that analogous -questions receive uniform answers, which can be particularly useful -in scenarios like customer service or educational applications. + - For certain applications, maintaining consistency in responses to + similar queries can be beneficial. Semantic caching ensures that analogous + questions receive uniform answers, which can be particularly useful + in scenarios like customer service or educational applications. ## Sample Reference Implementation From 8e7f0f3c757839634b67abdab4bc19b1fe918308 Mon Sep 17 00:00:00 2001 From: oandreeva-nv Date: Mon, 21 Oct 2024 17:52:55 -0700 Subject: [PATCH 03/15] Fixed links --- Conceptual_Guide/Part_8-semantic_caching/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Conceptual_Guide/Part_8-semantic_caching/README.md b/Conceptual_Guide/Part_8-semantic_caching/README.md index 03a99d4d..2d41618b 100644 --- a/Conceptual_Guide/Part_8-semantic_caching/README.md +++ b/Conceptual_Guide/Part_8-semantic_caching/README.md @@ -86,7 +86,7 @@ This approach offers several benefits including, but not limited to: ## Sample Reference Implementation In this tutorial we provide a reference implementation for Semantic Cache in -[semantic_caching.py](tutorials/Conceptual_Guide/Part_8-semantic_caching/artifacts/semantic_caching.py) +[semantic_caching.py](./artifacts/semantic_caching.py) ## Further optimisations From 22c412e6d0d1b7dfbb823f3062019bf080fb396f Mon Sep 17 00:00:00 2001 From: oandreeva-nv Date: Mon, 21 Oct 2024 17:56:36 -0700 Subject: [PATCH 04/15] Added missing annotation --- .../artifacts/semantic_caching.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/Conceptual_Guide/Part_8-semantic_caching/artifacts/semantic_caching.py b/Conceptual_Guide/Part_8-semantic_caching/artifacts/semantic_caching.py index e5779f8e..dbe550d8 100644 --- a/Conceptual_Guide/Part_8-semantic_caching/artifacts/semantic_caching.py +++ b/Conceptual_Guide/Part_8-semantic_caching/artifacts/semantic_caching.py @@ -171,6 +171,17 @@ def set(self, key: Hashable, value: Any) -> Optional[str]: return None def _handle_evicted_key(self, evicted_key: Optional[Hashable]) -> None: + """ + Handle the eviction of a key from the cache. + + This method is called when a key is evicted from the cache. It removes + the evicted key from the key_map and its corresponding + vector embedding from the index. + + Args: + evicted_key (Optional[Hashable]): The key that was evicted from the + cache. + """ if evicted_key: evicted_id = self.key_map.remove_key(evicted_key) self.index.remove_ids(np.array([evicted_id])) From 0fb90d685b916b5256ac542cefe2dc78530c5b68 Mon Sep 17 00:00:00 2001 From: oandreeva-nv Date: Tue, 22 Oct 2024 16:23:51 -0700 Subject: [PATCH 05/15] Follow ups --- .../Part_8-semantic_caching/README.md | 228 +++++++++++++++++- .../artifacts/semantic_caching.py | 22 +- Conceptual_Guide/README.md | 1 + 3 files changed, 247 insertions(+), 4 deletions(-) diff --git a/Conceptual_Guide/Part_8-semantic_caching/README.md b/Conceptual_Guide/Part_8-semantic_caching/README.md index 2d41618b..9f13bb80 100644 --- a/Conceptual_Guide/Part_8-semantic_caching/README.md +++ b/Conceptual_Guide/Part_8-semantic_caching/README.md @@ -86,9 +86,233 @@ This approach offers several benefits including, but not limited to: ## Sample Reference Implementation In this tutorial we provide a reference implementation for Semantic Cache in -[semantic_caching.py](./artifacts/semantic_caching.py) +[semantic_caching.py.](./artifacts/semantic_caching.py) There are 3 key +dependencies: +* [SentenceTransformer](https://sbert.net/): a Python framework for computing +dense vector representations (embeddings) of sentences, paragraphs, and images. + - We use this library and `all-MiniLM-L6-v2` in particular to convert + incoming prompt into an embedding, enabling semantic comparison. + - Alternatives include [semantic search models](https://www.sbert.net/docs/sentence_transformer/pretrained_models.html#semantic-search-models), + OpenAI Embeddings, etc. +* [Faiss](https://github.com/facebookresearch/faiss/wiki): an open-source library +developed by Facebook AI Research for efficient similarity search and +clustering of dense vectors. + - This library is used for the embedding store and extracting the most + similar embedded prompt from the cached requests (or from the index store). + - This is a mighty library with a great variety of CPu and GPU accelerated + algorithms. + - Alternatives include [annoy](https://github.com/spotify/annoy), or + [cuVS](https://github.com/rapidsai/cuvs). However, note that cuVS already + has an integration in Faiss, more on this can be found [here.](https://docs.rapids.ai/api/cuvs/nightly/integrations/faiss/) +* [Theine](https://github.com/Yiling-J/theine): High performance in-memory +cache. + - We will use it as our exact match cache backend. After the most similar + prompt is identified, the corresponding cached response id extracted from + the cache. This library supports multiple eviction policies, in this + tutorial we use "LRU". + - One may also look into [MemCached](https://memcached.org/about) as a + potential alternative. -## Further optimisations +Provided [script](./artifacts/semantic_caching.py) is heavily annotated and we +encourage users to look through the code to gain better clarity in all +the necessary stages. + +## Incorporating Semantic Cache into your workflow + +For this tutorial, we'll use the [vllm backend](https://github.com/triton-inference-server/vllm_backend) +as our example, focusing on demonstrating how to cache responses for the +non-streaming case. The principles covered here can be extended to handle +streaming scenarios as well. + +### Cutomising vllm backend + +First, let's start by cloning Triton's vllm backend repository. This will +provide the necessary codebase to implement our semantic caching example. + +``bash +git clone https://github.com/triton-inference-server/vllm_backend.git +``` + +With the repository cloned, the next step is to add the +[semantic_caching.py.](./artifacts/semantic_caching.py) script to +the appropriate directory. This script contains the logic for our semantic +caching implementation. + +```bash +wget -P vllm_backend/src/utils/ https://raw.githubusercontent.com/triton-inference-server/tutorials/refs/heads/main/Conceptual_Guide/Part_8-semantic_caching/artifacts/semantic_caching.py +``` + +Now that we have added the semantic caching script, let's proceed by making +some adjustments in `/vllm_backend/src/model.py`. These changes will integrate +the semantic caching functionality into the model. + +First, ensure that you import the necessary classes from `semantic_caching.py`: + +```diff +... + +from utils.metrics import VllmStatLogger ++from utils.semantic_caching import SemanticCPUCacheConfig, SemanticCPUCache +``` + +Next, initialize the semantic cache during the initialization step. +This setup will prepare your model to utilize semantic caching during +its operations. + +```diff + def initialize(self, args): + self.args = args + self.logger = pb_utils.Logger + self.model_config = json.loads(args["model_config"]) + ... + + # Starting asyncio event loop to process the received requests asynchronously. + self._loop = asyncio.get_event_loop() + self._event_thread = threading.Thread( + target=self.engine_loop, args=(self._loop,) + ) + self._shutdown_event = asyncio.Event() + self._event_thread.start() ++ self.semantic_cache = SemanticCPUCache(SemanticCPUCacheConfig) + +``` + +Finally, we'll add logic to query and update the semantic cache during +request processing. This ensures that cached responses are efficiently utilized +whenever possible. + +```diff + async def generate(self, request): + ... + try: + request_id = random_uuid() + prompt = pb_utils.get_input_tensor_by_name( + request, "text_input" + ).as_numpy()[0] + ... + + if prepend_input and stream: + raise ValueError( + "When streaming, `exclude_input_in_output` = False is not allowed." + ) ++ cache_hit = self.semantic_cache.get(prompt) ++ if cache_hit: ++ try: ++ response_sender.send( ++ self.create_response(cache_hit, prepend_input), ++ flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL, ++ ) ++ if decrement_ongoing_request_count: ++ self.ongoing_request_count -= 1 ++ except Exception as err: ++ print(f"Unexpected {err=} for prompt {prompt}") ++ return None + ... + + async for output in response_iterator: + ... + + last_output = output + + if not stream: + response_sender.send( + self.create_response(last_output, prepend_input), + flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL, + ) ++ self.semantic_cache.set(prompt, last_output) + +``` + +### Launching Triton with Optimized vLLM Backend + +To evaluate or optimized vllm backend, let's start vllm docker container and +mount our implementation to `/opt/tritonserver/backends/vllm`. We'll +also mount sample model repository, provided in +`vllm_backend/samples/model_repository`. Feel free to set up your own. +Use the following docker command to start Triton's vllm docker container, +but make sure to specify proper paths to the cloned `vllm_backend` +repository and replace `` with the latest release of Triton. + +```bash +docker run --gpus all -it --net=host --rm -p 8001:8001 --shm-size=1G \ +--ulimit memlock=-1 --ulimit stack=67108864 \ +-v /path/to/vllm_backend/src/:/opt/tritonserver/backends/vllm \ +-v /path/to/vllm_backend/samples/model_repository:/work/model_repository \ +-w /work nvcr.io/nvidia/tritonserver:-vllm-python-py3 +``` + +When inside the container, make sure to install required dependencies: +```bash +pip install sentence_transformers faiss_gpu theine +``` + +Finally, let's launch Triton +```bash +tritonserver --model-repository=model_repository/ +``` + +After you start Triton you will see output on the console showing +the server starting up and loading the model. When you see output +like the following, Triton is ready to accept inference requests. + +``` +I1030 22:33:28.291908 1 grpc_server.cc:2513] Started GRPCInferenceService at 0.0.0.0:8001 +I1030 22:33:28.292879 1 http_server.cc:4497] Started HTTPService at 0.0.0.0:8000 +I1030 22:33:28.335154 1 http_server.cc:270] Started Metrics Service at 0.0.0.0:8002 +``` + +### Evaluation + +After you [start Triton](#launching-triton-with-optimized-vllm-backend) +with the sample model_repository, you can quickly run your first inference +request with the +[generate endpoint](https://github.com/triton-inference-server/server/blob/main/docs/protocol/extension_generate.md). + +We'll also time this query: + +```bash +time curl -X POST localhost:8000/v2/models/vllm_model/generate -d '{"text_input": "Tell me, how do I create model repository for Triton Server?", "parameters": {"stream": false, "temperature": 0, "max_tokens":100}, "exclude_input_in_output":true}' +``` + +Upon success, you should see a response from the server like this one: +``` +{"model_name":"vllm_model","model_version":"1","text_output": } +real 0m1.128s +user 0m0.000s +sys 0m0.015s +``` + +Now, let's try a different response, but keep the semantics: + +```bash +time curl -X POST localhost:8000/v2/models/vllm_model/generate -d '{"text_input": "How do I set up model repository for Triton Inference Server?", "parameters": {"stream": false, "temperature": 0, "max_tokens":100}, "exclude_input_in_output":true} +``` + +Upon success, you should see a response from the server like this one: +``` +{"model_name":"vllm_model","model_version":"1","text_output": } +real 0m0.038s +user 0m0.000s +sys 0m0.017s +``` + +Let's try one more: + +```bash +time curl -X POST localhost:8000/v2/models/vllm_model/generate -d '{"text_input": "How model repository should be set up for Triton Server?", "parameters": {"stream": false, "temperature": 0, "max_tokens":100}, "exclude_input_in_output":true}' +``` + +Upon success, you should see a response from the server like this one: +``` +{"model_name":"vllm_model","model_version":"1","text_output": } +real 0m0.059s +user 0m0.016s +sys 0m0.000s +``` + +Clearly, the latter 2 requests are semantically similar to the first one, which +resulted in a cache hit scenario, which reduced the latency of our model from +approx 1.1s to the average of 0.048s per request. ## Interested in This Feature? diff --git a/Conceptual_Guide/Part_8-semantic_caching/artifacts/semantic_caching.py b/Conceptual_Guide/Part_8-semantic_caching/artifacts/semantic_caching.py index dbe550d8..703abe19 100644 --- a/Conceptual_Guide/Part_8-semantic_caching/artifacts/semantic_caching.py +++ b/Conceptual_Guide/Part_8-semantic_caching/artifacts/semantic_caching.py @@ -83,15 +83,23 @@ class SemanticCPUCacheConfig: Attributes: cache (Any): The cache object to use. + Default: Cache(policy="lru", size=1000). encoder (Any): The encoder object for embedding queries. + Default: SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2") + encoder_dim (int): The encoder dimension. + Default: 384. The size of `all-MiniLM-L6-v2` embeddings. index (Any): The index object for similarity search. + Default: faiss.IndexIDMap(faiss.IndexFlatL2(encoder_dim)) threshold (float): The similarity threshold for considering a match. + Default: 0.25 key_mapper (Any): The key mapper object for managing key-ID mappings. + default: KeyMapper() """ cache: Any = Cache(policy="lru", size=1000) encoder: Any = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2") - index: Any = faiss.IndexIDMap(faiss.IndexFlatL2(384)) + encoder_dim: int = 384 + index: Any = faiss.IndexIDMap(faiss.IndexFlatL2(encoder_dim)) threshold: float = 0.25 key_mapper: Any = KeyMapper() @@ -133,11 +141,19 @@ def get(self, key: Hashable, default: Any = None) -> Any: return default key_search = np.asarray([self.encoder.encode(key)]) + # The vector index returns two values, distance to the most similar + # embedding (1 indicates we only need top 1 similar result), and + # its numerical index. dist, ind = self.index.search(key_search, 1) + # If the distance between vectors above the set threshold, i.e. + # the most similar embedding is too far from the current prompt + # embedding, this considered as cache miss and we return the `default`. if dist[0][0] > self.threshold: return default + # To retrieve the cache hit from the cache store, we need to retrieve + # the corresponding prompt from the key_map store, given its index. key_str = self.key_map.get_key(ind[0][0]) return self.cache.get(key=key_str, default=default) @@ -164,7 +180,9 @@ def set(self, key: Hashable, value: Any) -> Optional[str]: self.index.add_with_ids( np.expand_dims(self.encoder.encode(key), axis=0), np.asarray([id]) ) - + # Adding a new entry into the cache can evict an old entry, according + # to the policy in-use. We need to make sure we evict the same entry + # from the vector index, stored in `self.index`. evicted_key = self.cache.set(key, value) self._handle_evicted_key(evicted_key=evicted_key) diff --git a/Conceptual_Guide/README.md b/Conceptual_Guide/README.md index 115f96e9..d0a44b5c 100644 --- a/Conceptual_Guide/README.md +++ b/Conceptual_Guide/README.md @@ -40,3 +40,4 @@ Conceptual guides have been designed as an onboarding experience to Triton Infer * [Part 5: Building Model Ensembles](./Part_5-Model_Ensembles/): Models are rarely used standalone. This guide will cover "how to build a deep learning inference pipeline?" * [Part 6: Using the BLS API to build complex pipelines](Part_6-building_complex_pipelines/): Often times there are scenarios where the pipeline requires control flows. Learn how to work with complex pipelines with models deployed on different backends. * [Part 7: Iterative Scheduling Tutorial](./Part_7-iterative_scheduling): Shows how to use the Triton Iterative Scheduler with a GPT2 model using HuggingFace Transformers. +* [Part 8: Semantic Caching](./Part_8-semantic_caching/): Shows benefits of adding semantic caching to you LLM-based workflow. From d11a5eab3ba331a9cc581db7aca0c496afbda663 Mon Sep 17 00:00:00 2001 From: Olga Andreeva <124622579+oandreeva-nv@users.noreply.github.com> Date: Wed, 23 Oct 2024 14:41:04 -0700 Subject: [PATCH 06/15] Apply suggestions from code review Co-authored-by: Ryan McCormick --- .../Part_8-semantic_caching/README.md | 32 +++++++++---------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/Conceptual_Guide/Part_8-semantic_caching/README.md b/Conceptual_Guide/Part_8-semantic_caching/README.md index 9f13bb80..45bbd3fa 100644 --- a/Conceptual_Guide/Part_8-semantic_caching/README.md +++ b/Conceptual_Guide/Part_8-semantic_caching/README.md @@ -26,7 +26,7 @@ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. --> -# Semantic caching +# Semantic Caching When deploying large language models (LLMs) or LLM-based workflows there are two key factors to consider: the performance and cost-efficiency @@ -37,10 +37,9 @@ pressing need for optimization strategies that can maintain high-quality outputs while minimizing operational expenses. Semantic caching emerges as a powerful solution to reduce computational costs -for LLM-based applications. Unlike traditional caching, it considers -the content and context of incoming requests. +for LLM-based applications. -## Definition and Main Benefits +## Definition and Benefits **_Semantic caching_** is a caching mechanism that takes into account the semantics of the incoming request, rather than just the raw data itself. @@ -60,7 +59,7 @@ This approach offers several benefits including, but not limited to: - One of the primary benefits of semantic caching is its ability to significantly improve response times. By retrieving cached responses for similar queries, the system can bypass the need for full model inference, - resulting in the reduced latency. + resulting in reduced latency. + **Increased Throughput** @@ -85,7 +84,7 @@ This approach offers several benefits including, but not limited to: ## Sample Reference Implementation -In this tutorial we provide a reference implementation for Semantic Cache in +In this tutorial we provide a reference implementation for a Semantic Cache in [semantic_caching.py.](./artifacts/semantic_caching.py) There are 3 key dependencies: * [SentenceTransformer](https://sbert.net/): a Python framework for computing @@ -99,7 +98,7 @@ developed by Facebook AI Research for efficient similarity search and clustering of dense vectors. - This library is used for the embedding store and extracting the most similar embedded prompt from the cached requests (or from the index store). - - This is a mighty library with a great variety of CPu and GPU accelerated + - This is a mighty library with a great variety of CPU and GPU accelerated algorithms. - Alternatives include [annoy](https://github.com/spotify/annoy), or [cuVS](https://github.com/rapidsai/cuvs). However, note that cuVS already @@ -107,7 +106,7 @@ clustering of dense vectors. * [Theine](https://github.com/Yiling-J/theine): High performance in-memory cache. - We will use it as our exact match cache backend. After the most similar - prompt is identified, the corresponding cached response id extracted from + prompt is identified, the corresponding cached response is extracted from the cache. This library supports multiple eviction policies, in this tutorial we use "LRU". - One may also look into [MemCached](https://memcached.org/about) as a @@ -124,12 +123,12 @@ as our example, focusing on demonstrating how to cache responses for the non-streaming case. The principles covered here can be extended to handle streaming scenarios as well. -### Cutomising vllm backend +### Customising vLLM Backend First, let's start by cloning Triton's vllm backend repository. This will provide the necessary codebase to implement our semantic caching example. -``bash +```bash git clone https://github.com/triton-inference-server/vllm_backend.git ``` @@ -143,7 +142,7 @@ wget -P vllm_backend/src/utils/ https://raw.githubusercontent.com/triton-inferen ``` Now that we have added the semantic caching script, let's proceed by making -some adjustments in `/vllm_backend/src/model.py`. These changes will integrate +some adjustments in `vllm_backend/src/model.py`. These changes will integrate the semantic caching functionality into the model. First, ensure that you import the necessary classes from `semantic_caching.py`: @@ -234,11 +233,12 @@ but make sure to specify proper paths to the cloned `vllm_backend` repository and replace `` with the latest release of Triton. ```bash -docker run --gpus all -it --net=host --rm -p 8001:8001 --shm-size=1G \ ---ulimit memlock=-1 --ulimit stack=67108864 \ --v /path/to/vllm_backend/src/:/opt/tritonserver/backends/vllm \ --v /path/to/vllm_backend/samples/model_repository:/work/model_repository \ --w /work nvcr.io/nvidia/tritonserver:-vllm-python-py3 +docker run --gpus all -it --net=host --rm \ + --shm-size=1G --ulimit memlock=-1 --ulimit stack=67108864 \ + -v /path/to/vllm_backend/src/:/opt/tritonserver/backends/vllm \ + -v /path/to/vllm_backend/samples/model_repository:/workspace/model_repository \ + -w /workspace \ + nvcr.io/nvidia/tritonserver:-vllm-python-py3 ``` When inside the container, make sure to install required dependencies: From 210a4000b896315819b540cee9f95c0aadbfb65a Mon Sep 17 00:00:00 2001 From: oandreeva-nv Date: Wed, 23 Oct 2024 15:28:53 -0700 Subject: [PATCH 07/15] Adjusted added codebase for clarity --- Conceptual_Guide/Part_8-semantic_caching/README.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/Conceptual_Guide/Part_8-semantic_caching/README.md b/Conceptual_Guide/Part_8-semantic_caching/README.md index 45bbd3fa..c09ed4b2 100644 --- a/Conceptual_Guide/Part_8-semantic_caching/README.md +++ b/Conceptual_Guide/Part_8-semantic_caching/README.md @@ -172,7 +172,8 @@ its operations. ) self._shutdown_event = asyncio.Event() self._event_thread.start() -+ self.semantic_cache = SemanticCPUCache(SemanticCPUCacheConfig) ++ config = SemanticCPUCacheConfig() ++ self.semantic_cache = SemanticCPUCache(config=config) ``` From 841464c2cb93f9135bf8661bed09d4f47a3a00f7 Mon Sep 17 00:00:00 2001 From: Olga Andreeva <124622579+oandreeva-nv@users.noreply.github.com> Date: Wed, 23 Oct 2024 15:32:02 -0700 Subject: [PATCH 08/15] Update Conceptual_Guide/Part_8-semantic_caching/README.md Co-authored-by: Kris Hung --- Conceptual_Guide/Part_8-semantic_caching/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Conceptual_Guide/Part_8-semantic_caching/README.md b/Conceptual_Guide/Part_8-semantic_caching/README.md index c09ed4b2..0cce5824 100644 --- a/Conceptual_Guide/Part_8-semantic_caching/README.md +++ b/Conceptual_Guide/Part_8-semantic_caching/README.md @@ -286,7 +286,7 @@ sys 0m0.015s Now, let's try a different response, but keep the semantics: ```bash -time curl -X POST localhost:8000/v2/models/vllm_model/generate -d '{"text_input": "How do I set up model repository for Triton Inference Server?", "parameters": {"stream": false, "temperature": 0, "max_tokens":100}, "exclude_input_in_output":true} +time curl -X POST localhost:8000/v2/models/vllm_model/generate -d '{"text_input": "How do I set up model repository for Triton Inference Server?", "parameters": {"stream": false, "temperature": 0, "max_tokens":100}, "exclude_input_in_output":true}' ``` Upon success, you should see a response from the server like this one: From cba419671e882bdbd6dfdadc676190e647e77c8d Mon Sep 17 00:00:00 2001 From: oandreeva-nv Date: Wed, 23 Oct 2024 15:33:32 -0700 Subject: [PATCH 09/15] Copyright --- .../artifacts/semantic_caching.py | 26 +++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/Conceptual_Guide/Part_8-semantic_caching/artifacts/semantic_caching.py b/Conceptual_Guide/Part_8-semantic_caching/artifacts/semantic_caching.py index 703abe19..b4ec5618 100644 --- a/Conceptual_Guide/Part_8-semantic_caching/artifacts/semantic_caching.py +++ b/Conceptual_Guide/Part_8-semantic_caching/artifacts/semantic_caching.py @@ -1,3 +1,29 @@ +# Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + import itertools from dataclasses import dataclass from typing import Any, Dict, Hashable, Optional From dd4de13afa361b1fa6c0222bcdc77d39cc5615c7 Mon Sep 17 00:00:00 2001 From: oandreeva-nv Date: Thu, 24 Oct 2024 12:08:33 -0700 Subject: [PATCH 10/15] Added patch --- .../artifacts/semantic_cache.patch | 238 ++++++++++++++++++ 1 file changed, 238 insertions(+) create mode 100644 Conceptual_Guide/Part_8-semantic_caching/artifacts/semantic_cache.patch diff --git a/Conceptual_Guide/Part_8-semantic_caching/artifacts/semantic_cache.patch b/Conceptual_Guide/Part_8-semantic_caching/artifacts/semantic_cache.patch new file mode 100644 index 00000000..5df4ceaf --- /dev/null +++ b/Conceptual_Guide/Part_8-semantic_caching/artifacts/semantic_cache.patch @@ -0,0 +1,238 @@ +diff --git a/src/model.py b/src/model.py +index 3f6e23b..d4228d2 100644 +--- a/src/model.py ++++ b/src/model.py +@@ -42,6 +42,7 @@ from vllm.sampling_params import SamplingParams + from vllm.utils import random_uuid + + from utils.metrics import VllmStatLogger ++from utils.semantic_caching import SemanticCPUCache, SemanticCPUCacheConfig + + _VLLM_ENGINE_ARGS_FILENAME = "model.json" + _MULTI_LORA_ARGS_FILENAME = "multi_lora.json" +@@ -130,6 +131,8 @@ class TritonPythonModel: + ) + self._shutdown_event = asyncio.Event() + self._event_thread.start() ++ config = SemanticCPUCacheConfig() ++ self.semantic_cache = SemanticCPUCache(config=config) + + def init_engine(self): + # Currently, Triton needs to use decoupled policy for asynchronously +@@ -407,6 +410,18 @@ class TritonPythonModel: + raise ValueError( + "When streaming, `exclude_input_in_output` = False is not allowed." + ) ++ cache_hit = self.semantic_cache.get(prompt) ++ if cache_hit: ++ try: ++ response_sender.send( ++ self.create_response(cache_hit, prepend_input), ++ flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL, ++ ) ++ if decrement_ongoing_request_count: ++ self.ongoing_request_count -= 1 ++ except Exception as err: ++ print(f"Unexpected {err=} for prompt {prompt}") ++ return None + + # Request parameters are not yet supported via + # BLS. Provide an optional mechanism to receive serialized +@@ -481,6 +496,7 @@ class TritonPythonModel: + self.create_response(last_output, prepend_input), + flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL, + ) ++ self.semantic_cache.set(prompt, last_output) + + except Exception as e: + self.logger.log_error(f"[vllm] Error generating stream: {e}") +diff --git a/src/utils/semantic_caching.py b/src/utils/semantic_caching.py +new file mode 100644 +index 0000000..457a163 +--- /dev/null ++++ b/src/utils/semantic_caching.py +@@ -0,0 +1,184 @@ ++import itertools ++from dataclasses import dataclass ++from typing import Any, Dict, Hashable, Optional ++ ++import faiss ++import numpy as np ++from sentence_transformers import SentenceTransformer ++from theine import Cache ++ ++ ++class KeyMapper: ++ """ ++ A class to manage bidirectional mapping between hashable keys and integer IDs. ++ """ ++ ++ def __init__(self): ++ self.hk_map: Dict[Hashable, int] = {} ++ self.kh_map: Dict[int, Hashable] = {} ++ self.counter = itertools.count() ++ ++ def add_key(self, key: Hashable): ++ """ ++ Add a new key to the mapper and return its assigned ID. ++ ++ Args: ++ key (Hashable): The key to be added. ++ ++ Returns: ++ int: The assigned ID for the key. ++ """ ++ if key in self.hk_map.keys(): ++ return None ++ id = next(self.counter) ++ self.hk_map[key] = id ++ self.kh_map[id] = key ++ return id ++ ++ def remove_key(self, key: Hashable): ++ """ ++ Remove key from the mapper and return its ID. ++ ++ Args: ++ key (Hashable): The key to be removed. ++ ++ Returns: ++ int: The ID for the removed key. ++ """ ++ id = self.hk_map.pop(key, None) ++ if id is not None: ++ self.kh_map.pop(id, None) ++ return id ++ return None ++ ++ def get_key(self, id: int): ++ """ ++ Retrieve the key associated with the given ID. ++ ++ Args: ++ id (int): The ID to look up. ++ ++ Returns: ++ Optional[Hashable]: The associated key, or None if not found. ++ """ ++ return self.kh_map.get(id) ++ ++ def get_id(self, key: Hashable): ++ """ ++ Retrieve the ID associated with the given key. ++ ++ Args: ++ key (Hashable): The key to look up. ++ ++ Returns: ++ Optional[int]: The associated ID, or None if not found. ++ """ ++ return self.hk_map.get(key) ++ ++ ++@dataclass ++class SemanticCPUCacheConfig: ++ """ ++ Configuration class for SemanticCPUCache. ++ ++ Attributes: ++ cache (Any): The cache object to use. ++ encoder (Any): The encoder object for embedding queries. ++ index (Any): The index object for similarity search. ++ threshold (float): The similarity threshold for considering a match. ++ key_mapper (Any): The key mapper object for managing key-ID mappings. ++ """ ++ ++ cache: Any = Cache(policy="lru", size=1000) ++ encoder: Any = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2") ++ index: Any = faiss.IndexIDMap(faiss.IndexFlatL2(384)) ++ threshold: float = 0.25 ++ key_mapper: Any = KeyMapper() ++ ++ ++class SemanticCPUCache: ++ """ ++ Semantic cache implementation. ++ """ ++ ++ def __init__(self, config: SemanticCPUCacheConfig): ++ """ ++ Initialize the SemanticCPUCache with the given configuration. ++ ++ Args: ++ config (SemanticCPUCacheConfig): The configuration object. ++ """ ++ self.encoder = config.encoder ++ self.index = config.index ++ self.cache = config.cache ++ self.key_map = config.key_mapper ++ self.threshold = config.threshold ++ ++ def get(self, key: Hashable, default: Any = None) -> Any: ++ """ ++ Retrieve a value from the cache based on the given key. ++ ++ First, a similarity search is performed. If a similar key is found ++ within the threshold, its associated value is returned. ++ Otherwise, the default value is returned. ++ ++ Args: ++ key (Hashable): The key to look up. ++ default (Any, optional): The default value to return if no match is found. Defaults to None. ++ ++ Returns: ++ Any: The retrieved value or the default value. ++ """ ++ if self.index.ntotal < 1: ++ return default ++ ++ key_search = np.asarray([self.encoder.encode(key)]) ++ dist, ind = self.index.search(key_search, 1) ++ # print(dist[0][0]) ++ ++ if dist[0][0] > self.threshold: ++ return default ++ ++ key_str = self.key_map.get_key(ind[0][0]) ++ ++ return self.cache.get(key=key_str, default=default) ++ ++ def set(self, key: Hashable, value: Any) -> Optional[str]: ++ """ ++ Set a key-value pair in the cache. ++ ++ This method adds the key to the key mapper, encodes the key, ++ adds the encoded key to the index, and sets the value in the cache. ++ ++ ++ Args: ++ key (Hashable): The key to set. ++ value (Any): The value to associate with the key. ++ ++ Returns: ++ Optional[str]: The result of setting the value in the cache. ++ ++ Raises: ++ AssertionError: If the key could not be added to the key mapper. ++ """ ++ id = self.key_map.add_key(key) ++ if id is not None: ++ # TODO: leaking implementation `add_with_ids`. add a layer ++ self.index.add_with_ids( ++ np.expand_dims(self.encoder.encode(key), axis=0), np.asarray([id]) ++ ) ++ ++ evicted_key = self.cache.set(key, value) ++ self._handle_evicted_key(evicted_key=evicted_key) ++ ++ return None ++ ++ def _handle_evicted_key(self, evicted_key: Hashable) -> None: ++ if evicted_key is None: ++ return None ++ # TODO: extremely coupled, remove dependency on key id? ++ evicted_id = self.key_map.remove_key(evicted_key) ++ print(evicted_id) ++ # TODO: leaking implementation `remove_ids`. add a layer ++ self.index.remove_ids(np.array([evicted_id])) ++ return None From 0ec9015d80a9b44391e392d49afa41b236049767 Mon Sep 17 00:00:00 2001 From: oandreeva-nv Date: Thu, 24 Oct 2024 17:15:07 -0700 Subject: [PATCH 11/15] Added limitations sections + some clarifications --- .../Part_8-semantic_caching/README.md | 54 +++++++++++++------ 1 file changed, 38 insertions(+), 16 deletions(-) diff --git a/Conceptual_Guide/Part_8-semantic_caching/README.md b/Conceptual_Guide/Part_8-semantic_caching/README.md index 0cce5824..d300b300 100644 --- a/Conceptual_Guide/Part_8-semantic_caching/README.md +++ b/Conceptual_Guide/Part_8-semantic_caching/README.md @@ -71,9 +71,11 @@ This approach offers several benefits including, but not limited to: + **Scalability** - - The improved resource efficiency and reduced computational demands allows - applications to serve more users without a proportional increase in - infrastructure costs. + - As the user base and the volume of queries grow, the probability of cache + hits increases, provided that there is adequate storage and resources + available to support this scaling. The improved resource efficiency and + reduced computational demands allows applications to serve more users + without a proportional increase in infrastructure costs. + **Consistency in Responses** @@ -130,22 +132,34 @@ provide the necessary codebase to implement our semantic caching example. ```bash git clone https://github.com/triton-inference-server/vllm_backend.git +cd vllm_backend ``` -With the repository cloned, the next step is to add the -[semantic_caching.py.](./artifacts/semantic_caching.py) script to -the appropriate directory. This script contains the logic for our semantic -caching implementation. +With the repository successfully cloned, the next step is to apply all +necessary modifications. To simplify this process, we've prepared a +[semantic_cache.patch](tutorials/Conceptual_Guide/Part_8-semantic_caching/artifacts/semantic_cache.patch) +that consolidates all changes into a single step: ```bash -wget -P vllm_backend/src/utils/ https://raw.githubusercontent.com/triton-inference-server/tutorials/refs/heads/main/Conceptual_Guide/Part_8-semantic_caching/artifacts/semantic_caching.py +curl https://raw.githubusercontent.com/triton-inference-server/tutorials/refs/heads/main/Conceptual_Guide/Part_8-semantic_caching/artifacts/semantic_cache.patch | git apply -v ``` -Now that we have added the semantic caching script, let's proceed by making -some adjustments in `vllm_backend/src/model.py`. These changes will integrate -the semantic caching functionality into the model. +If you're eager to start using Triton with the optimized vLLM backend, +you can skip ahead to the +[Launching Triton with Optimized vLLM Backend](#launching-triton-with-optimized-vllm-backend) +section. However, for those interested in understanding the specifics, +let's explore what this patch includes. -First, ensure that you import the necessary classes from `semantic_caching.py`: +The patch introduces a new script, +[semantic_caching.py.](./artifacts/semantic_caching.py), which is added to the +appropriate directory. This script implements the core logic for our +semantic caching functionality. + +Next, the patch integrates semantic caching into the model. Let's walk through +these changes step-by-step. + +Firstly, it imports the necessary classes from +[semantic_caching.py.](./artifacts/semantic_caching.py) into the codebase: ```diff ... @@ -154,7 +168,7 @@ from utils.metrics import VllmStatLogger +from utils.semantic_caching import SemanticCPUCacheConfig, SemanticCPUCache ``` -Next, initialize the semantic cache during the initialization step. +Next, it sets up the semantic cache during the initialization step. This setup will prepare your model to utilize semantic caching during its operations. @@ -177,9 +191,9 @@ its operations. ``` -Finally, we'll add logic to query and update the semantic cache during -request processing. This ensures that cached responses are efficiently utilized -whenever possible. +Finally, the patch incorporates logic to query and update the semantic cache +during request processing. This ensures that cached responses are efficiently +utilized whenever possible. ```diff async def generate(self, request): @@ -315,6 +329,14 @@ Clearly, the latter 2 requests are semantically similar to the first one, which resulted in a cache hit scenario, which reduced the latency of our model from approx 1.1s to the average of 0.048s per request. +## Current limitations + +* The current implementation of the Semantic Cache only considers the prompt +itself for cache hits, without accounting for additional request parameters +such as `max_tokens` and `temperature`. As a result, these parameters are not +included in the cache hit evaluation, which may affect the accuracy of cached +responses when different configurations are used. + ## Interested in This Feature? While this reference implementation provides a glimpse into the potential From fbc046640dbd64a68cf92310d4c8d72c9d4e4157 Mon Sep 17 00:00:00 2001 From: oandreeva-nv Date: Thu, 24 Oct 2024 17:16:49 -0700 Subject: [PATCH 12/15] heading format --- Conceptual_Guide/Part_8-semantic_caching/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Conceptual_Guide/Part_8-semantic_caching/README.md b/Conceptual_Guide/Part_8-semantic_caching/README.md index d300b300..e785e4c9 100644 --- a/Conceptual_Guide/Part_8-semantic_caching/README.md +++ b/Conceptual_Guide/Part_8-semantic_caching/README.md @@ -329,7 +329,7 @@ Clearly, the latter 2 requests are semantically similar to the first one, which resulted in a cache hit scenario, which reduced the latency of our model from approx 1.1s to the average of 0.048s per request. -## Current limitations +## Current Limitations * The current implementation of the Semantic Cache only considers the prompt itself for cache hits, without accounting for additional request parameters From b72be052672ed533d274c74708bfc2468d9319a0 Mon Sep 17 00:00:00 2001 From: oandreeva-nv Date: Thu, 24 Oct 2024 18:05:12 -0700 Subject: [PATCH 13/15] Finilised Interested in this feature discussion --- Conceptual_Guide/Part_8-semantic_caching/README.md | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/Conceptual_Guide/Part_8-semantic_caching/README.md b/Conceptual_Guide/Part_8-semantic_caching/README.md index e785e4c9..277fa7a3 100644 --- a/Conceptual_Guide/Part_8-semantic_caching/README.md +++ b/Conceptual_Guide/Part_8-semantic_caching/README.md @@ -344,8 +344,8 @@ of semantic caching, it's important to note that it's not an officially supported feature in Triton Inference Server. We value your input! If you're interested in seeing semantic caching as a -supported feature in future releases, we encourage you to [FILL IN] - -Provide details about why you think semantic caching would be valuable for -your use case. Your feedback helps shape our product roadmap, +supported feature in future releases, we invite you to join the ongoing +[discussion.](https://github.com/triton-inference-server/server/discussions/7742) +Provide details about why you think semantic caching would +be valuable for your use case. Your feedback helps shape our product roadmap, and we appreciate your contributions to making our software better for everyone. \ No newline at end of file From 70cac18dbcab845c238d62b2c80152886388a8f5 Mon Sep 17 00:00:00 2001 From: oandreeva-nv Date: Thu, 24 Oct 2024 18:08:38 -0700 Subject: [PATCH 14/15] added a limitation --- Conceptual_Guide/Part_8-semantic_caching/README.md | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/Conceptual_Guide/Part_8-semantic_caching/README.md b/Conceptual_Guide/Part_8-semantic_caching/README.md index 277fa7a3..83849696 100644 --- a/Conceptual_Guide/Part_8-semantic_caching/README.md +++ b/Conceptual_Guide/Part_8-semantic_caching/README.md @@ -337,6 +337,14 @@ such as `max_tokens` and `temperature`. As a result, these parameters are not included in the cache hit evaluation, which may affect the accuracy of cached responses when different configurations are used. +* Semantic Cache effectiveness is heavily reliant on the choice of embedding +model and application context. For instance, queries like "How to set up model +repository for Triton Inference Server?" and "How not to set up model +repository for Triton Inference Server?" may have high cosine similarity +despite differing semantically. This makes it challenging to set an optimal +threshold for cache hits, as a narrow similarity range might exclude useful +cache entries. + ## Interested in This Feature? While this reference implementation provides a glimpse into the potential From dc9ee05ddf28a3d7c8ee402e6e1fc4e78ad29c1d Mon Sep 17 00:00:00 2001 From: Olga Andreeva <124622579+oandreeva-nv@users.noreply.github.com> Date: Fri, 25 Oct 2024 14:12:27 -0700 Subject: [PATCH 15/15] Apply suggestions from code review Co-authored-by: Ryan McCormick --- Conceptual_Guide/Part_8-semantic_caching/README.md | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/Conceptual_Guide/Part_8-semantic_caching/README.md b/Conceptual_Guide/Part_8-semantic_caching/README.md index 83849696..cbdf363e 100644 --- a/Conceptual_Guide/Part_8-semantic_caching/README.md +++ b/Conceptual_Guide/Part_8-semantic_caching/README.md @@ -87,7 +87,7 @@ This approach offers several benefits including, but not limited to: ## Sample Reference Implementation In this tutorial we provide a reference implementation for a Semantic Cache in -[semantic_caching.py.](./artifacts/semantic_caching.py) There are 3 key +[semantic_caching.py](./artifacts/semantic_caching.py). There are 3 key dependencies: * [SentenceTransformer](https://sbert.net/): a Python framework for computing dense vector representations (embeddings) of sentences, paragraphs, and images. @@ -104,7 +104,7 @@ clustering of dense vectors. algorithms. - Alternatives include [annoy](https://github.com/spotify/annoy), or [cuVS](https://github.com/rapidsai/cuvs). However, note that cuVS already - has an integration in Faiss, more on this can be found [here.](https://docs.rapids.ai/api/cuvs/nightly/integrations/faiss/) + has an integration in Faiss, more on this can be found [here](https://docs.rapids.ai/api/cuvs/nightly/integrations/faiss/). * [Theine](https://github.com/Yiling-J/theine): High performance in-memory cache. - We will use it as our exact match cache backend. After the most similar @@ -151,7 +151,7 @@ section. However, for those interested in understanding the specifics, let's explore what this patch includes. The patch introduces a new script, -[semantic_caching.py.](./artifacts/semantic_caching.py), which is added to the +[semantic_caching.py](./artifacts/semantic_caching.py), which is added to the appropriate directory. This script implements the core logic for our semantic caching functionality. @@ -159,7 +159,7 @@ Next, the patch integrates semantic caching into the model. Let's walk through these changes step-by-step. Firstly, it imports the necessary classes from -[semantic_caching.py.](./artifacts/semantic_caching.py) into the codebase: +[semantic_caching.py](./artifacts/semantic_caching.py) into the codebase: ```diff ... @@ -353,7 +353,7 @@ supported feature in Triton Inference Server. We value your input! If you're interested in seeing semantic caching as a supported feature in future releases, we invite you to join the ongoing -[discussion.](https://github.com/triton-inference-server/server/discussions/7742) +[discussion](https://github.com/triton-inference-server/server/discussions/7742). Provide details about why you think semantic caching would be valuable for your use case. Your feedback helps shape our product roadmap, and we appreciate your contributions to making our software better for everyone. \ No newline at end of file