Skip to content

Commit

Permalink
added tests
Browse files Browse the repository at this point in the history
  • Loading branch information
vga91 committed Jan 16, 2025
1 parent 23fe06b commit 49e5315
Show file tree
Hide file tree
Showing 6 changed files with 255 additions and 199 deletions.
91 changes: 40 additions & 51 deletions extended-it/src/test/java/apoc/load/MySQLJdbcTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ public static class MySQLJdbcLatestVersionTest {
public static void setUpContainer() {
mysql.start();
TestUtil.registerProcedure(db, Jdbc.class, Analytics.class);
String movies = Util.readResourceFile(MOVIES_CYPHER_FILE);
String movies = Util.readResourceFile(ANALYTICS_CYPHER_FILE);
try (Transaction tx = db.beginTx()) {
tx.execute(movies);
tx.commit();
Expand All @@ -68,42 +68,42 @@ public void testIssue3496() {

@Test
public void testLoadJdbcAnalytics() {
String cypher = "MATCH (n:Movie) RETURN n.title AS title, n.released AS released, n.language AS language, n.tagline AS tagline";
String cypher = "MATCH (n:City) RETURN n.country AS country, n.name AS name, n.year AS year, n.population AS population";

String sql = """
SELECT
title,
released,
language,
tagline,
RANK() OVER (PARTITION BY language ORDER BY released DESC) AS 'rank'
FROM temp_table
ORDER BY title, tagline
""";
country,
name,
year,
population,
RANK() OVER (PARTITION BY country ORDER BY year DESC) AS 'rank'
FROM %s
ORDER BY country, name;
""".formatted(Analytics.TABLE_NAME_DEFAULT_CONF_KEY);
testResult(db, "CALL apoc.load.jdbc.analytics($queryCypher, $url, $sql, [], $config)",
map(
"queryCypher", cypher,
"sql", sql,
"url", mysql.getJdbcUrl(),
"config", map(PROVIDER_CONF_KEY, Analytics.Provider.MYSQL.name())
),
r -> commonAnalyticsAssertions(r, "1"));
r -> commonAnalyticsAssertions(r, "1", "3", "5"));
}

@Test
public void testLoadJdbcAnalyticsWindow() {
String cypher = "MATCH (n:Movie) RETURN n.title AS title, n.released AS released, n.language AS language, n.tagline AS tagline, n.qty AS qty";
String cypher = "MATCH (n:City) RETURN n.country AS country, n.name AS name, n.year AS year, n.population AS population";

String sql = """
SELECT
title,
released,
language,
tagline,
ROW_NUMBER() OVER (PARTITION BY language ORDER BY released DESC) AS 'rank'
FROM temp_table
ORDER BY title, tagline
""";
country,
name,
year,
population,
ROW_NUMBER() OVER (PARTITION BY country ORDER BY year DESC) AS 'rank'
FROM %s
ORDER BY country, name;
""".formatted(Analytics.TABLE_NAME_DEFAULT_CONF_KEY);

testResult(db, "CALL apoc.load.jdbc.analytics($queryCypher, $url, $sql, [], $config)",
map(
Expand All @@ -112,39 +112,28 @@ public void testLoadJdbcAnalyticsWindow() {
"url", mysql.getJdbcUrl(),
"config", map(PROVIDER_CONF_KEY, Analytics.Provider.MYSQL.name())
),
r -> commonAnalyticsAssertions(r, "2"));
r -> commonAnalyticsAssertions(r, "2", "4", "6"));
}

private static void commonAnalyticsAssertions(Result r, String expected4thResult) {
Map<String, Object> row = r.next();
var result = (Map) row.get("row");
var rank = (String) result.get("rank");
assertEquals("1", rank);

row = r.next();
result = (Map) row.get("row");
rank = (String) result.get("rank");
assertEquals("4", rank);

row = r.next();
result = (Map) row.get("row");
rank = (String) result.get("rank");
assertEquals("3", rank);

row = r.next();
result = (Map) row.get("row");
rank = (String) result.get("rank");
assertEquals("1", rank);

row = r.next();
result = (Map) row.get("row");
rank = (String) result.get("rank");
assertEquals(expected4thResult, rank);

row = r.next();
result = (Map) row.get("row");
rank = (String) result.get("rank");
assertEquals("2", rank);
private static void commonAnalyticsAssertions(Result r,
String expected4thResult, String expected5thResult, String expected6thResult) {
assertRowRank(r.next(), "1");

assertRowRank(r.next(), "2");

assertRowRank(r.next(), "3");

assertRowRank(r.next(), expected4thResult);

assertRowRank(r.next(), expected5thResult);

assertRowRank(r.next(), expected6thResult);

assertRowRank(r.next(), "1");

assertRowRank(r.next(), "3");

assertRowRank(r.next(), "5");

assertFalse(r.hasNext());
}
Expand Down
87 changes: 38 additions & 49 deletions extended-it/src/test/java/apoc/load/PostgresJdbcTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ public static void setUp() throws Exception {
TestUtil.registerProcedure(db,Jdbc.class, Periodic.class, Strings.class, Analytics.class);
db.executeTransactionally("CALL apoc.load.driver('org.postgresql.Driver')");

String movies = Util.readResourceFile(MOVIES_CYPHER_FILE);
String movies = Util.readResourceFile(ANALYTICS_CYPHER_FILE);
try (Transaction tx = db.beginTx()) {
tx.execute(movies);
tx.commit();
Expand Down Expand Up @@ -145,18 +145,19 @@ public void testIssue4141PeriodicIterateWithJdbc() throws Exception {

@Test
public void testLoadJdbcAnalytics() {
String cypher = "MATCH (n:Movie) RETURN n.title AS title, n.released AS released, n.language AS language, n.tagline AS tagline";
String cypher = "MATCH (n:City) RETURN n.country AS country, n.name AS name, n.year AS year, n.population AS population";

String sql = """
SELECT
title,
released,
language,
tagline,
RANK() OVER (PARTITION BY language ORDER BY released DESC) rank
FROM temp_table
ORDER BY rank, title, tagline;
""";
SELECT
country,
name,
year,
population,
RANK() OVER (PARTITION BY country ORDER BY year DESC) rank
FROM %s
ORDER BY rank, country, name;
""".formatted(Analytics.TABLE_NAME_DEFAULT_CONF_KEY);

testResult(db, "CALL apoc.load.jdbc.analytics($queryCypher, $url, $sql, [], $config)",
map(
"queryCypher", cypher,
Expand All @@ -169,18 +170,18 @@ public void testLoadJdbcAnalytics() {

@Test
public void testLoadJdbcAnalyticsWindow() {
String cypher = "MATCH (n:Movie) RETURN n.title AS title, n.released AS released, n.language AS language, n.tagline AS tagline, n.qty AS qty";
String cypher = "MATCH (n:City) RETURN n.country AS country, n.name AS name, n.year AS year, n.population AS population";

String sql = """
SELECT
title,
released,
language,
tagline,
ROW_NUMBER() OVER (PARTITION BY language ORDER BY released DESC) rank
FROM temp_table
ORDER BY rank, title, tagline
""";
country,
name,
year,
population,
ROW_NUMBER() OVER (PARTITION BY country ORDER BY year DESC) rank
FROM %s
ORDER BY rank, country, name;
""".formatted(Analytics.TABLE_NAME_DEFAULT_CONF_KEY);

testResult(db, "CALL apoc.load.jdbc.analytics($queryCypher, $url, $sql, [], $config)",
map(
Expand All @@ -193,35 +194,23 @@ public void testLoadJdbcAnalyticsWindow() {
}

private static void commonAnalyticsAssertions(Result r, int expected3rdResult) {
Map<String, Object> row = r.next();
var result = (Map) row.get("row");
var rank = (long) result.get("rank");
assertEquals(1, rank);

row = r.next();
result = (Map) row.get("row");
rank = (long) result.get("rank");
assertEquals(1, rank);

row = r.next();
result = (Map) row.get("row");
rank = (long) result.get("rank");
assertEquals(expected3rdResult, rank);

row = r.next();
result = (Map) row.get("row");
rank = (long) result.get("rank");
assertEquals(2, rank);

row = r.next();
result = (Map) row.get("row");
rank = (long) result.get("rank");
assertEquals(3, rank);

row = r.next();
result = (Map) row.get("row");
rank = (long) result.get("rank");
assertEquals(4, rank);
assertRowRank(r.next(), 1);

assertRowRank(r.next(), 1);

assertRowRank(r.next(), expected3rdResult);

assertRowRank(r.next(), 2);

assertRowRank(r.next(), 3);

assertRowRank(r.next(), 3);

assertRowRank(r.next(), 4);

assertRowRank(r.next(), 5);

assertRowRank(r.next(), 6);

assertFalse(r.hasNext());
}
Expand Down
42 changes: 27 additions & 15 deletions extended/src/main/java/apoc/load/Analytics.java
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,9 @@

@Extended
public class Analytics {

public static final String PROVIDER_CONF_KEY = "provider";
public static final String TABLE_NAME_CONF_KEY = "tableName";
public static final String TABLE_NAME_DEFAULT_CONF_KEY = "neo4j_tmp_table";

enum Provider {
POSTGRES,
Expand All @@ -52,9 +53,10 @@ public Stream<RowResult> aggregate(
@Name("jdbc") String urlOrKey,
@Name("sqlQuery") String sqlQuery,
@Name(value = "params", defaultValue = "[]") List<Object> params,
@Name(value = "config",defaultValue = "{}") Map<String, Object> config) throws Exception {
@Name(value = "config",defaultValue = "{}") Map<String, Object> config) {
AtomicReference<String> createTable = new AtomicReference<>();
final Provider provider = Provider.valueOf((String) config.getOrDefault(PROVIDER_CONF_KEY, Provider.DUCKDB.name()));
final String tableName = (String) config.getOrDefault(TABLE_NAME_CONF_KEY, TABLE_NAME_DEFAULT_CONF_KEY);

AtomicReference<String> columns = new AtomicReference<>();
AtomicReference<String> queryInsert = new AtomicReference<>();
Expand All @@ -66,7 +68,7 @@ public Stream<RowResult> aggregate(
result.forEachRemaining(map -> {

if (createTable.get() == null) {
String tempTableClause = getTempTableClause(map, provider);
String tempTableClause = getTempTableClause(map, provider, tableName);
createTable.set(tempTableClause);
}

Expand All @@ -82,7 +84,11 @@ public Stream<RowResult> aggregate(
});

// add values to `INSERT INTO ...` clause
queryInsert.set("INSERT INTO temp_table VALUES " + StringUtils.join(sqlValuesForQueryInsert, ","));
String sqlValues = StringUtils.join(sqlValuesForQueryInsert, ",");
String insertClause = String.format("INSERT INTO %s VALUES %s",
tableName, sqlValues
);
queryInsert.set(insertClause);

// columns to handle error msg
String neo4jResultColumns = result.columns().stream()
Expand All @@ -94,34 +100,40 @@ public Stream<RowResult> aggregate(

String url = getUrlOrKey(urlOrKey);
LoadJdbcConfig jdbcConfig = new LoadJdbcConfig(config);
Connection connection = (Connection) getConnection(url, jdbcConfig, Connection.class);
Connection connection;
try {
connection = (Connection) getConnection(url, jdbcConfig, Connection.class);
} catch (Exception e) {
throw new RuntimeException("Connection error", e);
}

Object[] paramsArray = params.toArray(new Object[params.size()]);

// Step 1. Create temporary table
executeUpdate(urlOrKey, createTable.get(), config, connection, log, paramsArray);

// Step 2. Insert data in temp table
executeUpdate(urlOrKey, queryInsert.get(), config, connection, log, paramsArray);

try {
// Step 1. Create temporary table
executeUpdate(urlOrKey, createTable.get(), config, connection, log, paramsArray);

// Step 2. Insert data in temp table
executeUpdate(urlOrKey, queryInsert.get(), config, connection, log, paramsArray);

// Step 3. Return data from temp table
return executeQuery(urlOrKey, sqlQuery, config, connection, log, paramsArray);
} catch (Exception e) {
throw new RuntimeException(String.format("Make sure the SQL is consistent with Cypher query which has columns: %s", columns.get()), e);
String checkColConsistency = String.format("Make sure the SQL is consistent with Cypher query which has columns: %s", columns.get());
throw new RuntimeException(checkColConsistency, e);
}
}

/**
* add fields to be added to the insert temp table clause
* e.g. `CREATE TEMPORARY TABLE temp_table (tagline VARCHAR, language VARCHAR, title VARCHAR, released INTEGER)`
* e.g. `CREATE TEMPORARY TABLE <tempTable> (tagline VARCHAR, language VARCHAR, title VARCHAR, released INTEGER)`
*/
private String getTempTableClause(Map<String, Object> map, Provider provider) {
private String getTempTableClause(Map<String, Object> map, Provider provider, String tableName) {
String sqlFields = getStreamSortedByKey(map)
.map(e -> e.getKey() + " " + mapSqlType(provider, e.getValue()))
.collect(Collectors.joining(","));

return "CREATE TEMPORARY TABLE temp_table (%s)".formatted(sqlFields);
return "CREATE TEMPORARY TABLE %s (%s)".formatted(tableName, sqlFields);
}

private static Stream<Map.Entry<String, Object>> getStreamSortedByKey(Map<String, Object> map) {
Expand Down
8 changes: 7 additions & 1 deletion extended/src/test/java/apoc/load/AbstractJdbcTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,17 @@ public abstract class AbstractJdbcTest {

protected static java.sql.Time time = java.sql.Time.valueOf("15:37:00");

protected static final String MOVIES_CYPHER_FILE = "movies-analytics.cypher";
protected static final String ANALYTICS_CYPHER_FILE = "movies-analytics.cypher";

public void assertResult(Map<String, Object> row) {
Map<String, Object> expected = Util.map("NAME", "John", "SURNAME", null, "HIRE_DATE", hireDate.toLocalDate(), "EFFECTIVE_FROM_DATE",
effectiveFromDate.toLocalDateTime(), "TEST_TIME", time.toLocalTime(), "NULL_DATE", null);
assertEquals(expected, row.get("row"));
}

protected static void assertRowRank(Map<String, Object> row, Object expected) {
var result = (Map) row.get("row");
Object rank = result.get("rank");
assertEquals(expected, rank);
}
}
Loading

0 comments on commit 49e5315

Please sign in to comment.