Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add explicit schema support to JdbcIO read and xlang transform. #34128

Merged
merged 4 commits into from
Mar 3, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -745,6 +745,9 @@ public abstract static class ReadRows extends PTransform<PBegin, PCollection<Row
@Pure
abstract boolean getDisableAutoCommit();

@Pure
abstract @Nullable Schema getSchema();

abstract Builder toBuilder();

@AutoValue.Builder
Expand All @@ -762,6 +765,8 @@ abstract Builder setDataSourceProviderFn(

abstract Builder setDisableAutoCommit(boolean disableAutoCommit);

abstract Builder setSchema(@Nullable Schema schema);

abstract ReadRows build();
}

Expand Down Expand Up @@ -789,6 +794,10 @@ public ReadRows withStatementPreparator(StatementPreparator statementPreparator)
return toBuilder().setStatementPreparator(statementPreparator).build();
}

public ReadRows withSchema(Schema schema) {
return toBuilder().setSchema(schema).build();
}

/**
* This method is used to set the size of the data that is going to be fetched and loaded in
* memory per every database call. Please refer to: {@link java.sql.Statement#setFetchSize(int)}
Expand Down Expand Up @@ -830,7 +839,14 @@ public PCollection<Row> expand(PBegin input) {
getDataSourceProviderFn(),
"withDataSourceConfiguration() or withDataSourceProviderFn() is required");

Schema schema = inferBeamSchema(dataSourceProviderFn.apply(null), query.get());
// Don't infer schema if explicitly provided.
Schema schema;
if (getSchema() != null) {
schema = getSchema();
} else {
schema = inferBeamSchema(dataSourceProviderFn.apply(null), query.get());
}

PCollection<Row> rows =
input.apply(
JdbcIO.<Row>read()
Expand Down Expand Up @@ -1292,6 +1308,9 @@ public abstract static class ReadWithPartitions<T, PartitionColumnT>
@Pure
abstract boolean getUseBeamSchema();

@Pure
abstract @Nullable Schema getSchema();

@Pure
abstract @Nullable PartitionColumnT getLowerBound();

Expand Down Expand Up @@ -1333,6 +1352,8 @@ abstract Builder<T, PartitionColumnT> setDataSourceProviderFn(

abstract Builder<T, PartitionColumnT> setUseBeamSchema(boolean useBeamSchema);

abstract Builder setSchema(@Nullable Schema schema);

abstract Builder<T, PartitionColumnT> setFetchSize(int fetchSize);

abstract Builder<T, PartitionColumnT> setTable(String tableName);
Expand Down Expand Up @@ -1424,6 +1445,10 @@ public ReadWithPartitions<T, PartitionColumnT> withTable(String tableName) {
return toBuilder().setTable(tableName).build();
}

public ReadWithPartitions<T, PartitionColumnT> withSchema(Schema schema) {
return toBuilder().setSchema(schema).build();
}

private static final int EQUAL = 0;

@Override
Expand Down Expand Up @@ -1532,8 +1557,11 @@ public KV<Long, KV<PartitionColumnT, PartitionColumnT>> apply(
Schema schema = null;
if (getUseBeamSchema()) {
schema =
ReadRows.inferBeamSchema(
dataSourceProviderFn.apply(null), String.format("SELECT * FROM %s", getTable()));
getSchema() != null
? getSchema()
: ReadRows.inferBeamSchema(
dataSourceProviderFn.apply(null),
String.format("SELECT * FROM %s", getTable()));
rowMapper = (RowMapper<T>) SchemaUtil.BeamRowMapper.of(schema);
} else {
rowMapper = getRowMapper();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ public Schema configurationSchema() {
*/
@Override
public JdbcSchemaIO from(String location, Row configuration, @Nullable Schema dataSchema) {
return new JdbcSchemaIO(location, configuration);
return new JdbcSchemaIO(location, configuration, dataSchema);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, previously the dataSchema parameter wasn't passed down stream

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah and in the JdbcIO#readRows and readWithPartitions there schema case wasnt handled.

I also need to add the ability to explicitly pass the partition lower and upper bounds through xlang, otherwise that part will also happen at pipeline construction time. Will do in a followup later this week.

}

@Override
Expand All @@ -101,10 +101,12 @@ public PCollection.IsBounded isBounded() {
static class JdbcSchemaIO implements SchemaIO, Serializable {
protected final Row config;
protected final String location;
protected final @Nullable Schema dataSchema;

JdbcSchemaIO(String location, Row config) {
JdbcSchemaIO(String location, Row config, @Nullable Schema dataSchema) {
this.config = config;
this.location = location;
this.dataSchema = dataSchema;
}

@Override
Expand Down Expand Up @@ -147,6 +149,10 @@ public PCollection<Row> expand(PBegin input) {
readRows = readRows.withDisableAutoCommit(disableAutoCommit);
}

if (dataSchema != null) {
readRows = readRows.withSchema(dataSchema);
}

return input.apply(readRows);
} else {

Expand Down Expand Up @@ -175,6 +181,9 @@ public PCollection<Row> expand(PBegin input) {
readRows = readRows.withDisableAutoCommit(disableAutoCommit);
}

if (dataSchema != null) {
readRows = readRows.withSchema(dataSchema);
}
return input.apply(readRows);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -470,6 +470,51 @@ public void testReadWithSchema() {
pipeline.run();
}

@Test
public void testReadRowsWithExplicitSchema() {
Schema customSchema =
Schema.of(
Schema.Field.of("CUSTOMER_NAME", Schema.FieldType.STRING).withNullable(true),
Schema.Field.of("CUSTOMER_ID", Schema.FieldType.INT64).withNullable(true));

PCollection<Row> rows =
pipeline.apply(
JdbcIO.readRows()
.withDataSourceConfiguration(DATA_SOURCE_CONFIGURATION)
.withQuery(String.format("select name,id from %s where name = ?", READ_TABLE_NAME))
.withStatementPreparator(
preparedStatement -> preparedStatement.setString(1, TestRow.getNameForSeed(1)))
.withSchema(customSchema));

assertEquals(customSchema, rows.getSchema());

PCollection<Row> output = rows.apply(Select.fieldNames("CUSTOMER_NAME", "CUSTOMER_ID"));
PAssert.that(output)
.containsInAnyOrder(
ImmutableList.of(Row.withSchema(customSchema).addValues("Testval1", 1L).build()));

pipeline.run();
}

@Test
@SuppressWarnings({"UnusedVariable"})
public void testIncompatibleSchemaThrowsError() {
Schema incompatibleSchema =
Schema.of(
Schema.Field.of("WRONG_TYPE_NAME", Schema.FieldType.INT64),
Schema.Field.of("WRONG_TYPE_ID", Schema.FieldType.STRING));

Pipeline pipeline = Pipeline.create();
pipeline.apply(
JdbcIO.readRows()
.withDataSourceConfiguration(DATA_SOURCE_CONFIGURATION)
.withQuery(String.format("select name,id from %s limit 10", READ_TABLE_NAME))
.withSchema(incompatibleSchema));

PipelineExecutionException exception =
assertThrows(PipelineExecutionException.class, () -> pipeline.run().waitUntilFinish());
}

@Test
public void testReadWithPartitions() {
PCollection<TestRow> rows =
Expand All @@ -486,6 +531,32 @@ public void testReadWithPartitions() {
pipeline.run();
}

@Test
public void testReadWithPartitionsWithExplicitSchema() {
Schema customSchema =
Schema.of(
Schema.Field.of("CUSTOMER_NAME", Schema.FieldType.STRING).withNullable(true),
Schema.Field.of("CUSTOMER_ID", Schema.FieldType.INT32).withNullable(true));

PCollection<Row> rows =
pipeline.apply(
JdbcIO.<Row>readWithPartitions()
.withDataSourceConfiguration(DATA_SOURCE_CONFIGURATION)
.withTable(String.format("(select name,id from %s) as subq", READ_TABLE_NAME))
.withNumPartitions(5)
.withPartitionColumn("id")
.withLowerBound(0L)
.withUpperBound(1000L)
.withRowOutput()
.withSchema(customSchema));

assertEquals(customSchema, rows.getSchema());

PAssert.thatSingleton(rows.apply("Count All", Count.globally())).isEqualTo(1000L);

pipeline.run();
}

@Test
public void testReadWithPartitionsBySubqery() {
PCollection<TestRow> rows =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.beam.sdk.io.jdbc;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;

import java.sql.Connection;
import java.sql.PreparedStatement;
Expand Down Expand Up @@ -85,6 +86,92 @@ public void testPartitionedRead() {
pipeline.run();
}

@Test
public void testPartitionedReadWithExplicitSchema() {
JdbcSchemaIOProvider provider = new JdbcSchemaIOProvider();

Schema customSchema =
Schema.of(
Schema.Field.of("CUSTOMER_NAME", Schema.FieldType.STRING).withNullable(true),
Schema.Field.of("CUSTOMER_ID", Schema.FieldType.INT32).withNullable(true));

Row config =
Row.withSchema(provider.configurationSchema())
.withFieldValue("driverClassName", DATA_SOURCE_CONFIGURATION.getDriverClassName().get())
.withFieldValue("jdbcUrl", DATA_SOURCE_CONFIGURATION.getUrl().get())
.withFieldValue("username", "")
.withFieldValue("password", "")
.withFieldValue("partitionColumn", "id")
.withFieldValue("partitions", (short) 10)
.build();

JdbcSchemaIOProvider.JdbcSchemaIO schemaIO =
provider.from(
String.format("(select name,id from %s) as subq", READ_TABLE_NAME),
config,
customSchema);

PCollection<Row> output = pipeline.apply(schemaIO.buildReader());

assertEquals(customSchema, output.getSchema());

Long expected = Long.valueOf(EXPECTED_ROW_COUNT);
PAssert.that(output.apply(Count.globally())).containsInAnyOrder(expected);

PAssert.that(output)
.satisfies(
rows -> {
for (Row row : rows) {
assertNotNull(row.getString("CUSTOMER_NAME"));
assertNotNull(row.getInt32("CUSTOMER_ID"));
}
return null;
});

pipeline.run();
}

@Test
public void testReadWithExplicitSchema() {
JdbcSchemaIOProvider provider = new JdbcSchemaIOProvider();

Schema customSchema =
Schema.of(
Schema.Field.of("CUSTOMER_NAME", Schema.FieldType.STRING).withNullable(true),
Schema.Field.of("CUSTOMER_ID", Schema.FieldType.INT32).withNullable(true));

Row config =
Row.withSchema(provider.configurationSchema())
.withFieldValue("driverClassName", DATA_SOURCE_CONFIGURATION.getDriverClassName().get())
.withFieldValue("jdbcUrl", DATA_SOURCE_CONFIGURATION.getUrl().get())
.withFieldValue("username", "")
.withFieldValue("password", "")
.withFieldValue("readQuery", "SELECT name, id FROM " + READ_TABLE_NAME)
.build();

JdbcSchemaIOProvider.JdbcSchemaIO schemaIO =
provider.from(READ_TABLE_NAME, config, customSchema);

PCollection<Row> output = pipeline.apply(schemaIO.buildReader());

assertEquals(customSchema, output.getSchema());

Long expected = Long.valueOf(EXPECTED_ROW_COUNT);
PAssert.that(output.apply(Count.globally())).containsInAnyOrder(expected);

PAssert.that(output)
.satisfies(
rows -> {
for (Row row : rows) {
assertNotNull(row.getString("CUSTOMER_NAME"));
assertNotNull(row.getInt32("CUSTOMER_ID"));
}
return null;
});

pipeline.run();
}

// This test shouldn't work because we only support numeric and datetime columns and we are trying
// to use a string column as our partition source.
@Test
Expand Down
Loading
Loading