Skip to content

Commit

Permalink
Migrate tests to JUnit
Browse files Browse the repository at this point in the history
  • Loading branch information
martint committed Oct 5, 2023
1 parent 39fe9da commit 9a7cf10
Show file tree
Hide file tree
Showing 96 changed files with 777 additions and 549 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,10 @@
import io.trino.spi.QueryId;
import io.trino.spiller.SpillSpaceTracker;
import io.trino.sql.planner.plan.PlanNodeId;
import org.testng.annotations.AfterClass;
import org.testng.annotations.BeforeClass;
import org.testng.annotations.BeforeMethod;
import org.testng.annotations.Test;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.TestInstance;

import java.util.List;
import java.util.concurrent.ExecutorService;
Expand All @@ -48,11 +48,12 @@
import static java.util.concurrent.Executors.newCachedThreadPool;
import static java.util.concurrent.Executors.newScheduledThreadPool;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.junit.jupiter.api.TestInstance.Lifecycle.PER_METHOD;
import static org.testng.Assert.assertEquals;
import static org.testng.Assert.assertFalse;
import static org.testng.Assert.assertTrue;

@Test(singleThreaded = true)
@TestInstance(PER_METHOD)
public class TestMemoryTracking
{
private static final DataSize queryMaxMemory = DataSize.of(1, GIGABYTE);
Expand All @@ -70,14 +71,7 @@ public class TestMemoryTracking
private ExecutorService notificationExecutor;
private ScheduledExecutorService yieldExecutor;

@BeforeClass
public void setUp()
{
notificationExecutor = newCachedThreadPool(daemonThreadsNamed("local-query-runner-executor-%s"));
yieldExecutor = newScheduledThreadPool(2, daemonThreadsNamed("local-query-runner-scheduler-%s"));
}

@AfterClass(alwaysRun = true)
@AfterEach
public void tearDown()
{
notificationExecutor.shutdownNow();
Expand All @@ -90,9 +84,12 @@ public void tearDown()
memoryPool = null;
}

@BeforeMethod
@BeforeEach
public void setUpTest()
{
notificationExecutor = newCachedThreadPool(daemonThreadsNamed("local-query-runner-executor-%s"));
yieldExecutor = newScheduledThreadPool(2, daemonThreadsNamed("local-query-runner-scheduler-%s"));

memoryPool = new MemoryPool(memoryPoolSize);
queryContext = new QueryContext(
new QueryId("test_query"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,9 @@
import io.trino.sql.gen.JoinCompiler;
import io.trino.sql.planner.plan.PlanNodeId;
import io.trino.testing.MaterializedResult;
import org.testng.annotations.AfterClass;
import org.testng.annotations.Test;
import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.TestInstance;

import java.util.List;
import java.util.Optional;
Expand All @@ -43,15 +44,17 @@
import static io.trino.testing.TestingTaskContext.createTaskContext;
import static java.util.concurrent.Executors.newCachedThreadPool;
import static java.util.concurrent.Executors.newScheduledThreadPool;
import static org.junit.jupiter.api.TestInstance.Lifecycle.PER_CLASS;
import static org.testng.Assert.assertEquals;

@TestInstance(PER_CLASS)
public class TestDistinctLimitOperator
{
private final ExecutorService executor = newCachedThreadPool(daemonThreadsNamed(getClass().getSimpleName() + "-%s"));
private final ScheduledExecutorService scheduledExecutor = newScheduledThreadPool(2, daemonThreadsNamed(getClass().getSimpleName() + "-scheduledExecutor-%s"));
private final JoinCompiler joinCompiler = new JoinCompiler(new TypeOperators());

@AfterClass(alwaysRun = true)
@AfterAll
public void tearDown()
{
executor.shutdownNow();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,10 @@
import io.trino.spi.type.Type;
import io.trino.split.RemoteSplit;
import io.trino.sql.planner.plan.PlanNodeId;
import org.testng.annotations.AfterClass;
import org.testng.annotations.BeforeClass;
import org.testng.annotations.BeforeMethod;
import org.testng.annotations.Test;
import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.TestInstance;

import java.util.ArrayList;
import java.util.List;
Expand All @@ -57,11 +57,12 @@
import static io.trino.spi.type.VarcharType.VARCHAR;
import static io.trino.testing.TestingTaskContext.createTaskContext;
import static java.util.concurrent.Executors.newScheduledThreadPool;
import static org.junit.jupiter.api.TestInstance.Lifecycle.PER_CLASS;
import static org.testng.Assert.assertEquals;
import static org.testng.Assert.assertNull;
import static org.testng.Assert.assertTrue;

@Test(singleThreaded = true)
@TestInstance(PER_CLASS)
public class TestExchangeOperator
{
private static final List<Type> TYPES = ImmutableList.of(VARCHAR);
Expand All @@ -81,7 +82,7 @@ public class TestExchangeOperator
private ExecutorService pageBufferClientCallbackExecutor;

@SuppressWarnings("resource")
@BeforeClass
@BeforeAll
public void setUp()
{
scheduler = newScheduledThreadPool(4, daemonThreadsNamed(getClass().getSimpleName() + "-%s"));
Expand All @@ -104,7 +105,7 @@ public void setUp()
taskFailureListener);
}

@AfterClass(alwaysRun = true)
@AfterAll
public void tearDown()
{
httpClient.close();
Expand All @@ -120,17 +121,12 @@ public void tearDown()
pageBufferClientCallbackExecutor = null;
}

@BeforeMethod
public void setUpMethod()
{
// the test class is single-threaded, so there should be no ongoing loads and invalidation should be effective
taskBuffers.invalidateAll();
}

@Test
public void testSimple()
throws Exception
{
taskBuffers.invalidateAll();

SourceOperator operator = createExchangeOperator();

operator.addSplit(newRemoteSplit(TASK_1_ID));
Expand Down Expand Up @@ -159,6 +155,8 @@ private static Split newRemoteSplit(TaskId taskId)
public void testWaitForClose()
throws Exception
{
taskBuffers.invalidateAll();

SourceOperator operator = createExchangeOperator();

operator.addSplit(newRemoteSplit(TASK_1_ID));
Expand Down Expand Up @@ -195,6 +193,8 @@ public void testWaitForClose()
public void testWaitForNoMoreSplits()
throws Exception
{
taskBuffers.invalidateAll();

SourceOperator operator = createExchangeOperator();

// add a buffer location containing one page and close the buffer
Expand Down Expand Up @@ -228,6 +228,8 @@ public void testWaitForNoMoreSplits()
public void testFinish()
throws Exception
{
taskBuffers.invalidateAll();

SourceOperator operator = createExchangeOperator();

operator.addSplit(newRemoteSplit(TASK_1_ID));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
import io.trino.spi.type.TypeOperators;
import io.trino.sql.gen.JoinCompiler;
import io.trino.testing.TestingSession;
import org.testng.annotations.Test;
import org.junit.jupiter.api.Test;

import java.util.ArrayList;
import java.util.List;
Expand Down
37 changes: 8 additions & 29 deletions plugin/trino-ml/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -130,36 +130,15 @@
</dependency>

<dependency>
<groupId>org.testng</groupId>
<artifactId>testng</artifactId>
<groupId>org.assertj</groupId>
<artifactId>assertj-core</artifactId>
<scope>test</scope>
</dependency>
</dependencies>

<build>
<plugins>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-surefire-plugin</artifactId>
<dependencies>
<!-- allow both JUnit and TestNG -->
<dependency>
<groupId>org.apache.maven.surefire</groupId>
<artifactId>surefire-junit-platform</artifactId>
<version>${dep.plugin.surefire.version}</version>
</dependency>
<dependency>
<groupId>org.apache.maven.surefire</groupId>
<artifactId>surefire-testng</artifactId>
<version>${dep.plugin.surefire.version}</version>
</dependency>
<dependency>
<groupId>org.junit.jupiter</groupId>
<artifactId>junit-jupiter-engine</artifactId>
<version>${dep.junit.version}</version>
</dependency>
</dependencies>
</plugin>
</plugins>
</build>
<dependency>
<groupId>org.junit.jupiter</groupId>
<artifactId>junit-jupiter-api</artifactId>
<scope>test</scope>
</dependency>
</dependencies>
</project>
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import io.trino.spi.Page;
import io.trino.spi.block.Block;
import io.trino.spi.block.BlockBuilder;
import org.testng.annotations.Test;
import org.junit.jupiter.api.Test;

import java.util.List;
import java.util.OptionalInt;
Expand All @@ -32,7 +32,7 @@
import static io.trino.spi.type.VarcharType.VARCHAR;
import static io.trino.sql.analyzer.TypeSignatureProvider.fromTypes;
import static io.trino.sql.planner.plan.AggregationNode.Step.SINGLE;
import static org.testng.Assert.assertEquals;
import static org.assertj.core.api.Assertions.assertThat;

public class TestEvaluateClassifierPredictions
{
Expand All @@ -49,8 +49,10 @@ public void testEvaluateClassifierPredictions()

String output = VARCHAR.getSlice(block, 0).toStringUtf8();
List<String> parts = ImmutableList.copyOf(Splitter.on('\n').omitEmptyStrings().split(output));
assertEquals(parts.size(), 7, output);
assertEquals(parts.get(0), "Accuracy: 1/2 (50.00%)");
assertThat(parts.size())
.describedAs(output)
.isEqualTo(7);
assertThat(parts.get(0)).isEqualTo("Accuracy: 1/2 (50.00%)");
}

private static Page getPage()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,15 @@

import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import org.testng.annotations.Test;
import org.junit.jupiter.api.Test;

import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Set;

import static io.trino.plugin.ml.TestUtils.getDataset;
import static org.testng.Assert.assertEquals;
import static org.testng.Assert.assertTrue;
import static org.assertj.core.api.Assertions.assertThat;

public class TestFeatureTransformations
{
Expand All @@ -43,11 +42,11 @@ public void testUnitNormalizer()
}
}
// Make sure there is a feature that needs to be normalized
assertTrue(valueGreaterThanOne);
assertThat(valueGreaterThanOne).isTrue();
transformation.train(dataset);
for (FeatureVector vector : transformation.transform(dataset).getDatapoints()) {
for (double value : vector.getFeatures().values()) {
assertTrue(value <= 1);
assertThat(value <= 1).isTrue();
}
}
}
Expand All @@ -69,6 +68,6 @@ public void testUnitNormalizerSimple()
for (FeatureVector vector : transformation.transform(dataset).getDatapoints()) {
featureValues.addAll(vector.getFeatures().values());
}
assertEquals(featureValues, ImmutableSet.of(0.0, 0.5, 1.0));
assertThat(featureValues).isEqualTo(ImmutableSet.of(0.0, 0.5, 1.0));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
import io.trino.spi.type.Type;
import io.trino.sql.PlannerContext;
import io.trino.transaction.TransactionManager;
import org.testng.annotations.Test;
import org.junit.jupiter.api.Test;

import java.util.OptionalInt;
import java.util.Random;
Expand All @@ -45,8 +45,7 @@
import static io.trino.testing.StructuralTestUtil.mapBlockOf;
import static io.trino.transaction.InMemoryTransactionManager.createTestTransactionManager;
import static io.trino.type.InternalTypeManager.TESTING_TYPE_MANAGER;
import static org.testng.Assert.assertNotNull;
import static org.testng.Assert.assertTrue;
import static org.assertj.core.api.Assertions.assertThat;

public class TestLearnAggregations
{
Expand Down Expand Up @@ -90,8 +89,11 @@ private static void assertLearnClassifier(Aggregator aggregator)
Block block = finalOut.build();
Slice slice = aggregator.getType().getSlice(block, 0);
Model deserialized = ModelUtils.deserialize(slice);
assertNotNull(deserialized, "deserialization failed");
assertTrue(deserialized instanceof Classifier, "deserialized model is not a classifier");
assertThat(deserialized)
.describedAs("deserialization failed")
.isNotNull();

assertThat(deserialized).isInstanceOf(Classifier.class);
}

private static Page getPage()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import io.trino.testing.AbstractTestQueryFramework;
import io.trino.testing.LocalQueryRunner;
import io.trino.testing.QueryRunner;
import org.testng.annotations.Test;
import org.junit.jupiter.api.Test;

import static io.trino.plugin.tpch.TpchMetadata.TINY_SCHEMA_NAME;
import static io.trino.testing.TestingSession.testSessionBuilder;
Expand Down
Loading

0 comments on commit 9a7cf10

Please sign in to comment.