Skip to content

Commit

Permalink
Add IAST propagation to String valueOf (#8013)
Browse files Browse the repository at this point in the history
  • Loading branch information
Mariovido authored and PerfectSlayer committed Nov 28, 2024
1 parent 60c4c42 commit 5d6e519
Show file tree
Hide file tree
Showing 7 changed files with 211 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,15 @@
import static datadog.trace.api.telemetry.LogCollector.SEND_TELEMETRY;

import com.datadog.iast.model.Range;
import com.datadog.iast.model.Source;
import com.datadog.iast.taint.Ranges;
import com.datadog.iast.taint.TaintedObject;
import com.datadog.iast.taint.TaintedObjects;
import com.datadog.iast.util.RangeBuilder;
import com.datadog.iast.util.Ranged;
import com.datadog.iast.util.StringUtils;
import datadog.trace.api.iast.IastContext;
import datadog.trace.api.iast.Taintable;
import datadog.trace.api.iast.propagation.StringModule;
import de.thetaphi.forbiddenapis.SuppressForbidden;
import edu.umd.cs.findbugs.annotations.SuppressFBWarnings;
Expand Down Expand Up @@ -742,6 +744,54 @@ public String onStringReplace(
numReplacements);
}

@Override
@SuppressFBWarnings("ES_COMPARING_PARAMETER_STRING_WITH_EQ")
public void onStringValueOf(Object param, @Nonnull String result) {
if (param == null || !canBeTainted(result)) {
return;
}
final IastContext ctx = IastContext.Provider.get();
if (ctx == null) {
return;
}
final TaintedObjects taintedObjects = ctx.getTaintedObjects();

if (param instanceof Taintable) {
final Taintable taintable = (Taintable) param;
if (!taintable.$DD$isTainted()) {
return;
}
final Source source = (Source) taintable.$$DD$getSource();
final Range[] ranges =
Ranges.forCharSequence(
result, new Source(source.getOrigin(), source.getName(), source.getValue()));

taintedObjects.taint(result, ranges);
} else {
final TaintedObject taintedParam = taintedObjects.get(param);
if (taintedParam == null) {
return;
}

final Range[] rangesParam = taintedParam.getRanges();
if (rangesParam.length == 0) {
return;
}

// Special objects like InputStream...
if (rangesParam[0].getLength() == Integer.MAX_VALUE) {
final Source source = rangesParam[0].getSource();
final Range[] ranges =
Ranges.forCharSequence(
result, new Source(source.getOrigin(), source.getName(), source.getValue()));

taintedObjects.taint(result, ranges);
} else {
taintedObjects.taint(result, rangesParam);
}
}
}

/**
* Adds the tainted ranges belonging to the current parameter added via placeholder taking care of
* an optional tainted placeholder.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
package com.datadog.iast.propagation

import com.datadog.iast.IastModuleImplTestBase
import com.datadog.iast.model.Source
import com.datadog.iast.taint.TaintedObjects
import datadog.trace.api.gateway.RequestContext
import datadog.trace.api.gateway.RequestContextSlot
import datadog.trace.api.iast.SourceTypes
import datadog.trace.api.iast.Taintable
import datadog.trace.api.iast.propagation.StringModule
import datadog.trace.bootstrap.instrumentation.api.AgentSpan
import datadog.trace.bootstrap.instrumentation.api.AgentTracer
Expand All @@ -16,6 +20,7 @@ import static com.datadog.iast.taint.TaintUtils.fromTaintFormat
import static com.datadog.iast.taint.TaintUtils.getStringFromTaintFormat
import static com.datadog.iast.taint.TaintUtils.taint
import static com.datadog.iast.taint.TaintUtils.taintFormat
import static com.datadog.iast.taint.TaintUtils.taintObject

@CompileDynamic
class StringModuleTest extends IastModuleImplTestBase {
Expand Down Expand Up @@ -1298,6 +1303,65 @@ class StringModuleTest extends IastModuleImplTestBase {
"==>my_o<==u==>tput<====>my_o<==u==>tput<==" | 'out' | '==>in<==' | 0 | "==>my_o<==u==>tput<====>my_o<==u==>tput<=="
}
void 'test valueOf with (#param) and make sure IastRequestContext is called'() {
given:
final taintedObjects = ctx.getTaintedObjects()
def paramTainted = addFromTaintFormat(taintedObjects, param)
def result = String.valueOf(paramTainted)
when:
module.onStringValueOf(paramTainted, result)
def taintedObject = taintedObjects.get(result)
then:
1 * tracer.activeSpan() >> span
taintFormat(result, taintedObject.getRanges()) == expected
where:
param | expected
"==>test<==" | "==>test<=="
sb("==>test<==") | "==>test<=="
sbf("==>my_input<==") | "==>my_input<=="
}
void 'test valueOf with taintable object and make sure IastRequestContext is called'() {
given:
final taintedObjects = ctx.getTaintedObjects()
final source = taintedSource()
final param = taintable(taintedObjects, source)
final result = String.valueOf(param)
when:
module.onStringValueOf(param, result)
final taintedObject = taintedObjects.get(result)
then:
1 * tracer.activeSpan() >> span
taintFormat(result, taintedObject.getRanges()) == "==>my_input<=="
}
void 'test valueOf with special objects and make sure IastRequestContext is called'() {
given:
final taintedObjects = ctx.getTaintedObjects()
final source = taintedSource()
final param = new Object() {
@Override
String toString() {
return "my_input"
}
}
taintObject(taintedObjects, param, source)
final result = String.valueOf(param)
when:
module.onStringValueOf(param, result)
final taintedObject = taintedObjects.get(result)
then:
1 * tracer.activeSpan() >> span
taintFormat(result, taintedObject.getRanges()) == "==>my_input<=="
}
private static Date date(final String pattern, final String value) {
return new SimpleDateFormat(pattern).parse(value)
}
Expand All @@ -1310,11 +1374,44 @@ class StringModuleTest extends IastModuleImplTestBase {
return new StringBuilder(string)
}
private static StringBuilder sbf() {
private static StringBuffer sbf() {
return sbf('')
}
private static StringBuffer sbf(final String string) {
return new StringBuffer(string)
}
private static Source taintedSource(String value = 'value') {
return new Source(SourceTypes.REQUEST_PARAMETER_VALUE, 'name', value)
}
private static Taintable taintable(TaintedObjects tos, Source source = null) {
final result = new MockTaintable()
if (source != null) {
taintObject(tos, result, source)
}
return result
}
private static class MockTaintable implements Taintable {
private Source source
@SuppressWarnings('CodeNarc')
@Override
Source $$DD$getSource() {
return source
}
@SuppressWarnings('CodeNarc')
@Override
void $$DD$setSource(Source source) {
this.source = source
}
@Override
String toString() {
return "my_input"
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package com.datadog.iast.taint
import com.datadog.iast.model.Range
import com.datadog.iast.model.Source
import datadog.trace.api.iast.SourceTypes
import datadog.trace.api.iast.Taintable

import static datadog.trace.api.iast.VulnerabilityMarks.NOT_MARKED

Expand Down Expand Up @@ -81,6 +82,14 @@ class TaintUtils {
getStringFromTaintFormat(appendable.toString())
}

static TaintedObject getTaintedObject(final TaintedObjects tos, final Object target) {
if (target instanceof Taintable) {
final source = (target as Taintable).$$DD$getSource() as Source
return source == null ? null : new TaintedObject(target, Ranges.forObject(source))
}
return tos.get(target)
}

static <E> E taint(final TaintedObjects tos, final E value) {
if (value instanceof String) {
return addFromTaintFormat(tos, value as String)
Expand All @@ -89,6 +98,17 @@ class TaintUtils {
return value
}

static TaintedObject taintObject(final TaintedObjects tos, final Object target, Source source, int mark = NOT_MARKED) {
if (target instanceof Taintable) {
target.$$DD$setSource(source)
} else if (target instanceof CharSequence) {
tos.taint(target, Ranges.forCharSequence(target, source, mark))
} else {
tos.taint(target, Ranges.forObject(source, mark))
}
return getTaintedObject(tos, target)
}

static String addFromTaintFormat(final TaintedObjects tos, final String s) {
return addFromTaintFormat(tos, s, NOT_MARKED)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -300,4 +300,18 @@ public static String afterReplaceChar(
}
return result;
}

@CallSite.After("java.lang.String java.lang.String.valueOf(java.lang.Object)")
public static String afterValueOf(
@CallSite.Argument(0) final Object obj, @CallSite.Return final String result) {
final StringModule module = InstrumentationBridge.STRING;
if (module != null) {
try {
module.onStringValueOf(obj, result);
} catch (final Throwable e) {
module.onUnexpectedException("afterValueOf threw", e);
}
}
return result;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -265,4 +265,24 @@ class StringCallSiteTest extends AgentTestRunner {
"test" | 't' | 'T' | "TesT"
"test" | 'e' | 'E' | "tEst"
}

def 'test string valueOf call site'() {
setup:
final stringModule = Mock(StringModule)
InstrumentationBridge.registerIastModule(stringModule)

when:
final result = TestStringSuite.valueOf(input)

then:
result == expected
1 * stringModule.onStringValueOf(input, expected)
0 * _

where:
input | expected
"test" | "test"
new StringBuilder("test") | "test"
new StringBuffer("test") | "test"
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -239,4 +239,11 @@ public static String replaceFirst(
LOGGER.debug("After replace first {}", result);
return result;
}

public static String valueOf(final Object param) {
LOGGER.debug("Before valueOf {}", param);
String result = String.valueOf(param);
LOGGER.debug("After valueOf {}", result);
return result;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -61,4 +61,6 @@ void onStringFormat(

String onStringReplace(
@Nonnull String self, String regex, String replacement, int numReplacements);

void onStringValueOf(Object param, @Nullable String result);
}

0 comments on commit 5d6e519

Please sign in to comment.