Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Spark 3.5: Fix NotSerializableException when migrating Spark tables #11157

Merged
merged 2 commits into from
Nov 20, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import java.util.Map;
import java.util.concurrent.ExecutorService;
import java.util.stream.Collectors;
import javax.annotation.Nullable;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FileStatus;
import org.apache.hadoop.fs.FileSystem;
Expand Down Expand Up @@ -263,6 +264,7 @@ private static DataFile buildDataFile(
* <p><b>Important:</b> Callers are responsible for shutting down the returned executor service
* when it is no longer needed to prevent resource leaks.
*/
@Nullable
public static ExecutorService migrationService(int parallelism) {
return parallelism == 1 ? null : ThreadPools.newFixedThreadPool("table-migration", parallelism);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -948,6 +948,26 @@ public void testAddFilesWithParallelism() {
sql("SELECT * FROM %s ORDER BY id", tableName));
}

@TestTemplate
public void testAddFilesPartitionedWithParallelism() {
createPartitionedHiveTable();

createIcebergTable(
"id Integer, name String, dept String, subdept String", "PARTITIONED BY (id)");

List<Object[]> result =
sql(
"CALL %s.system.add_files(table => '%s', source_table => '%s', parallelism => 2)",
catalogName, tableName, sourceTableName);

assertOutput(result, 8L, 4L);

assertEquals(
"Iceberg table contains correct data",
sql("SELECT id, name, dept, subdept FROM %s ORDER BY id", sourceTableName),
sql("SELECT id, name, dept, subdept FROM %s ORDER BY id", tableName));
}

private static final List<Object[]> EMPTY_QUERY_RESULT = Lists.newArrayList();

private static final StructField[] STRUCT = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -273,4 +273,22 @@ public void testMigrateWithInvalidParallelism() throws IOException {
.isInstanceOf(IllegalArgumentException.class)
.hasMessage("Parallelism should be larger than 0");
}

@TestTemplate
public void testMigratePartitionedWithParallelism() throws IOException {
assumeThat(catalogName).isEqualToIgnoringCase("spark_catalog");

String location = Files.createTempDirectory(temp, "junit").toFile().toString();
sql(
"CREATE TABLE %s (id bigint NOT NULL, data string) USING parquet PARTITIONED BY (id) LOCATION '%s'",
tableName, location);
sql("INSERT INTO TABLE %s (id, data) VALUES (1, 'a'), (2, 'b')", tableName);
List<Object[]> result =
sql("CALL %s.system.migrate(table => '%s', parallelism => %d)", catalogName, tableName, 2);
assertEquals("Procedure output must match", ImmutableList.of(row(2L)), result);
assertEquals(
"Should have expected rows",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Quick check do these tests fail without this patch? I just want to make sure because I'm pretty sure we are running this code in local mode and I want to make sure the serializers break without this patch.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes

ImmutableList.of(row("a", 1L), row("b", 2L)),
sql("SELECT * FROM %s ORDER BY id", tableName));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -263,4 +263,22 @@ public void testSnapshotWithInvalidParallelism() throws IOException {
.isInstanceOf(IllegalArgumentException.class)
.hasMessage("Parallelism should be larger than 0");
}

@TestTemplate
public void testSnapshotPartitionedWithParallelism() throws IOException {
String location = Files.createTempDirectory(temp, "junit").toFile().toString();
sql(
"CREATE TABLE %s (id bigint NOT NULL, data string) USING parquet PARTITIONED BY (id) LOCATION '%s'",
SOURCE_NAME, location);
sql("INSERT INTO TABLE %s (id, data) VALUES (1, 'a'), (2, 'b')", SOURCE_NAME);
List<Object[]> result =
sql(
"CALL %s.system.snapshot(source_table => '%s', table => '%s', parallelism => %d)",
catalogName, SOURCE_NAME, tableName, 2);
assertEquals("Procedure output must match", ImmutableList.of(row(2L)), result);
assertEquals(
"Should have expected rows",
ImmutableList.of(row("a", 1L), row("b", 2L)),
sql("SELECT * FROM %s ORDER BY id", tableName));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,18 @@
import java.io.IOException;
import java.io.Serializable;
import java.net.URI;
import java.util.Collection;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.UUID;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import java.util.stream.Collectors;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.Path;
Expand Down Expand Up @@ -92,6 +98,8 @@
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan;
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation;
import org.apache.spark.sql.util.CaseInsensitiveStringMap;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;
manuzhang marked this conversation as resolved.
Show resolved Hide resolved
import scala.Function2;
import scala.Option;
import scala.Some;
Expand Down Expand Up @@ -487,7 +495,7 @@ public static void importSparkTable(
stagingDir,
partitionFilter,
checkDuplicateFiles,
TableMigrationUtil.migrationService(parallelism));
executorService(parallelism));
}

/**
Expand Down Expand Up @@ -711,7 +719,7 @@ public static void importSparkPartitions(
spec,
stagingDir,
checkDuplicateFiles,
TableMigrationUtil.migrationService(parallelism));
executorService(parallelism));
manuzhang marked this conversation as resolved.
Show resolved Hide resolved
}

/**
Expand Down Expand Up @@ -971,4 +979,109 @@ public int hashCode() {
return Objects.hashCode(values, uri, format);
}
}

@Nullable
public static ExecutorService executorService(int parallelism) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this need to be public?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also this specifically makes ExecutorServices with the TableMigrationUtil.migrationService(parallelism); so we should probably indicate that in the name as well. (I think we have thread pool labeling in that method?)

return parallelism == 1 ? null : new ExecutorServiceFactory(parallelism);
}

private static class ExecutorServiceFactory implements ExecutorService, Serializable {
Copy link
Member

@RussellSpitzer RussellSpitzer Nov 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably needs a rename since this doesn't actually make Executor Services, probably just LazyExecutorService?


private final int parallelism;
private volatile ExecutorService service;

ExecutorServiceFactory(int parallelism) {
this.parallelism = parallelism;
}

@Override
public void shutdown() {
getService().shutdown();
}

@NotNull
@Override
public List<Runnable> shutdownNow() {
return getService().shutdownNow();
}

@Override
public boolean isShutdown() {
return getService().isShutdown();
}

@Override
public boolean isTerminated() {
return getService().isTerminated();
}

@Override
public boolean awaitTermination(long timeout, @NotNull TimeUnit unit)
throws InterruptedException {
return getService().awaitTermination(timeout, unit);
}

@NotNull
@Override
public <T> Future<T> submit(@NotNull Callable<T> task) {
return getService().submit(task);
}

@NotNull
@Override
public <T> Future<T> submit(@NotNull Runnable task, T result) {
return getService().submit(task, result);
}

@NotNull
@Override
public Future<?> submit(@NotNull Runnable task) {
return getService().submit(task);
}

@NotNull
@Override
public <T> List<Future<T>> invokeAll(@NotNull Collection<? extends Callable<T>> tasks)
throws InterruptedException {
return getService().invokeAll(tasks);
}

@NotNull
@Override
public <T> List<Future<T>> invokeAll(
@NotNull Collection<? extends Callable<T>> tasks, long timeout, @NotNull TimeUnit unit)
throws InterruptedException {
return getService().invokeAll(tasks, timeout, unit);
}

@NotNull
@Override
public <T> T invokeAny(@NotNull Collection<? extends Callable<T>> tasks)
throws InterruptedException, ExecutionException {
return getService().invokeAny(tasks);
}

@Override
public <T> T invokeAny(
@NotNull Collection<? extends Callable<T>> tasks, long timeout, @NotNull TimeUnit unit)
throws InterruptedException, ExecutionException, TimeoutException {
return getService().invokeAny(tasks, timeout, unit);
}

@Override
public void execute(@NotNull Runnable command) {
getService().execute(command);
}

private ExecutorService getService() {
if (service == null) {
synchronized (this) {
if (service == null) {
service = TableMigrationUtil.migrationService(parallelism);
}
}
}
return service;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import org.apache.iceberg.actions.MigrateTable;
import org.apache.iceberg.relocated.com.google.common.base.Preconditions;
import org.apache.iceberg.relocated.com.google.common.collect.Maps;
import org.apache.iceberg.spark.SparkTableUtil;
import org.apache.iceberg.spark.actions.MigrateTableSparkAction;
import org.apache.iceberg.spark.actions.SparkActions;
import org.apache.iceberg.spark.procedures.SparkProcedures.ProcedureBuilder;
Expand Down Expand Up @@ -110,7 +111,7 @@ public InternalRow[] call(InternalRow args) {
int parallelism = args.getInt(4);
Preconditions.checkArgument(parallelism > 0, "Parallelism should be larger than 0");
migrateTableSparkAction =
migrateTableSparkAction.executeWith(executorService(parallelism, "table-migration"));
migrateTableSparkAction.executeWith(SparkTableUtil.executorService(parallelism));
}

MigrateTable.Result result = migrateTableSparkAction.execute();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import org.apache.iceberg.actions.SnapshotTable;
import org.apache.iceberg.relocated.com.google.common.base.Preconditions;
import org.apache.iceberg.relocated.com.google.common.collect.Maps;
import org.apache.iceberg.spark.SparkTableUtil;
import org.apache.iceberg.spark.actions.SparkActions;
import org.apache.spark.sql.catalyst.InternalRow;
import org.apache.spark.sql.connector.catalog.TableCatalog;
Expand Down Expand Up @@ -106,7 +107,7 @@ public InternalRow[] call(InternalRow args) {
if (!args.isNullAt(4)) {
int parallelism = args.getInt(4);
Preconditions.checkArgument(parallelism > 0, "Parallelism should be larger than 0");
action = action.executeWith(executorService(parallelism, "table-snapshot"));
action = action.executeWith(SparkTableUtil.executorService(parallelism));
}

SnapshotTable.Result result = action.tableProperties(properties).execute();
Expand Down