Skip to content

Commit

Permalink
Add configurable random number generator
Browse files Browse the repository at this point in the history
  • Loading branch information
jboulter committed Aug 2, 2017
1 parent b5bee8e commit a3b9053
Show file tree
Hide file tree
Showing 7 changed files with 193 additions and 20 deletions.
23 changes: 19 additions & 4 deletions src/main/java/com/hubspot/jinjava/JinjavaConfig.java
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

import com.hubspot.jinjava.interpret.Context;
import com.hubspot.jinjava.interpret.Context.Library;
import com.hubspot.jinjava.random.RandomNumberGeneratorStrategy;

public class JinjavaConfig {

Expand All @@ -44,17 +45,18 @@ public class JinjavaConfig {
private Map<Context.Library, Set<String>> disabled;
private final boolean failOnUnknownTokens;
private final boolean nestedInterpretationEnabled;
private final RandomNumberGeneratorStrategy randomNumberGenerator;

public static Builder newBuilder() {
return new Builder();
}

public JinjavaConfig() {
this(StandardCharsets.UTF_8, Locale.ENGLISH, ZoneOffset.UTC, 10, new HashMap<>(), false, false, true, false, false, 0, true);
this(StandardCharsets.UTF_8, Locale.ENGLISH, ZoneOffset.UTC, 10, new HashMap<>(), false, false, true, false, false, 0, true, RandomNumberGeneratorStrategy.THREAD_LOCAL);
}

public JinjavaConfig(Charset charset, Locale locale, ZoneId timeZone, int maxRenderDepth) {
this(charset, locale, timeZone, maxRenderDepth, new HashMap<>(), false, false, true, false, false, 0, true);
this(charset, locale, timeZone, maxRenderDepth, new HashMap<>(), false, false, true, false, false, 0, true, RandomNumberGeneratorStrategy.THREAD_LOCAL);
}

private JinjavaConfig(Charset charset,
Expand All @@ -69,7 +71,8 @@ private JinjavaConfig(Charset charset,
boolean enableRecursiveMacroCalls,
boolean failOnUnknownTokens,
long maxOutputSize,
boolean nestedInterpretationEnabled) {
boolean nestedInterpretationEnabled,
RandomNumberGeneratorStrategy randomNumberGenerator) {
this.charset = charset;
this.locale = locale;
this.timeZone = timeZone;
Expand All @@ -82,6 +85,7 @@ private JinjavaConfig(Charset charset,
this.failOnUnknownTokens = failOnUnknownTokens;
this.maxOutputSize = maxOutputSize;
this.nestedInterpretationEnabled = nestedInterpretationEnabled;
this.randomNumberGenerator = randomNumberGenerator;
}

public Charset getCharset() {
Expand All @@ -104,6 +108,10 @@ public long getMaxOutputSize() {
return maxOutputSize;
}

public RandomNumberGeneratorStrategy getRandomNumberGeneratorStrategy() {
return randomNumberGenerator;
}

public boolean isTrimBlocks() {
return trimBlocks;
}
Expand Down Expand Up @@ -147,6 +155,7 @@ public static class Builder {
private boolean enableRecursiveMacroCalls;
private boolean failOnUnknownTokens;
private boolean nestedInterpretationEnabled = true;
private RandomNumberGeneratorStrategy randomNumberGeneratorStrategy = RandomNumberGeneratorStrategy.THREAD_LOCAL;

private Builder() {}

Expand Down Expand Up @@ -175,6 +184,12 @@ public Builder withMaxRenderDepth(int maxRenderDepth) {
return this;
}

public Builder withRandomNumberGeneratorStrategy(RandomNumberGeneratorStrategy randomNumberGeneratorStrategy) {
this.randomNumberGeneratorStrategy = randomNumberGeneratorStrategy;
return this;
}


public Builder withTrimBlocks(boolean trimBlocks) {
this.trimBlocks = trimBlocks;
return this;
Expand Down Expand Up @@ -211,7 +226,7 @@ public Builder withNestedInterpretationEnabled(boolean nestedInterpretationEnabl
}

public JinjavaConfig build() {
return new JinjavaConfig(charset, locale, timeZone, maxRenderDepth, disabled, trimBlocks, lstripBlocks, readOnlyResolver, enableRecursiveMacroCalls, failOnUnknownTokens, maxOutputSize, nestedInterpretationEnabled);
return new JinjavaConfig(charset, locale, timeZone, maxRenderDepth, disabled, trimBlocks, lstripBlocks, readOnlyResolver, enableRecursiveMacroCalls, failOnUnknownTokens, maxOutputSize, nestedInterpretationEnabled, randomNumberGeneratorStrategy);
}

}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,10 @@
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Random;
import java.util.Set;
import java.util.Stack;
import java.util.concurrent.ThreadLocalRandom;

import org.apache.commons.lang3.StringUtils;

Expand All @@ -36,6 +38,8 @@
import com.hubspot.jinjava.Jinjava;
import com.hubspot.jinjava.JinjavaConfig;
import com.hubspot.jinjava.el.ExpressionResolver;
import com.hubspot.jinjava.random.ConstantZeroRandomNumberGenerator;
import com.hubspot.jinjava.random.RandomNumberGeneratorStrategy;
import com.hubspot.jinjava.tree.Node;
import com.hubspot.jinjava.tree.TreeParser;
import com.hubspot.jinjava.tree.output.BlockPlaceholderOutputNode;
Expand All @@ -54,6 +58,7 @@ public class JinjavaInterpreter {

private final ExpressionResolver expressionResolver;
private final Jinjava application;
private final Random random;

private int lineNumber = -1;
private final List<TemplateError> errors = new LinkedList<>();
Expand All @@ -63,6 +68,14 @@ public JinjavaInterpreter(Jinjava application, Context context, JinjavaConfig re
this.config = renderConfig;
this.application = application;

if (config.getRandomNumberGeneratorStrategy() == RandomNumberGeneratorStrategy.THREAD_LOCAL) {
random = ThreadLocalRandom.current();
} else if (config.getRandomNumberGeneratorStrategy() == RandomNumberGeneratorStrategy.CONSTANT_ZERO) {
random = new ConstantZeroRandomNumberGenerator();
} else {
throw new IllegalStateException("No random number generator with strategy " + config.getRandomNumberGeneratorStrategy());
}

this.expressionResolver = new ExpressionResolver(this, application.getExpressionFactory());
}

Expand Down Expand Up @@ -116,6 +129,10 @@ public void leaveScope() {
}
}

public Random getRandom() {
return random;
}

public class InterpreterScopeClosable implements AutoCloseable {

@Override
Expand Down
11 changes: 5 additions & 6 deletions src/main/java/com/hubspot/jinjava/lib/filter/RandomFilter.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
import java.util.Collection;
import java.util.Iterator;
import java.util.Map;
import java.util.concurrent.ThreadLocalRandom;

import com.hubspot.jinjava.doc.annotations.JinjavaDoc;
import com.hubspot.jinjava.doc.annotations.JinjavaParam;
Expand Down Expand Up @@ -52,7 +51,7 @@ public Object filter(Object object, JinjavaInterpreter interpreter, String... ar
if (size == 0) {
return null;
}
int index = ThreadLocalRandom.current().nextInt(size);
int index = interpreter.getRandom().nextInt(size);
while (index-- > 0) {
it.next();
}
Expand All @@ -64,7 +63,7 @@ public Object filter(Object object, JinjavaInterpreter interpreter, String... ar
if (size == 0) {
return null;
}
int index = ThreadLocalRandom.current().nextInt(size);
int index = interpreter.getRandom().nextInt(size);
return Array.get(object, index);
}
// map
Expand All @@ -75,20 +74,20 @@ public Object filter(Object object, JinjavaInterpreter interpreter, String... ar
if (size == 0) {
return null;
}
int index = ThreadLocalRandom.current().nextInt(size);
int index = interpreter.getRandom().nextInt(size);
while (index-- > 0) {
it.next();
}
return it.next();
}
// number
if (object instanceof Number) {
return ThreadLocalRandom.current().nextLong(((Number) object).longValue());
return interpreter.getRandom().nextInt(((Number) object).intValue());
}
// string
if (object instanceof String) {
try {
return ThreadLocalRandom.current().nextLong(new BigDecimal((String) object).longValue());
return interpreter.getRandom().nextInt(new BigDecimal((String) object).intValue());
} catch (Exception e) {
return 0;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ public String getName() {
@Override
public Object filter(Object var, JinjavaInterpreter interpreter, String... args) {
if (var instanceof Collection) {
List<?> list = new ArrayList<Object>((Collection<Object>) var);
Collections.shuffle(list);
List<?> list = new ArrayList<>((Collection<Object>) var);
Collections.shuffle(list, interpreter.getRandom());
return list;
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
package com.hubspot.jinjava.random;

import java.util.Random;
import java.util.stream.DoubleStream;
import java.util.stream.IntStream;
import java.util.stream.LongStream;

/**
* A random number generator that always returns 0. Useful for testing code when you want the output to be constant.
*/
public class ConstantZeroRandomNumberGenerator extends Random {

@Override
protected int next(int bits) {
return 0;
}

@Override
public int nextInt() {
return 0;
}

@Override
public int nextInt(int bound) {
return 0;
}

@Override
public long nextLong() {
return 0;
}

@Override
public boolean nextBoolean() {
return false;
}

@Override
public float nextFloat() {
return 0f;
}

@Override
public double nextDouble() {
return 0;
}

@Override
public synchronized double nextGaussian() {
return 0;
}

@Override
public void nextBytes(byte[] bytes) {
throw new UnsupportedOperationException();
}

@Override
public IntStream ints(long streamSize) {
throw new UnsupportedOperationException();
}

@Override
public IntStream ints() {
throw new UnsupportedOperationException();
}

@Override
public IntStream ints(long streamSize, int randomNumberOrigin, int randomNumberBound) {
throw new UnsupportedOperationException();
}

@Override
public IntStream ints(int randomNumberOrigin, int randomNumberBound) {
throw new UnsupportedOperationException();
}

@Override
public LongStream longs(long streamSize) {
throw new UnsupportedOperationException();
}

@Override
public LongStream longs() {
throw new UnsupportedOperationException();
}

@Override
public LongStream longs(long streamSize, long randomNumberOrigin, long randomNumberBound) {
throw new UnsupportedOperationException();
}

@Override
public LongStream longs(long randomNumberOrigin, long randomNumberBound) {
throw new UnsupportedOperationException();
}

@Override
public DoubleStream doubles(long streamSize) {
throw new UnsupportedOperationException();
}

@Override
public DoubleStream doubles() {
throw new UnsupportedOperationException();
}

@Override
public DoubleStream doubles(long streamSize, double randomNumberOrigin, double randomNumberBound) {
throw new UnsupportedOperationException();
}

@Override
public DoubleStream doubles(double randomNumberOrigin, double randomNumberBound) {
throw new UnsupportedOperationException();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
package com.hubspot.jinjava.random;

public enum RandomNumberGeneratorStrategy {
THREAD_LOCAL,
CONSTANT_ZERO
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,27 +2,31 @@

import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.failBecauseExceptionWasNotThrown;
import static org.mockito.Mockito.*;

import java.util.Arrays;
import java.util.List;
import java.util.concurrent.ThreadLocalRandom;

import org.junit.Before;
import org.junit.Test;

import com.hubspot.jinjava.interpret.JinjavaInterpreter;
import com.hubspot.jinjava.random.ConstantZeroRandomNumberGenerator;

public class ShuffleFilterTest {

ShuffleFilter filter;
ShuffleFilter filter = new ShuffleFilter();

@Before
public void setup() {
this.filter = new ShuffleFilter();
}
JinjavaInterpreter interpreter = mock(JinjavaInterpreter.class);

@SuppressWarnings("unchecked")
@Test
public void shuffleItems() {
public void itShufflesItems() {

when(interpreter.getRandom()).thenReturn(ThreadLocalRandom.current());

List<String> before = Arrays.asList("1", "2", "3", "4", "5", "6", "7", "8", "9");
List<String> after = (List<String>) filter.filter(before, null);
List<String> after = (List<String>) filter.filter(before, interpreter);

assertThat(before).isSorted();
assertThat(after).containsAll(before);
Expand All @@ -35,4 +39,19 @@ public void shuffleItems() {
}
}

@SuppressWarnings("unchecked")
@Test
public void itShufflesConsistentlyWithConstantRandom() {

when(interpreter.getRandom()).thenReturn(new ConstantZeroRandomNumberGenerator());

List<String> before = Arrays.asList("1", "2", "3", "4", "5", "6", "7", "8", "9");
List<String> after = (List<String>) filter.filter(before, interpreter);

assertThat(before).isSorted();
assertThat(after).containsAll(before);

assertThat(after).containsExactly("2", "3", "4", "5", "6", "7", "8", "9", "1");
}

}

0 comments on commit a3b9053

Please sign in to comment.