Skip to content

Commit

Permalink
Spark 3.5: Fix NotSerializableException when migrating Spark tables (#…
Browse files Browse the repository at this point in the history
…11157) (#11605)

Co-authored-by: Manu Zhang <[email protected]>
  • Loading branch information
bryanck and manuzhang authored Nov 20, 2024
1 parent eaea286 commit 9e6595f
Show file tree
Hide file tree
Showing 7 changed files with 177 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import java.util.Map;
import java.util.concurrent.ExecutorService;
import java.util.stream.Collectors;
import javax.annotation.Nullable;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FileStatus;
import org.apache.hadoop.fs.FileSystem;
Expand Down Expand Up @@ -263,6 +264,7 @@ private static DataFile buildDataFile(
* <p><b>Important:</b> Callers are responsible for shutting down the returned executor service
* when it is no longer needed to prevent resource leaks.
*/
@Nullable
public static ExecutorService migrationService(int parallelism) {
return parallelism == 1 ? null : ThreadPools.newFixedThreadPool("table-migration", parallelism);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -948,6 +948,26 @@ public void testAddFilesWithParallelism() {
sql("SELECT * FROM %s ORDER BY id", tableName));
}

@TestTemplate
public void testAddFilesPartitionedWithParallelism() {
createPartitionedHiveTable();

createIcebergTable(
"id Integer, name String, dept String, subdept String", "PARTITIONED BY (id)");

List<Object[]> result =
sql(
"CALL %s.system.add_files(table => '%s', source_table => '%s', parallelism => 2)",
catalogName, tableName, sourceTableName);

assertOutput(result, 8L, 4L);

assertEquals(
"Iceberg table contains correct data",
sql("SELECT id, name, dept, subdept FROM %s ORDER BY id", sourceTableName),
sql("SELECT id, name, dept, subdept FROM %s ORDER BY id", tableName));
}

private static final List<Object[]> EMPTY_QUERY_RESULT = Lists.newArrayList();

private static final StructField[] STRUCT = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -273,4 +273,22 @@ public void testMigrateWithInvalidParallelism() throws IOException {
.isInstanceOf(IllegalArgumentException.class)
.hasMessage("Parallelism should be larger than 0");
}

@TestTemplate
public void testMigratePartitionedWithParallelism() throws IOException {
assumeThat(catalogName).isEqualToIgnoringCase("spark_catalog");

String location = Files.createTempDirectory(temp, "junit").toFile().toString();
sql(
"CREATE TABLE %s (id bigint NOT NULL, data string) USING parquet PARTITIONED BY (id) LOCATION '%s'",
tableName, location);
sql("INSERT INTO TABLE %s (id, data) VALUES (1, 'a'), (2, 'b')", tableName);
List<Object[]> result =
sql("CALL %s.system.migrate(table => '%s', parallelism => %d)", catalogName, tableName, 2);
assertEquals("Procedure output must match", ImmutableList.of(row(2L)), result);
assertEquals(
"Should have expected rows",
ImmutableList.of(row("a", 1L), row("b", 2L)),
sql("SELECT * FROM %s ORDER BY id", tableName));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -263,4 +263,22 @@ public void testSnapshotWithInvalidParallelism() throws IOException {
.isInstanceOf(IllegalArgumentException.class)
.hasMessage("Parallelism should be larger than 0");
}

@TestTemplate
public void testSnapshotPartitionedWithParallelism() throws IOException {
String location = Files.createTempDirectory(temp, "junit").toFile().toString();
sql(
"CREATE TABLE %s (id bigint NOT NULL, data string) USING parquet PARTITIONED BY (id) LOCATION '%s'",
SOURCE_NAME, location);
sql("INSERT INTO TABLE %s (id, data) VALUES (1, 'a'), (2, 'b')", SOURCE_NAME);
List<Object[]> result =
sql(
"CALL %s.system.snapshot(source_table => '%s', table => '%s', parallelism => %d)",
catalogName, SOURCE_NAME, tableName, 2);
assertEquals("Procedure output must match", ImmutableList.of(row(2L)), result);
assertEquals(
"Should have expected rows",
ImmutableList.of(row("a", 1L), row("b", 2L)),
sql("SELECT * FROM %s ORDER BY id", tableName));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,18 @@
import java.io.IOException;
import java.io.Serializable;
import java.net.URI;
import java.util.Collection;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.UUID;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import java.util.stream.Collectors;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.Path;
Expand Down Expand Up @@ -92,6 +98,8 @@
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan;
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation;
import org.apache.spark.sql.util.CaseInsensitiveStringMap;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;
import scala.Function2;
import scala.Option;
import scala.Some;
Expand Down Expand Up @@ -487,7 +495,7 @@ public static void importSparkTable(
stagingDir,
partitionFilter,
checkDuplicateFiles,
TableMigrationUtil.migrationService(parallelism));
migrationService(parallelism));
}

/**
Expand Down Expand Up @@ -711,7 +719,7 @@ public static void importSparkPartitions(
spec,
stagingDir,
checkDuplicateFiles,
TableMigrationUtil.migrationService(parallelism));
migrationService(parallelism));
}

/**
Expand Down Expand Up @@ -971,4 +979,109 @@ public int hashCode() {
return Objects.hashCode(values, uri, format);
}
}

@Nullable
public static ExecutorService migrationService(int parallelism) {
return parallelism == 1 ? null : new LazyExecutorService(parallelism);
}

private static class LazyExecutorService implements ExecutorService, Serializable {

private final int parallelism;
private volatile ExecutorService service;

LazyExecutorService(int parallelism) {
this.parallelism = parallelism;
}

@Override
public void shutdown() {
getService().shutdown();
}

@NotNull
@Override
public List<Runnable> shutdownNow() {
return getService().shutdownNow();
}

@Override
public boolean isShutdown() {
return getService().isShutdown();
}

@Override
public boolean isTerminated() {
return getService().isTerminated();
}

@Override
public boolean awaitTermination(long timeout, @NotNull TimeUnit unit)
throws InterruptedException {
return getService().awaitTermination(timeout, unit);
}

@NotNull
@Override
public <T> Future<T> submit(@NotNull Callable<T> task) {
return getService().submit(task);
}

@NotNull
@Override
public <T> Future<T> submit(@NotNull Runnable task, T result) {
return getService().submit(task, result);
}

@NotNull
@Override
public Future<?> submit(@NotNull Runnable task) {
return getService().submit(task);
}

@NotNull
@Override
public <T> List<Future<T>> invokeAll(@NotNull Collection<? extends Callable<T>> tasks)
throws InterruptedException {
return getService().invokeAll(tasks);
}

@NotNull
@Override
public <T> List<Future<T>> invokeAll(
@NotNull Collection<? extends Callable<T>> tasks, long timeout, @NotNull TimeUnit unit)
throws InterruptedException {
return getService().invokeAll(tasks, timeout, unit);
}

@NotNull
@Override
public <T> T invokeAny(@NotNull Collection<? extends Callable<T>> tasks)
throws InterruptedException, ExecutionException {
return getService().invokeAny(tasks);
}

@Override
public <T> T invokeAny(
@NotNull Collection<? extends Callable<T>> tasks, long timeout, @NotNull TimeUnit unit)
throws InterruptedException, ExecutionException, TimeoutException {
return getService().invokeAny(tasks, timeout, unit);
}

@Override
public void execute(@NotNull Runnable command) {
getService().execute(command);
}

private ExecutorService getService() {
if (service == null) {
synchronized (this) {
if (service == null) {
service = TableMigrationUtil.migrationService(parallelism);
}
}
}
return service;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import org.apache.iceberg.actions.MigrateTable;
import org.apache.iceberg.relocated.com.google.common.base.Preconditions;
import org.apache.iceberg.relocated.com.google.common.collect.Maps;
import org.apache.iceberg.spark.SparkTableUtil;
import org.apache.iceberg.spark.actions.MigrateTableSparkAction;
import org.apache.iceberg.spark.actions.SparkActions;
import org.apache.iceberg.spark.procedures.SparkProcedures.ProcedureBuilder;
Expand Down Expand Up @@ -110,7 +111,7 @@ public InternalRow[] call(InternalRow args) {
int parallelism = args.getInt(4);
Preconditions.checkArgument(parallelism > 0, "Parallelism should be larger than 0");
migrateTableSparkAction =
migrateTableSparkAction.executeWith(executorService(parallelism, "table-migration"));
migrateTableSparkAction.executeWith(SparkTableUtil.migrationService(parallelism));
}

MigrateTable.Result result = migrateTableSparkAction.execute();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import org.apache.iceberg.actions.SnapshotTable;
import org.apache.iceberg.relocated.com.google.common.base.Preconditions;
import org.apache.iceberg.relocated.com.google.common.collect.Maps;
import org.apache.iceberg.spark.SparkTableUtil;
import org.apache.iceberg.spark.actions.SparkActions;
import org.apache.spark.sql.catalyst.InternalRow;
import org.apache.spark.sql.connector.catalog.TableCatalog;
Expand Down Expand Up @@ -106,7 +107,7 @@ public InternalRow[] call(InternalRow args) {
if (!args.isNullAt(4)) {
int parallelism = args.getInt(4);
Preconditions.checkArgument(parallelism > 0, "Parallelism should be larger than 0");
action = action.executeWith(executorService(parallelism, "table-snapshot"));
action = action.executeWith(SparkTableUtil.migrationService(parallelism));
}

SnapshotTable.Result result = action.tableProperties(properties).execute();
Expand Down

0 comments on commit 9e6595f

Please sign in to comment.