Skip to content

Commit

Permalink
[shared] Implement SQL sanitization for MSSQL (open-telemetry#2330)
Browse files Browse the repository at this point in the history
  • Loading branch information
alanwest authored Nov 18, 2024
1 parent 15190d0 commit a7d5aeb
Show file tree
Hide file tree
Showing 6 changed files with 436 additions and 0 deletions.
1 change: 1 addition & 0 deletions opentelemetry-dotnet-contrib.sln
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,7 @@ Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "Shared", "Shared", "{1FCC8E
src\Shared\ServerCertificateValidationProvider.cs = src\Shared\ServerCertificateValidationProvider.cs
src\Shared\SpanAttributeConstants.cs = src\Shared\SpanAttributeConstants.cs
src\Shared\SpanHelper.cs = src\Shared\SpanHelper.cs
src\Shared\SqlProcessor.cs = src\Shared\SqlProcessor.cs
src\Shared\UriHelper.cs = src\Shared\UriHelper.cs
EndProjectSection
EndProject
Expand Down
234 changes: 234 additions & 0 deletions src/Shared/SqlProcessor.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,234 @@
// Copyright The OpenTelemetry Authors
// SPDX-License-Identifier: Apache-2.0

using System.Text;

namespace OpenTelemetry.Instrumentation;

public static class SqlProcessor
{
public static string GetSanitizedSql(string sql)
{
if (sql == null)
{
return string.Empty;
}

var sb = new StringBuilder(capacity: sql.Length);
for (var i = 0; i < sql.Length; ++i)
{
if (SkipComment(sql, ref i))
{
continue;
}

if (SanitizeStringLiteral(sql, ref i) ||
SanitizeHexLiteral(sql, ref i) ||
SanitizeNumericLiteral(sql, ref i))
{
sb.Append('?');
continue;
}

WriteToken(sql, ref i, sb);
}

return sb.ToString();
}

private static bool SkipComment(string sql, ref int index)
{
var i = index;
var ch = sql[i];
var length = sql.Length;

// Scan past multi-line comment
if (ch == '/' && i + 1 < length && sql[i + 1] == '*')
{
for (i += 2; i < length; ++i)
{
ch = sql[i];
if (ch == '*' && i + 1 < length && sql[i + 1] == '/')
{
i += 1;
break;
}
}

index = i;
return true;
}

// Scan past single-line comment
if (ch == '-' && i + 1 < length && sql[i + 1] == '-')
{
for (i += 2; i < length; ++i)
{
ch = sql[i];
if (ch == '\r' || ch == '\n')
{
i -= 1;
break;
}
}

index = i;
return true;
}

return false;
}

private static bool SanitizeStringLiteral(string sql, ref int index)
{
var ch = sql[index];
if (ch == '\'')
{
var i = index + 1;
var length = sql.Length;
for (; i < length; ++i)
{
ch = sql[i];
if (ch == '\'' && i + 1 < length && sql[i + 1] == '\'')
{
++i;
continue;
}

if (ch == '\'')
{
break;
}
}

index = i;
return true;
}

return false;
}

private static bool SanitizeHexLiteral(string sql, ref int index)
{
var i = index;
var ch = sql[i];
var length = sql.Length;

if (ch == '0' && i + 1 < length && (sql[i + 1] == 'x' || sql[i + 1] == 'X'))
{
for (i += 2; i < length; ++i)
{
ch = sql[i];
if (char.IsDigit(ch) ||
ch == 'A' || ch == 'a' ||
ch == 'B' || ch == 'b' ||
ch == 'C' || ch == 'c' ||
ch == 'D' || ch == 'd' ||
ch == 'E' || ch == 'e' ||
ch == 'F' || ch == 'f')
{
continue;
}

i -= 1;
break;
}

index = i;
return true;
}

return false;
}

private static bool SanitizeNumericLiteral(string sql, ref int index)
{
var i = index;
var ch = sql[i];
var length = sql.Length;

// Scan past leading sign
if ((ch == '-' || ch == '+') && i + 1 < length && (char.IsDigit(sql[i + 1]) || sql[i + 1] == '.'))
{
i += 1;
ch = sql[i];
}

// Scan past leading decimal point
var periodMatched = false;
if (ch == '.' && i + 1 < length && char.IsDigit(sql[i + 1]))
{
periodMatched = true;
i += 1;
ch = sql[i];
}

if (char.IsDigit(ch))
{
var exponentMatched = false;
for (i += 1; i < length; ++i)
{
ch = sql[i];
if (char.IsDigit(ch))
{
continue;
}

if (!periodMatched && ch == '.')
{
periodMatched = true;
continue;
}

if (!exponentMatched && (ch == 'e' || ch == 'E'))
{
// Scan past sign in exponent
if (i + 1 < length && (sql[i + 1] == '-' || sql[i + 1] == '+'))
{
i += 1;
}

exponentMatched = true;
continue;
}

i -= 1;
break;
}

index = i;
return true;
}

return false;
}

private static void WriteToken(string sql, ref int index, StringBuilder sb)
{
var i = index;
var ch = sql[i];

if (char.IsLetter(ch) || ch == '_')
{
for (; i < sql.Length; i++)
{
ch = sql[i];
if (char.IsLetter(ch) || ch == '_' || char.IsDigit(ch))
{
sb.Append(ch);
continue;
}

break;
}

i -= 1;
}
else
{
sb.Append(ch);
}

index = i;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
<ItemGroup>
<PackageReference Include="Microsoft.Extensions.Configuration" Version="$(MicrosoftExtensionsConfigurationPkgVer)"/>
<PackageReference Include="OpenTelemetry.Api" Version="$(OpenTelemetryCoreLatestVersion)" />
<PackageReference Include="System.Text.Json" Version="$(SystemTextJsonLatestNet8OutOfBandPkgVer)" />
</ItemGroup>

<ItemGroup>
Expand All @@ -23,6 +24,13 @@
<Compile Include="$(RepoRoot)\src\Shared\RedactionHelper.cs" Link="Includes\RedactionHelper.cs" />
<Compile Include="$(RepoRoot)\src\Shared\RequestDataHelper.cs" Link="Includes\RequestDataHelper.cs" />
<Compile Include="$(RepoRoot)\src\Shared\SemanticConventions.cs" Link="Includes\SemanticConventions.cs" />
<Compile Include="$(RepoRoot)\src\Shared\SqlProcessor.cs" Link="Includes\SqlProcessor.cs" />
</ItemGroup>

<ItemGroup>
<EmbeddedResource Include="SqlProcessorTestCases.json">
<LogicalName>SqlProcessorTestCases.json</LogicalName>
</EmbeddedResource>
</ItemGroup>

</Project>
46 changes: 46 additions & 0 deletions test/OpenTelemetry.Contrib.Shared.Tests/SqlProcessorTestCases.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
// Copyright The OpenTelemetry Authors
// SPDX-License-Identifier: Apache-2.0

using System.Reflection;
using System.Text.Json;
using System.Text.Json.Serialization;

namespace OpenTelemetry.Instrumentation.Tests;

public static class SqlProcessorTestCases
{
private static readonly JsonSerializerOptions JsonSerializerOptions = new()
{
PropertyNamingPolicy = JsonNamingPolicy.CamelCase,
Converters = { new JsonStringEnumConverter() },
};

public static IEnumerable<object[]> GetTestCases()
{
var assembly = Assembly.GetExecutingAssembly();
var input = JsonSerializer.Deserialize<TestCase[]>(
assembly.GetManifestResourceStream("SqlProcessorTestCases.json")!,
JsonSerializerOptions)!;

foreach (var testCase in input)
{
yield return new object[] { testCase };
}
}

public class TestCase
{
public string Name { get; set; } = string.Empty;

public string Sql { get; set; } = string.Empty;

public string Sanitized { get; set; } = string.Empty;

public IEnumerable<string> Dialects { get; set; } = [];

public override string ToString()
{
return this.Name;
}
}
}
Loading

0 comments on commit a7d5aeb

Please sign in to comment.