Skip to content

Commit

Permalink
Use QueryDefinition to avoid injection
Browse files Browse the repository at this point in the history
  • Loading branch information
Pilchie committed May 19, 2024
1 parent d246bf4 commit 79f748d
Showing 1 changed file with 31 additions and 59 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -131,10 +131,12 @@ public async Task<bool> DoesCollectionExistAsync(
string collectionName,
CancellationToken cancellationToken = default)
{
var queryDefinition = new QueryDefinition("SELECT VALUE(c.id) FROM c WHERE c.id = @collectionName");
queryDefinition.WithParameter("@collectionName", collectionName);
using var feedIterator = this.
_cosmosClient
.GetDatabase(this._databaseName)
.GetContainerQueryIterator<string>($"SELECT VALUE(c.id) FROM c WHERE c.id = '{collectionName}'");
.GetContainerQueryIterator<string>(queryDefinition);

while (feedIterator.HasMoreResults)
{
Expand Down Expand Up @@ -212,20 +214,36 @@ public async IAsyncEnumerable<MemoryRecord> GetBatchAsync(
[EnumeratorCancellation] CancellationToken cancellationToken = default)
{
const string OR = " OR ";

// Optimistically create the entire query string.
var whereClause = string.Join(OR, keys.Select(k => $"(x.id = \"{k}\" AND x.key = \"{k}\")"));
var queryDefinition = new QueryDefinition($"""
var queryStart = $"""
SELECT x.id,x.key,x.metadata,x.timestamp{(withEmbeddings ? ",x.embedding" : "")}
FROM x
WHERE {whereClause}
""");

// NOTE: Cosmos DB queries are limited to 512kB, so if this is larger than that, break it into segments.
var byteCount = Encoding.UTF8.GetByteCount(whereClause);
var ratio = byteCount / ((float)(512 * 1024));
if (ratio < 1)
WHERE
""";
// NOTE: Cosmos DB queries are limited to 512kB, so we'll break this into chunks
// of around 500kB. We don't go all the way to 512kB so that we don't have to
// remove the last clause we added once we go over.
int keyIndex = 0;
var keyList = keys.ToList();
while (keyIndex < keyList.Count)
{
var length = queryStart.Length;
var countThisBatch = 0;
var whereClauses = new StringBuilder();
for (int i = keyIndex; i < keyList.Count && length <= 500 * 1024; i++, countThisBatch++)
{
string keyId = $"@key{i:D}";
var clause = $"(x.id = {keyId} AND x.key = {keyId})";
whereClauses.Append(clause).Append(OR);
length += clause.Length + OR.Length + 4 + keyId.Length + Encoding.UTF8.GetByteCount(keyList[keyIndex]);
}
whereClauses.Length -= OR.Length;

var queryDefinition = new QueryDefinition(queryStart + whereClauses);
for (int i = keyIndex; i < keyIndex + countThisBatch; i++)
{
queryDefinition.WithParameter($"@key{i:D}", keyList[i]);
}

var feedIterator = this._cosmosClient
.GetDatabase(this._databaseName)
.GetContainer(collectionName)
Expand All @@ -238,54 +256,8 @@ FROM x
yield return memoryRecord;
}
}
}
else
{
// We're in the very large case, we'll need to split this into multiple queries.
// We add one to catch any fractional piece left in the last segment
var segments = (int)(ratio + 1);
var keyList = keys.ToList();
var keysPerQuery = keyList.Count / segments;
// Make a guess as to how long this query will be. We need at least 26 chars for each "OR" block, so
// put a few extra for the values of the keys.
var estimatedWhereLength = 30 * keysPerQuery;
var localWhere = new StringBuilder(estimatedWhereLength);
for (var i = 0; i < segments; i++)
{
localWhere.Clear();
for (var q = i * keysPerQuery; q < (i + 1) * keysPerQuery && q < keyList.Count; q++)
{
var k = keyList[q];
#if NET6_0_OR_GREATER
localWhere.Append(CultureInfo.InvariantCulture, $"(x.id = \"{k}\" AND x.key = \"{k}\")").Append(OR);
#else
localWhere.Append($"(x.id = \"{k}\" AND x.key = \"{k}\")").Append(OR);
#endif
}

if (localWhere.Length >= OR.Length)
{
localWhere.Length -= OR.Length;

var localQueryDefinition = new QueryDefinition($"""
SELECT x.id,x.key,x.metadata,x.timestamp{(withEmbeddings ? ",x.embedding" : "")}
FROM x
WHERE {localWhere}
""");
var feedIterator = this._cosmosClient
.GetDatabase(this._databaseName)
.GetContainer(collectionName)
.GetItemQueryIterator<MemoryRecord>(localQueryDefinition);

while (feedIterator.HasMoreResults)
{
foreach (var memoryRecord in await feedIterator.ReadNextAsync(cancellationToken).ConfigureAwait(false))
{
yield return memoryRecord;
}
}
}
}
keyIndex += countThisBatch;
}
}

Expand Down

0 comments on commit 79f748d

Please sign in to comment.