diff --git a/src/main/java/com/hubspot/jinjava/JinjavaConfig.java b/src/main/java/com/hubspot/jinjava/JinjavaConfig.java index 6fbfd9404..1db4686b2 100644 --- a/src/main/java/com/hubspot/jinjava/JinjavaConfig.java +++ b/src/main/java/com/hubspot/jinjava/JinjavaConfig.java @@ -26,6 +26,7 @@ import com.hubspot.jinjava.interpret.Context; import com.hubspot.jinjava.interpret.Context.Library; +import com.hubspot.jinjava.random.RandomNumberGeneratorStrategy; public class JinjavaConfig { @@ -44,17 +45,18 @@ public class JinjavaConfig { private Map> disabled; private final boolean failOnUnknownTokens; private final boolean nestedInterpretationEnabled; + private final RandomNumberGeneratorStrategy randomNumberGenerator; public static Builder newBuilder() { return new Builder(); } public JinjavaConfig() { - this(StandardCharsets.UTF_8, Locale.ENGLISH, ZoneOffset.UTC, 10, new HashMap<>(), false, false, true, false, false, 0, true); + this(StandardCharsets.UTF_8, Locale.ENGLISH, ZoneOffset.UTC, 10, new HashMap<>(), false, false, true, false, false, 0, true, RandomNumberGeneratorStrategy.THREAD_LOCAL); } public JinjavaConfig(Charset charset, Locale locale, ZoneId timeZone, int maxRenderDepth) { - this(charset, locale, timeZone, maxRenderDepth, new HashMap<>(), false, false, true, false, false, 0, true); + this(charset, locale, timeZone, maxRenderDepth, new HashMap<>(), false, false, true, false, false, 0, true, RandomNumberGeneratorStrategy.THREAD_LOCAL); } private JinjavaConfig(Charset charset, @@ -69,7 +71,8 @@ private JinjavaConfig(Charset charset, boolean enableRecursiveMacroCalls, boolean failOnUnknownTokens, long maxOutputSize, - boolean nestedInterpretationEnabled) { + boolean nestedInterpretationEnabled, + RandomNumberGeneratorStrategy randomNumberGenerator) { this.charset = charset; this.locale = locale; this.timeZone = timeZone; @@ -82,6 +85,7 @@ private JinjavaConfig(Charset charset, this.failOnUnknownTokens = failOnUnknownTokens; this.maxOutputSize = maxOutputSize; this.nestedInterpretationEnabled = nestedInterpretationEnabled; + this.randomNumberGenerator = randomNumberGenerator; } public Charset getCharset() { @@ -104,6 +108,10 @@ public long getMaxOutputSize() { return maxOutputSize; } + public RandomNumberGeneratorStrategy getRandomNumberGeneratorStrategy() { + return randomNumberGenerator; + } + public boolean isTrimBlocks() { return trimBlocks; } @@ -147,6 +155,7 @@ public static class Builder { private boolean enableRecursiveMacroCalls; private boolean failOnUnknownTokens; private boolean nestedInterpretationEnabled = true; + private RandomNumberGeneratorStrategy randomNumberGeneratorStrategy = RandomNumberGeneratorStrategy.THREAD_LOCAL; private Builder() {} @@ -175,6 +184,12 @@ public Builder withMaxRenderDepth(int maxRenderDepth) { return this; } + public Builder withRandomNumberGeneratorStrategy(RandomNumberGeneratorStrategy randomNumberGeneratorStrategy) { + this.randomNumberGeneratorStrategy = randomNumberGeneratorStrategy; + return this; + } + + public Builder withTrimBlocks(boolean trimBlocks) { this.trimBlocks = trimBlocks; return this; @@ -211,7 +226,7 @@ public Builder withNestedInterpretationEnabled(boolean nestedInterpretationEnabl } public JinjavaConfig build() { - return new JinjavaConfig(charset, locale, timeZone, maxRenderDepth, disabled, trimBlocks, lstripBlocks, readOnlyResolver, enableRecursiveMacroCalls, failOnUnknownTokens, maxOutputSize, nestedInterpretationEnabled); + return new JinjavaConfig(charset, locale, timeZone, maxRenderDepth, disabled, trimBlocks, lstripBlocks, readOnlyResolver, enableRecursiveMacroCalls, failOnUnknownTokens, maxOutputSize, nestedInterpretationEnabled, randomNumberGeneratorStrategy); } } diff --git a/src/main/java/com/hubspot/jinjava/interpret/JinjavaInterpreter.java b/src/main/java/com/hubspot/jinjava/interpret/JinjavaInterpreter.java index 55a64e74c..869328296 100644 --- a/src/main/java/com/hubspot/jinjava/interpret/JinjavaInterpreter.java +++ b/src/main/java/com/hubspot/jinjava/interpret/JinjavaInterpreter.java @@ -25,8 +25,10 @@ import java.util.Map; import java.util.Objects; import java.util.Optional; +import java.util.Random; import java.util.Set; import java.util.Stack; +import java.util.concurrent.ThreadLocalRandom; import org.apache.commons.lang3.StringUtils; @@ -36,6 +38,8 @@ import com.hubspot.jinjava.Jinjava; import com.hubspot.jinjava.JinjavaConfig; import com.hubspot.jinjava.el.ExpressionResolver; +import com.hubspot.jinjava.random.ConstantZeroRandomNumberGenerator; +import com.hubspot.jinjava.random.RandomNumberGeneratorStrategy; import com.hubspot.jinjava.tree.Node; import com.hubspot.jinjava.tree.TreeParser; import com.hubspot.jinjava.tree.output.BlockPlaceholderOutputNode; @@ -54,6 +58,7 @@ public class JinjavaInterpreter { private final ExpressionResolver expressionResolver; private final Jinjava application; + private final Random random; private int lineNumber = -1; private final List errors = new LinkedList<>(); @@ -63,6 +68,14 @@ public JinjavaInterpreter(Jinjava application, Context context, JinjavaConfig re this.config = renderConfig; this.application = application; + if (config.getRandomNumberGeneratorStrategy() == RandomNumberGeneratorStrategy.THREAD_LOCAL) { + random = ThreadLocalRandom.current(); + } else if (config.getRandomNumberGeneratorStrategy() == RandomNumberGeneratorStrategy.CONSTANT_ZERO) { + random = new ConstantZeroRandomNumberGenerator(); + } else { + throw new IllegalStateException("No random number generator with strategy " + config.getRandomNumberGeneratorStrategy()); + } + this.expressionResolver = new ExpressionResolver(this, application.getExpressionFactory()); } @@ -116,6 +129,10 @@ public void leaveScope() { } } + public Random getRandom() { + return random; + } + public class InterpreterScopeClosable implements AutoCloseable { @Override diff --git a/src/main/java/com/hubspot/jinjava/lib/filter/RandomFilter.java b/src/main/java/com/hubspot/jinjava/lib/filter/RandomFilter.java index 3042f74ac..9da58e3eb 100644 --- a/src/main/java/com/hubspot/jinjava/lib/filter/RandomFilter.java +++ b/src/main/java/com/hubspot/jinjava/lib/filter/RandomFilter.java @@ -20,7 +20,6 @@ import java.util.Collection; import java.util.Iterator; import java.util.Map; -import java.util.concurrent.ThreadLocalRandom; import com.hubspot.jinjava.doc.annotations.JinjavaDoc; import com.hubspot.jinjava.doc.annotations.JinjavaParam; @@ -52,7 +51,7 @@ public Object filter(Object object, JinjavaInterpreter interpreter, String... ar if (size == 0) { return null; } - int index = ThreadLocalRandom.current().nextInt(size); + int index = interpreter.getRandom().nextInt(size); while (index-- > 0) { it.next(); } @@ -64,7 +63,7 @@ public Object filter(Object object, JinjavaInterpreter interpreter, String... ar if (size == 0) { return null; } - int index = ThreadLocalRandom.current().nextInt(size); + int index = interpreter.getRandom().nextInt(size); return Array.get(object, index); } // map @@ -75,7 +74,7 @@ public Object filter(Object object, JinjavaInterpreter interpreter, String... ar if (size == 0) { return null; } - int index = ThreadLocalRandom.current().nextInt(size); + int index = interpreter.getRandom().nextInt(size); while (index-- > 0) { it.next(); } @@ -83,12 +82,12 @@ public Object filter(Object object, JinjavaInterpreter interpreter, String... ar } // number if (object instanceof Number) { - return ThreadLocalRandom.current().nextLong(((Number) object).longValue()); + return interpreter.getRandom().nextInt(((Number) object).intValue()); } // string if (object instanceof String) { try { - return ThreadLocalRandom.current().nextLong(new BigDecimal((String) object).longValue()); + return interpreter.getRandom().nextInt(new BigDecimal((String) object).intValue()); } catch (Exception e) { return 0; } diff --git a/src/main/java/com/hubspot/jinjava/lib/filter/ShuffleFilter.java b/src/main/java/com/hubspot/jinjava/lib/filter/ShuffleFilter.java index d9c0f861b..eed0da0c7 100644 --- a/src/main/java/com/hubspot/jinjava/lib/filter/ShuffleFilter.java +++ b/src/main/java/com/hubspot/jinjava/lib/filter/ShuffleFilter.java @@ -29,8 +29,8 @@ public String getName() { @Override public Object filter(Object var, JinjavaInterpreter interpreter, String... args) { if (var instanceof Collection) { - List list = new ArrayList((Collection) var); - Collections.shuffle(list); + List list = new ArrayList<>((Collection) var); + Collections.shuffle(list, interpreter.getRandom()); return list; } diff --git a/src/main/java/com/hubspot/jinjava/random/ConstantZeroRandomNumberGenerator.java b/src/main/java/com/hubspot/jinjava/random/ConstantZeroRandomNumberGenerator.java new file mode 100644 index 000000000..83eddd05e --- /dev/null +++ b/src/main/java/com/hubspot/jinjava/random/ConstantZeroRandomNumberGenerator.java @@ -0,0 +1,117 @@ +package com.hubspot.jinjava.random; + +import java.util.Random; +import java.util.stream.DoubleStream; +import java.util.stream.IntStream; +import java.util.stream.LongStream; + +/** + * A random number generator that always returns 0. Useful for testing code when you want the output to be constant. + */ +public class ConstantZeroRandomNumberGenerator extends Random { + + @Override + protected int next(int bits) { + return 0; + } + + @Override + public int nextInt() { + return 0; + } + + @Override + public int nextInt(int bound) { + return 0; + } + + @Override + public long nextLong() { + return 0; + } + + @Override + public boolean nextBoolean() { + return false; + } + + @Override + public float nextFloat() { + return 0f; + } + + @Override + public double nextDouble() { + return 0; + } + + @Override + public synchronized double nextGaussian() { + return 0; + } + + @Override + public void nextBytes(byte[] bytes) { + throw new UnsupportedOperationException(); + } + + @Override + public IntStream ints(long streamSize) { + throw new UnsupportedOperationException(); + } + + @Override + public IntStream ints() { + throw new UnsupportedOperationException(); + } + + @Override + public IntStream ints(long streamSize, int randomNumberOrigin, int randomNumberBound) { + throw new UnsupportedOperationException(); + } + + @Override + public IntStream ints(int randomNumberOrigin, int randomNumberBound) { + throw new UnsupportedOperationException(); + } + + @Override + public LongStream longs(long streamSize) { + throw new UnsupportedOperationException(); + } + + @Override + public LongStream longs() { + throw new UnsupportedOperationException(); + } + + @Override + public LongStream longs(long streamSize, long randomNumberOrigin, long randomNumberBound) { + throw new UnsupportedOperationException(); + } + + @Override + public LongStream longs(long randomNumberOrigin, long randomNumberBound) { + throw new UnsupportedOperationException(); + } + + @Override + public DoubleStream doubles(long streamSize) { + throw new UnsupportedOperationException(); + } + + @Override + public DoubleStream doubles() { + throw new UnsupportedOperationException(); + } + + @Override + public DoubleStream doubles(long streamSize, double randomNumberOrigin, double randomNumberBound) { + throw new UnsupportedOperationException(); + } + + @Override + public DoubleStream doubles(double randomNumberOrigin, double randomNumberBound) { + throw new UnsupportedOperationException(); + } +} diff --git a/src/main/java/com/hubspot/jinjava/random/RandomNumberGeneratorStrategy.java b/src/main/java/com/hubspot/jinjava/random/RandomNumberGeneratorStrategy.java new file mode 100644 index 000000000..57ca80ab3 --- /dev/null +++ b/src/main/java/com/hubspot/jinjava/random/RandomNumberGeneratorStrategy.java @@ -0,0 +1,6 @@ +package com.hubspot.jinjava.random; + +public enum RandomNumberGeneratorStrategy { + THREAD_LOCAL, + CONSTANT_ZERO +} diff --git a/src/test/java/com/hubspot/jinjava/lib/filter/ShuffleFilterTest.java b/src/test/java/com/hubspot/jinjava/lib/filter/ShuffleFilterTest.java index 48b4f4bb4..4a4d1fe83 100644 --- a/src/test/java/com/hubspot/jinjava/lib/filter/ShuffleFilterTest.java +++ b/src/test/java/com/hubspot/jinjava/lib/filter/ShuffleFilterTest.java @@ -2,27 +2,31 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.failBecauseExceptionWasNotThrown; +import static org.mockito.Mockito.*; import java.util.Arrays; import java.util.List; +import java.util.concurrent.ThreadLocalRandom; -import org.junit.Before; import org.junit.Test; +import com.hubspot.jinjava.interpret.JinjavaInterpreter; +import com.hubspot.jinjava.random.ConstantZeroRandomNumberGenerator; + public class ShuffleFilterTest { - ShuffleFilter filter; + ShuffleFilter filter = new ShuffleFilter(); - @Before - public void setup() { - this.filter = new ShuffleFilter(); - } + JinjavaInterpreter interpreter = mock(JinjavaInterpreter.class); @SuppressWarnings("unchecked") @Test - public void shuffleItems() { + public void itShufflesItems() { + + when(interpreter.getRandom()).thenReturn(ThreadLocalRandom.current()); + List before = Arrays.asList("1", "2", "3", "4", "5", "6", "7", "8", "9"); - List after = (List) filter.filter(before, null); + List after = (List) filter.filter(before, interpreter); assertThat(before).isSorted(); assertThat(after).containsAll(before); @@ -35,4 +39,19 @@ public void shuffleItems() { } } + @SuppressWarnings("unchecked") + @Test + public void itShufflesConsistentlyWithConstantRandom() { + + when(interpreter.getRandom()).thenReturn(new ConstantZeroRandomNumberGenerator()); + + List before = Arrays.asList("1", "2", "3", "4", "5", "6", "7", "8", "9"); + List after = (List) filter.filter(before, interpreter); + + assertThat(before).isSorted(); + assertThat(after).containsAll(before); + + assertThat(after).containsExactly("2", "3", "4", "5", "6", "7", "8", "9", "1"); + } + }